Use cityscapes dataset, trains directly on damaged images.

Beginning with U-net architecture, will explore other as well.

In [None]:
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import torch
import torchvision
from glob import glob
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F
import torchvision.transforms as transform
from torch.utils.data import DataLoader,Dataset

import sys
sys.path.append("/Users/nathanieljames/Desktop/dl_final_proj/dl_final_proj")
from utils.cityscapes import CityscapesDataset, get_loaders

#dataset: https://www.kaggle.com/datasets/dansbecker/cityscapes-image-pairs/data
# change to germany se

In [None]:
dataloader, valloader = get_loaders(batch_size=32, subclass = "5")
print(len(dataloader), len(valloader))

In [None]:
dtype = torch.float
device = torch.device("mps")

In [None]:
class Convblock(nn.Module):

      def __init__(self,input_channel,output_channel,kernel=3,stride=1,padding=2):

        super().__init__()
        self.convblock = nn.Sequential(
            nn.Conv2d(input_channel,output_channel,kernel,stride,padding),
            nn.BatchNorm2d(output_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(output_channel,output_channel,kernel),
            nn.ReLU(inplace=True),
        )

      def forward(self,x):
        x = self.convblock(x)
        return x

In [None]:
class DirectUNet(nn.Module):

    def __init__(self,input_channel,retain=True):

        super().__init__()

        self.conv1 = Convblock(input_channel,32)
        self.conv2 = Convblock(32,64)
        self.conv3 = Convblock(64,128, kernel = 1, padding = 0)
        # next 4 lines are bottom layer
        self.conv4 = Convblock(128,256, kernel = 1, padding = 0)
        self.neck = nn.Conv2d(256,512,3,1)
        self.upconv4 = nn.ConvTranspose2d(512,256,3,2,0,1)
        self.dconv4 = Convblock(512,256)
        # begin replace here
        #self.neck = nn.Conv2d(256, 256, 3, 1)
        self.upconv3 = nn.ConvTranspose2d(256,128,3,2,0,1)
        self.dconv3 = Convblock(256,128)
        self.upconv2 = nn.ConvTranspose2d(128,64,3,2,0,1)
        self.dconv2 = Convblock(128,64)
        self.upconv1 = nn.ConvTranspose2d(64,32,3,2,0,1)
        self.dconv1 = Convblock(64,32)
        self.out = nn.Conv2d(32,3,1,1,1)
        self.retain = retain

    def forward(self,x):

        # Encoder Network

        # Conv down 1
        conv1 = self.conv1(x)
        pool1 = F.max_pool2d(conv1,kernel_size=2,stride=2)
        # Conv down 2
        conv2 = self.conv2(pool1)
        pool2 = F.max_pool2d(conv2,kernel_size=2,stride=2)
        # Conv down 3
        conv3 = self.conv3(pool2)
        pool3 = F.max_pool2d(conv3,kernel_size=2,stride=2)
        # Conv down 4
        conv4 = self.conv4(pool3)
        pool4 = F.max_pool2d(conv4,kernel_size=2,stride=2)

        # BottelNeck
        neck = self.neck(pool4) #cb to pool 4 if layer 4 left in

        # Decoder Network

        # Upconv 1 again removed
        upconv4 = self.upconv4(neck)
        croped = self.crop(conv4,upconv4)
        # # Making the skip connection 1
        dconv4 = self.dconv4(torch.cat([upconv4,croped],1))
        # Upconv 2
        upconv3 = self.upconv3(dconv4) # replace to dconv 4 if unskip
        croped = self.crop(conv3,upconv3)
        # Making the skip connection 2
        dconv3 = self.dconv3(torch.cat([upconv3,croped],1))
        # Upconv 3
        upconv2 = self.upconv2(dconv3)
        croped = self.crop(conv2,upconv2)
        # Making the skip connection 3
        dconv2 = self.dconv2(torch.cat([upconv2,croped],1))
        # Upconv 4
        upconv1 = self.upconv1(dconv2)
        croped = self.crop(conv1,upconv1)
        # Making the skip connection 4
        dconv1 = self.dconv1(torch.cat([upconv1,croped],1))
        # Output Layer
        out = self.out(dconv1)

        if self.retain == True:
            out = F.interpolate(out,list(x.shape)[2:])

        return out

    def crop(self,input_tensor,target_tensor):
        # Crops for border kernels
        _,_,H,W = target_tensor.shape
        return transform.CenterCrop([H,W])(input_tensor)

In [None]:
model = DirectUNet(3).float()

from torchsummary import summary
summary(model, (3,256,256))
model = model.to(device)

epochs = 50

loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.015)

train_acc = []
val_acc = []
train_loss = []
val_loss = []

In [None]:
best_loss = float('inf')
for i in range(epochs):

    trainloss = 0
    valloss = 0

    for img,label in tqdm(dataloader):
        #print("new image")
        optimizer.zero_grad()
        img = img.to(device)
        label = label.to(device)
        output = model(img)
        loss = loss_func(output,label)
        loss.backward()
        optimizer.step()
        trainloss+=loss.item()

    train_loss.append(trainloss/len(dataloader))

    for img,label in tqdm(valloader):
        img = img.to(device)
        label = label.to(device)
        output = model(img)
        loss = loss_func(output,label)
        valloss+=loss.item()
    epoch_loss = valloss/len(valloader)
    val_loss.append(valloss/len(valloader))
    print("epoch : {} ,train loss : {} ,valid loss : {} ".format(i,train_loss[-1],val_loss[-1]))
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        print(f"new best loss: {best_loss}")
        print("saving")
        torch.save(model.state_dict(), '/content/drive/MyDrive/direct_unet_2.pth')