<a href="https://colab.research.google.com/github/mgp87/Jupyter_Notebooks_Collection/blob/main/UnetCourse.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms.functional as TF

In [None]:
from torch.nn.modules.batchnorm import BatchNorm2d
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
      super(DoubleConv, self).__init__()

      self.conv = nn.Sequential(
          nn.Conv2d(in_channels, out_channels, 3, 1, 1),
          nn.BatchNorm2d(out_channels),
          nn.ReLU(inplace=True),
          nn.Conv2d(out_channels, out_channels, 3, 1, 1),
          nn.BatchNorm2d(out_channels),
          nn.ReLU(inplace=True)
      )
    def forward(self, x):
        return self.conv(x)

In [None]:
class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNET, self).__init__()

        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)


        # Down path of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature


        # Up path of UNET
        for feature in reversed(features):

            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))

            self.ups.append(DoubleConv(feature*2, feature))


        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)


    def forward(self, x):

        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.maxpool(x)


        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):

            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])


            concat_skip = torch.cat((skip_connection, x), dim=1)

            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)




In [None]:
x = torch.randn((3, 1, 161, 161))
model = UNET(in_channels=1, out_channels=1)

In [None]:

preds = model(x)

print(preds)



tensor([[[[-0.1661, -0.0652,  0.0616,  ..., -0.3198,  0.0018,  0.1467],
          [-0.2329, -0.5740, -0.1873,  ..., -0.1648,  0.0733, -0.3024],
          [-0.2453, -0.0542, -0.1204,  ..., -0.0448,  0.2311,  0.1178],
          ...,
          [-0.0850,  0.2718,  0.2988,  ..., -0.1926, -0.3610,  0.0313],
          [-0.2913, -0.8278,  0.3615,  ..., -0.2830, -0.1674,  0.3928],
          [-0.5364, -0.4217, -0.2075,  ...,  0.1660, -0.4051,  0.1708]]],


        [[[ 0.0293, -0.1051, -0.1352,  ..., -0.0688, -0.3108, -0.0438],
          [-0.1427, -0.4800, -0.0430,  ..., -0.3914,  0.2262,  0.2094],
          [-0.2194,  0.3042, -0.6773,  ..., -0.1752, -0.5149,  0.0470],
          ...,
          [-0.3282, -0.1372,  0.1623,  ...,  0.0692, -0.2013,  0.0843],
          [-0.2625, -0.0244,  0.1406,  ..., -0.2039,  0.0562, -0.1420],
          [-0.3425, -0.2763,  0.0009,  ..., -0.8312, -0.4066,  0.2560]]],


        [[[-0.1290, -0.1345,  0.1322,  ..., -0.3665, -0.0730,  0.1422],
          [-0.5223, -0.416

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

unet = model = UNET(in_channels=1, out_channels=1).to(device)

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(unet.parameters(), lr=0.001)

In [None]:
from torch.utils.data import Dataset, DataLoader

class CustomImageDataset(Dataset):
    def __init__(self, num_samples, image_size):

        self.num_samples = num_samples
        self.image_size = image_size

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        image = torch.randn(1, self.image_size, self.image_size)
        target_mask = torch.randint(0,2, (1, self.image_size, self.image_size)).float()

        return image, target_mask


In [None]:
dataset = CustomImageDataset(1000, 256)

batch_size = 6
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [None]:
num_epochs = 5

for epoch in range(num_epochs):
    for batch in dataloader:
        inputs, targets = batch

        optimizer.zero_grad()

        outputs = unet(inputs.to(device))

        loss = criterion(outputs, targets.to(device))

        loss.backward()
        optimizer.step()


        print(f"Epoch: {epoch}, Loss: {loss.item()}")




Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Loss: -0.0
Epoch: 0, Los