In [None]:
import os
os.system("git clone https://github.com/lorenzo-saccaro/NNDL-recoloring-GAN")

In [None]:
import os

# check if in kaggle 
if 'KAGGLE_CONTAINER_NAME' in os.environ:
    kaggle = True
    os.chdir('/kaggle/working/NNDL-recoloring-GAN')
    print(os.getcwd())
    # pull repo to update .py files
    os.system('git pull')
    
else:
    kaggle = False
    

In [None]:
from dataset import CocoDataset
from torchvision.transforms import Compose, ToTensor, Grayscale, Resize
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import tqdm
import matplotlib.pyplot as plt

# Dataset creation

In [None]:
if kaggle:
    DATASET_ROOT = '/kaggle/input/coco-2017-dataset/coco2017'
else:
    DATASET_ROOT = 'C:\\Users\\loren\\Datasets\\coco2017'

print(DATASET_ROOT)

## Define transformations to apply to each dataset input and output

In [None]:
transform_x_train = Compose([
    ToTensor(),
    Resize((256,256)), # TODO: to be tuned
    Grayscale() # TODO: think about other transformation / data augmentation techniques (be carefull that the transformation must be the same for x and y (eg. random ones, probably need to rewrite class)
])

transform_y_train = Compose([
    ToTensor(),
    Resize((256,256)) # TODO: to be tuned
])

# TODO: think if transformations for val and test have to be different from the train one
transform_x_val = transform_x_test = transform_x_train
transform_y_val = transform_y_test = transform_y_train

## Get dataset objects from helper function

In [None]:
train_dataset = CocoDataset(dataset_folder=DATASET_ROOT, dataset_type='train', transform_x=transform_x_train,
                            transform_y=transform_y_train)

val_dataset = CocoDataset(dataset_folder=DATASET_ROOT, dataset_type='val', transform_x=transform_x_val,
                            transform_y=transform_y_val)

test_dataset = CocoDataset(dataset_folder=DATASET_ROOT, dataset_type='test', transform_x=transform_x_test,
                            transform_y=transform_y_test)
# TODO: Think about working in Lab colorspace and use just 2 vectors as output


## Define corresponding dataloaders

In [None]:
BATCH_SIZE = 32 # TODO: to be tuned

train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=os.cpu_count())

val_dataloader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=os.cpu_count())

test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=os.cpu_count())



In [None]:
# test if everything is working
iterator = tqdm.tqdm(train_dataloader)
for x_batch, y_batch in iterator:
    pass


# Model Definition

### Generator

In [None]:
class ConvBlock(nn.Module):
    
    def __init__(self, in_size, out_size, kernel = 3, padding = 1):
        
        super().__init__()
        
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_size, out_size, kernel_size=kernel, padding=padding),
            nn.BatchNorm2d(out_size),
            
            nn.Conv2d(out_size, out_size, kernel_size=kernel, padding=padding),
            nn.BatchNorm2d(out_size),
            
            nn.ReLU()
            )
        
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Conv2d):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.BatchNorm2d):
            nn.init.normal_(module.weight.data, 1.0, 0.2)
            nn.init.constant_(module.bias.data, 0)
    
    def forward(self, x):
        
        y = self.conv_block(x)

        return y
        

In [None]:
c_b = ConvBlock(1, 64)

foo = torch.normal(0, 1, [1, 1, 256, 256])
print(foo.shape)
result = c_b(foo).detach()
print(result.shape)
result = result[0,1]
print(result.shape)

plt.imshow(foo.squeeze(), cmap='gray')
plt.figure()
plt.imshow(result, cmap='gray')

In [None]:
class EncoderBlock(nn.Module):
    
    def __init__(self, in_size, out_size, pool_size = (2,2)):
        
        super().__init__()
        
        self.conv_block = ConvBlock(in_size, out_size)
        self.pool = nn.MaxPool2d(pool_size)
    
        self.conv_block._init_weights
    
    def forward(self, x):
        y = self.conv_block(x) # keep y since will be used for the decoder part
        pooled = self.pool(y)
        
        return y, pooled

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, in_size, out_size, kernel = 2, padding = 1):
        
        super().__init__()
        
        self.t_conv = nn.ConvTranspose2d(in_size, out_size, kernel_size= kernel, padding=padding)
        self.conv_block = ConvBlock(2*out_size, out_size)
        
        self.conv_block._init_weights()
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.ConvTranspose2d):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                module.bias.data.zero_()
            

    
    def forward(self, x, skip):
        y = self.t_conv(x)
        y = torch.cat([skip, y], axis=1)
        y = self.conv_block(y)
        
        return y

In [None]:
class UNet(nn.Module):
    def __init__(self, in_size):
        
        super().__init__()
        
        self.e1 = EncoderBlock(in_size, 64)
        self.e2 = EncoderBlock(64, 128)
        self.e3 = EncoderBlock(128, 256)
        self.e4 = EncoderBlock(256,512)
        
        self.bottleneck = ConvBlock(512,1024)
        
        self.d4 = DecoderBlock(1024,512)
        self.d3 = DecoderBlock(512,256)
        self.d2 = DecoderBlock(256,128)
        self.d1 = DecoderBlock(128,64)
        
        self.last = nn.Conv2d(64, 3, kernel_size=1, padding=0)
        
        def _init_weights():
            
    
        
    def forward(self,x):
        y, p1 = self.e1(x)
        y, p2 = self.e2(y)
        y, p3 = self.e3(y)
        y, p4 = self.e4(y)
        y = self.bottleneck(y)
        y = self.d4(y, p4)
        y = self.d3(y, p3)
        y = self.d2(y, p2)
        y = self.d1(y, p1)
        y = self.last(y)
        
        return y
        

In [None]:
#class UNet(nn.Module):
#    def __init__(self, in_size, out_size):
#        
#        super.__init__()
#        
#        self.unet = nn.Sequential(
#            EncoderBlock(in_size, 64),
#            EncoderBlock(64, 128),
#            EncoderBlock(128, 256),
#            EncoderBlock(256,512),
#        
#            ConvBlock(512,1024),
#        
#            DecoderBlock(1024,512),
#            DecoderBlock(512,256),
#            DecoderBlock(256,128),
#            DecoderBlock(128,64),
#        
#            nn.Conv2d(64, 2, kernel_size=1, padding=0)
#        )
#        
#    def forward(self,data_input):
#        self.unet(data_input)
#        

### Initialize weights and try to produce noisy image

In [None]:
image, colored = test_dataset.__getitem__(10)
plt.imshow(colored.permute(1, 2, 0))
#plt.imshow(image.squeeze(0), cmap='binary_r')
print(image.shape, colored.shape)

In [None]:
generator = UNet(100)
generator.apply(weights_init)