In [3]:
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
from torchvision import datasets, transforms
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
import glob
import random
import cv2
import imutils
# %run data.ipynb

# Basic U-Net 
Based on https://www.mdpi.com/2073-8994/14/11/2295 
Using https://github.com/richzhang/colorization/blob/master/colorizers/eccv16.py for syntax help
Using https://pyimagesearch.com/2021/11/08/u-net-training-image-segmentation-models-in-pytorch/ for syntax help

In [4]:
class CUNet(nn.Module):
    def __init__(self, batch = nn.BatchNorm2d):
        
        super(CUNet, self).__init__()
        # this one downsamples using avg pool idk
        
        # output dim = floor(input dim)
        self.down = torch.nn.AvgPool2d(2) #may need to change depending on output shapes

        self.conv1_1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size = 3, stride = 1, padding = 1),
            nn.Sigmoid(),
            nn.BatchNorm2d(64)
        )

        self.conv1_2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size = 3, stride = 1, padding = 1),
            nn.Sigmoid(),
            nn.BatchNorm2d(64)
        )

        #downsample

        self.conv2_1 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size = 3, stride = 1, padding = 1),
            nn.Sigmoid(),
            nn.BatchNorm2d(128)
        )

        self.conv2_2 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size = 3, stride = 1, padding = 1),
            nn.Sigmoid(),
            nn.BatchNorm2d(128)
        )

        #downsample

        self.conv3_1 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size = 3, stride = 1, padding = 1),
            nn.Sigmoid(),
            nn.BatchNorm2d(256)
        )

        self.conv3_2 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size = 3, stride = 1, padding = 1),
            nn.Sigmoid(),
            nn.BatchNorm2d(256)
        )

        #downsample

        self.conv4_1 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size = 3, stride = 1, padding = 1),
            nn.Sigmoid(),
            nn.BatchNorm2d(512)
        )

        self.conv4_2 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 1),
            nn.Sigmoid(),
            nn.BatchNorm2d(512)
        )

        self.extconv5_1 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 2, dilation = 2),
            nn.Sigmoid(),
            nn.BatchNorm2d(512)
        )

        self.extconv5_2 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 2, dilation = 2),
            nn.Sigmoid(),
            nn.BatchNorm2d(512)
        )

        self.extconv6_1 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 2, dilation = 2),
            nn.Sigmoid(),
            nn.BatchNorm2d(512)
        )

        self.extconv6_2 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 2, dilation = 2),
            nn.Sigmoid(),
            nn.BatchNorm2d(512)
        )

        self.conv7_1 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 1),
            nn.Sigmoid(),
            nn.BatchNorm2d(512)
        )

        self.conv7_2 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 1),
            nn.Sigmoid(),
            nn.BatchNorm2d(512)
        )

        self.deconv8_1 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size = 4, stride = 2, padding = 1),
            nn.Sigmoid(),
            nn.BatchNorm2d(256)
        )

        #cat the original copy

        self.conv8_2 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size = 3, stride = 1, padding = 1),
            nn.Sigmoid(),
            nn.BatchNorm2d(256)
        )

        self.deconv9_1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size = 4, stride = 2, padding = 1),
            nn.Sigmoid(),
            nn.BatchNorm2d(128)
        )

        #cat the original

        self.conv9_2 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size = 3, stride = 1, padding = 1),
            nn.Sigmoid(),
            nn.BatchNorm2d(128)
        )

        self.deconv10_1 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size = 4, stride = 2, padding = 1),
            nn.Sigmoid(),
            nn.BatchNorm2d(64)
        )

        #cat the original

        self.conv10_2 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size = 3, stride = 1, padding = 1),
            nn.Sigmoid(),
            nn.BatchNorm2d(64)
        )

        #final output layer for the per-pixel colour
        self.conv11 = nn.Sequential(
            nn.Conv2d(64, 2, kernel_size = 3, stride = 1, padding = 1),
            nn.Tanh()
        )
    def forward(self, x: torch.Tensor):
        x = self.conv1_1(x)
        x = self.conv1_2(x)
        conv1_res = x
        x = self.down(x)
        x = self.conv2_1(x)
        x = self.conv2_2(x)
        conv2_res = x
        x = self.down(x)
        x = self.conv3_1(x)
        x = self.conv3_2(x)
        conv3_res = x
        x = self.down(x)
        x = self.conv4_1(x)
        x = self.conv4_2(x)
        x = self.extconv5_1(x)
        x = self.extconv5_2(x)
        x = self.extconv6_1(x)
        x = self.extconv6_2(x)
        x = self.conv7_1(x)
        x = self.conv7_2(x)
        x = self.deconv8_1(x)
        # Resizing: not needed for nice powers of 2
        # if x.size() != conv3_res.size(): #x, y not the same, guarantee cat
        #     x = nn.functional.interpolate(x, conv3_res.size())
        x = torch.cat([x, conv3_res], dim = 1)
        x = self.conv8_2(x)
        x = self.deconv9_1(x)
        # if x.size() != conv2_res.size(): #x, y not the same, guarantee cat
        #     x = nn.functional.interpolate(x, conv3_res.size())
        x = torch.cat([x, conv2_res], dim = 1)
        x = self.conv9_2(x)
        x = self.deconv10_1(x)
        # if x.size() != conv2_res.size(): #x, y not the same, guarantee cat
        #     x = nn.functional.interpolate(x, conv3_res.size())
        x = torch.cat([x, conv1_res], dim = 1)
        x = self.conv10_2(x)
        x = self.conv11(x)
        return x
    

SyntaxError: incomplete input (2804254556.py, line 66)

In [33]:
# Testing how dimensions work
down = torch.nn.AvgPool2d(2, ceil_mode = True)
deconv = nn.ConvTranspose2d(1, 1, kernel_size = 4, stride = 2, padding = 1)
w = 256
h = 256
t1 = torch.tensor(
    [[[
       [1 for range in range(h)]
       for j in range(w)
             ]]],
       dtype = torch.float
    )
print(t1.size())
t1 = down(t1)
print(t1.size())
#print(t1)
t1 = deconv(t1)
print(t1.size())

torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 128, 128])
torch.Size([1, 1, 256, 256])


In [46]:
#test model forward
w = 256
h = 256
t1 = torch.tensor(
    [[[
       [1 for range in range(h)]
       for j in range(w)
             ]]],
       dtype = torch.float
    )
model = CUNet()
res = model.forward(t1)
print(t1.size())
print(res.size())

TypeError: _BatchNorm.__init__() missing 1 required positional argument: 'num_features'