In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms.functional as TF

In [2]:
class DoubleConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(DoubleConv, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, 1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, 1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

  def forward(self, x):
    return self.conv(x)

In [3]:
class UNET(nn.Module):
  def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
    super(UNET, self).__init__()
    self.up_sampling = nn.ModuleList()
    self.down_sampling = nn.ModuleList()
    self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

    # Down the U network
    for feature in features:
      self.down_sampling.append(DoubleConv(in_channels, out_channels))
      in_channels = feature

    # Up the U network
    for feature in reversed(features):
      self.up_sampling.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2))
      self.up_sampling.append(DoubleConv(feature * 2, out_channels))

    self.bottle_neck = DoubleConv(features[-1], features[-1] * 2)
    self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

  def forward(self, x):
    skip_connections = []
    for down_sample in self.down_sampling:
      x = down_sample(x)
      skip_connections.append(x)
      x = self.max_pool(x)

    x = self.bottle_neck(x)
    skip_connections = skip_connections[::-1]

    for i in range(0, len(self.up_sampling), 2):
      x = self.up_sampling[i](x)
      skip_connection = skip_connections[i//2]

      if x.shape != skip_connection.shape():
        x = TF.resize(x, size=skip_connection.shape[2:])

      concat_skip = torch.cat((skip_connection, x), dim=1)
      x = self.up_sampling[i + 1](concat_skip)

    return self.final_conv(x)

In [4]:
x = torch.randn((3, 1, 164, 164))
model = UNET(in_channels=1, out_channels=1)
print(x)

tensor([[[[ 0.4745,  0.3990,  0.4629,  ..., -0.4358, -2.7471,  1.2854],
          [-0.8002, -1.3700,  0.1394,  ..., -1.0008, -0.1001,  0.3647],
          [ 0.4885, -0.2957,  0.3457,  ..., -0.8069,  0.0458, -1.3301],
          ...,
          [-0.5858, -0.7971,  0.6741,  ...,  1.6228, -0.3367,  2.7377],
          [-0.6080, -1.0116,  0.5573,  ...,  0.4941,  1.8874,  1.1634],
          [-1.2275,  0.4125, -0.2102,  ...,  0.1442,  0.3863, -0.5045]]],


        [[[ 0.5432,  1.7162, -1.2448,  ..., -1.7358,  0.5377,  0.3850],
          [ 0.2976, -0.4702,  0.3024,  ..., -0.3201,  1.1905, -0.1385],
          [ 0.3186, -0.6002, -0.4442,  ...,  2.2087, -0.7032,  0.0036],
          ...,
          [-2.1623,  0.3663, -0.7531,  ...,  0.3838, -0.3361, -1.0889],
          [ 0.4942, -0.4374, -0.7796,  ..., -1.2864,  1.8577, -0.3994],
          [-0.2362, -1.0003, -0.6744,  ...,  1.7848,  0.4088, -0.0438]]],


        [[[-0.1844,  0.4704, -1.1642,  ...,  0.0059,  0.5202, -1.2484],
          [ 0.6056,  0.077

In [5]:
preds = model(x)

print(preds)

RuntimeError: Given groups=1, weight of size [1, 64, 3, 3], expected input[3, 1, 80, 80] to have 64 channels, but got 1 channels instead

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet = UNET(in_channels=1, out_channels=1).to(device)

criteria = nn.CrossEntropyLoss()
optimizer = optim.Adam(unet.parameters(), lr=0.001)

In [7]:
from torch.utils.data import Dataset, DataLoader

class CustomImageDataset(Dataset):
  def __init__(self, num_samples, image_size):
    self.num_samples = num_samples
    self.image_size = image_size

  def __len__(self):
    return self.num_samples

  def __getitem__(self, index):
    image = torch.randn(1, self.image_size, self.image_size)
    target_mask = torch.randint(0, 2, (1, self.image_size, self.image_size)).float()
    return image, target_mask

In [8]:
dataset = CustomImageDataset(1000, 256)
batch_size = 6
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [9]:
# Training Loop
num_epochs = 5

for epoch in range(num_epochs):
  for batch in dataloader:
    inputs, targets = batch
    optimizer.zero_grad()
    outputs = unet(inputs.to(device))
    loss = criteria(outputs, targets.to(device))
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch}, Loss: {loss}")

RuntimeError: Given groups=1, weight of size [1, 64, 3, 3], expected input[6, 1, 126, 126] to have 64 channels, but got 1 channels instead