In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms.functional as TF

In [8]:
class DoubleConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(DoubleConv, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, 1, 1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, 1, 1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

  def forward(self, x):
    return self.conv(x)

In [15]:
class UNET(nn.Module):
  def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
    super(UNET, self).__init__()
    self.up_sampling = nn.ModuleList()
    self.down_sampling = nn.ModuleList()
    self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

    # Down the U network
    for feature in features:
      self.down_sampling.append(DoubleConv(in_channels, feature))
      in_channels = feature

    # Up the U network
    for feature in reversed(features):
      self.up_sampling.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2))
      self.up_sampling.append(DoubleConv(feature * 2, feature))

    self.bottle_neck = DoubleConv(features[-1], features[-1] * 2)
    self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

  def forward(self, x):
    skip_connections = []
    for down_sample in self.down_sampling:
      x = down_sample(x)
      skip_connections.append(x)
      x = self.max_pool(x)

    x = self.bottle_neck(x)
    skip_connections = skip_connections[::-1]

    for i in range(0, len(self.up_sampling), 2):
      x = self.up_sampling[i](x)
      skip_connection = skip_connections[i//2]

      if x.shape != skip_connection.shape:
        x = TF.resize(x, size=skip_connection.shape[2:])

      concat_skip = torch.cat((skip_connection, x), dim=1)
      x = self.up_sampling[i + 1](concat_skip)

    return self.final_conv(x)

In [16]:
x = torch.randn((3, 1, 164, 164))
model = UNET(in_channels=1, out_channels=1)
print(x)

tensor([[[[-1.0468, -0.0166,  0.0301,  ...,  1.6915,  0.5362, -1.6300],
          [ 0.9801, -0.3030,  0.6119,  ...,  0.9953,  0.8321, -0.6704],
          [-0.0898,  0.1024,  2.3816,  ...,  0.0412, -0.6030, -0.6178],
          ...,
          [-0.0446, -0.5421, -0.9536,  ...,  2.4646, -0.6498, -0.7474],
          [-0.2652,  0.1830, -2.3945,  ...,  0.1613, -1.1950, -1.2325],
          [ 0.6813, -0.5447, -0.8322,  ...,  0.2796, -0.8650, -1.0577]]],


        [[[ 1.0751,  0.0246, -0.5136,  ...,  0.4104,  0.2865, -0.2959],
          [ 0.7612, -0.4717,  0.6451,  ...,  0.3040, -0.0145, -0.9153],
          [-1.0468, -0.4714,  0.6018,  ...,  0.0491,  1.7783,  0.4017],
          ...,
          [ 1.5515, -2.0278, -1.8483,  ...,  0.4439,  0.5086, -0.5200],
          [-0.5622,  0.4031, -0.9494,  ..., -1.5475,  2.1311, -1.2994],
          [ 1.0823,  0.2312,  0.7589,  ..., -0.4231,  1.1247, -0.7772]]],


        [[[ 1.3303,  0.5727, -0.0660,  ...,  0.2148, -1.4475, -1.1001],
          [-0.7953, -1.050

In [17]:
preds = model(x)

print(preds)



tensor([[[[-7.8615e-01, -9.2691e-01, -6.0194e-01,  ..., -6.1102e-01,
           -5.9466e-01, -4.2123e-01],
          [-3.3110e-01,  5.0412e-02, -8.5804e-02,  ..., -3.5462e-01,
           -4.5124e-01, -2.9182e-01],
          [-6.5551e-02, -4.8520e-01, -1.3581e+00,  ..., -1.8820e-01,
           -1.5520e-01,  3.1639e-03],
          ...,
          [-5.7518e-01,  1.7181e-01, -9.2582e-02,  ..., -6.8580e-02,
           -1.9454e-01, -4.0080e-01],
          [-9.4960e-01, -3.5209e-01, -3.9878e-01,  ..., -3.0820e-01,
            1.5029e-01, -3.1823e-01],
          [-7.8362e-01, -9.5320e-01, -3.1767e-01,  ..., -3.7102e-01,
           -6.0932e-01, -8.3625e-01]]],


        [[[-5.1371e-01, -7.1526e-01, -5.1797e-01,  ..., -9.2383e-01,
           -8.4946e-01, -3.7277e-01],
          [-8.2932e-01, -1.7809e-01, -3.8386e-01,  ..., -6.7520e-01,
           -5.2506e-01, -2.0573e-01],
          [-9.3634e-02, -9.0712e-01, -3.1235e-01,  ..., -5.0498e-01,
           -5.9980e-01, -4.7795e-01],
          ...,
   

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet = UNET(in_channels=1, out_channels=1).to(device)

criteria = nn.CrossEntropyLoss()
optimizer = optim.Adam(unet.parameters(), lr=0.001)

In [19]:
from torch.utils.data import Dataset, DataLoader

class CustomImageDataset(Dataset):
  def __init__(self, num_samples, image_size):
    self.num_samples = num_samples
    self.image_size = image_size

  def __len__(self):
    return self.num_samples

  def __getitem__(self, index):
    image = torch.randn(1, self.image_size, self.image_size)
    target_mask = torch.randint(0, 2, (1, self.image_size, self.image_size)).float()
    return image, target_mask

In [20]:
dataset = CustomImageDataset(1000, 256)
batch_size = 6
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Training Loop
num_epochs = 5

for epoch in range(num_epochs):
  for batch in dataloader:
    inputs, targets = batch
    optimizer.zero_grad()
    outputs = unet(inputs.to(device))
    loss = criteria(outputs, targets.to(device))
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch}, Loss: {loss}")

Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Los