In [1]:
import torch.nn as nn
import torchvision
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from random import choices, sample
import numpy as np
import random
from sklearn.model_selection import train_test_split
import warnings

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
seed = 42
warnings.filterwarnings("ignore")

In [3]:
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [4]:
mock_images = torch.randint(0, 256, (1, 1, 572, 572)).float().to(device)

In [5]:
class UNet(nn.Module):

  def __init__(self):
    super().__init__()

    self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3)
    self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3)
    self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3)
    self.conv4 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3)
    self.conv5 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3)
    self.conv6 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3)
    self.conv7 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3)
    self.conv8 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3)
    self.conv9 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3)
    self.conv10 = nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3)
    self.conv11 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2)
    self.conv12 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3)
    self.conv13 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3)
    self.conv14 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
    self.conv15 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3)
    self.conv16 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3)
    self.conv17 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
    self.conv18 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3)
    self.conv19 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3)
    self.conv20 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)
    self.conv21 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3)
    self.conv22 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3)
    self.conv23 = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)

    self.crop1 = torchvision.transforms.CenterCrop((392, 392))
    self.crop2 = torchvision.transforms.CenterCrop((200, 200))
    self.crop3 = torchvision.transforms.CenterCrop((104, 104))
    self.crop4 = torchvision.transforms.CenterCrop((56, 56))

    self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
    self.relu = nn.ReLU()

  def forward(self, x):
    x = self.relu(self.conv1(x))
    x = self.relu(self.conv2(x))
    x1 = self.crop1(x.clone().detach())
    x = self.maxpool(x)
    x = self.relu(self.conv3(x))
    x = self.relu(self.conv4(x))
    x2 = self.crop2(x.clone().detach())
    x = self.maxpool(x)
    x = self.relu(self.conv5(x))
    x = self.relu(self.conv6(x))
    x3 = self.crop3(x.clone().detach())
    x = self.maxpool(x)
    x = self.relu(self.conv7(x))
    x = self.relu(self.conv8(x))
    x4 = self.crop4(x.clone().detach())
    x = self.maxpool(x)
    x = self.relu(self.conv9(x))
    x = self.relu(self.conv10(x))
    x = self.conv11(x)
    x = torch.cat((x4, x), dim=1)
    x = self.relu(self.conv12(x))
    x = self.relu(self.conv13(x))
    x = self.conv14(x)
    x = torch.cat((x3, x), dim=1)
    x = self.relu(self.conv15(x))
    x = self.relu(self.conv16(x))
    x = self.conv17(x)
    x = torch.cat((x2, x), dim=1)
    x = self.relu(self.conv18(x))
    x = self.relu(self.conv19(x))
    x = self.conv20(x)
    x = torch.cat((x1, x), dim=1)
    x = self.relu(self.conv21(x))
    x = self.relu(self.conv22(x))
    x = self.conv23(x)
    return x


In [6]:
model = UNet().to(device)
model(mock_images)

tensor([[[[-0.9295, -0.4955, -0.0597,  ..., -0.8960, -1.2050, -0.2833],
          [-1.6364, -0.7513, -1.0377,  ..., -0.8439, -0.6263, -0.7552],
          [-0.7889, -1.4408, -1.2966,  ..., -0.5833,  0.4957, -0.1332],
          ...,
          [-0.5248, -1.1461, -0.9973,  ..., -1.0610, -0.5198,  0.3151],
          [-0.9917, -0.6281, -0.9026,  ..., -0.4739, -1.1499,  0.7924],
          [-0.6772,  0.2353, -0.7214,  ...,  0.2731, -1.5637, -0.1155]],

         [[-0.6106, -0.0313, -0.7558,  ...,  0.3762,  0.6754,  1.4451],
          [ 0.0461,  1.0008,  0.5537,  ..., -0.1377,  0.7356,  0.2764],
          [-0.4387,  0.1490,  0.1790,  ...,  0.7754,  0.3303,  0.2420],
          ...,
          [-1.0367, -0.9383,  0.0205,  ..., -0.0858,  0.8842,  0.7873],
          [ 0.5883,  0.4170,  0.8936,  ...,  0.4267,  0.2625,  0.5653],
          [ 0.3385, -0.5467, -0.1472,  ..., -0.2908,  0.0461,  0.8120]]]],
       device='cuda:0', grad_fn=<ConvolutionBackward0>)