Use cityscapes dataset, trains directly on damaged images.

Beginning with U-net architecture, will explore other as well.

In [None]:
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import torch
import torchvision
from glob import glob
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F
import torchvision.transforms as transform
from torch.utils.data import DataLoader,Dataset

#dataset: https://www.kaggle.com/datasets/dansbecker/cityscapes-image-pairs/data

class CityscapesDataset(Dataset):
    def __init__(self, image_dir, cut_half = True, transform = None):
        self.image_dir = image_dir
        self.imgs = os.listdir(image_dir)

        self.cut_half = cut_half
        self.transforms = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_mask = Image.open(os.path.join(self.image_dir, self.imgs[idx]))
        if self.cut_half:
            x_width, y_height = img_mask.size
            split = x_width / 2

            img = img_mask.crop((0, 0, split, y_height))

            mask = img_mask.crop((split, 0, split + split, y_height))

            if self.transforms:
                img = self.transforms(img)
                mask = self.transforms(mask)

            return img, mask

        return img_mask


In [3]:
transform_init = transform.Compose([
    transform.ToTensor()
])

# NEED TO SWITCH THIS
# SO THAT THEY ARE ALL EITHER SNOWY, RAINY, etc.

dataset = CityscapesDataset(image_dir='/Users/nathanieljames/Desktop/direct/cityscapes_data/train', cut_half=True, transform=transform_init)
val_dataset = CityscapesDataset(image_dir='/Users/nathanieljames/Desktop/direct/cityscapes_data/val', cut_half=True, transform=transform_init)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
valloader = DataLoader(val_dataset)
print(len(dataset), len(dataloader), len(valloader))

2975 93 500


In [4]:
dtype = torch.float
device = torch.device("mps")

In [5]:
class Convblock(nn.Module):

      def __init__(self,input_channel,output_channel,kernel=3,stride=1,padding=1):

        super().__init__()
        self.convblock = nn.Sequential(
            nn.Conv2d(input_channel,output_channel,kernel,stride,padding),
            nn.BatchNorm2d(output_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(output_channel,output_channel,kernel),
            nn.ReLU(inplace=True),
        )

      def forward(self,x):
        x = self.convblock(x)
        return x

In [6]:
class DirectUNet(nn.Module):

    def __init__(self,input_channel,retain=True):

        super().__init__()

        self.conv1 = Convblock(input_channel,32)
        self.conv2 = Convblock(32,64)
        self.conv3 = Convblock(64,128)
        # next 4 lines are bottom layer
        #self.conv4 = Convblock(128,256)
        #self.neck = nn.Conv2d(256,512,3,1)
        #self.upconv4 = nn.ConvTranspose2d(512,256,3,2,0,1)
        #self.dconv4 = Convblock(512,256)
        # begin replace here
        self.neck = nn.Conv2d(128, 256, 3, 1)
        self.upconv3 = nn.ConvTranspose2d(256,128,3,2,0,1)
        self.dconv3 = Convblock(256,128)
        self.upconv2 = nn.ConvTranspose2d(128,64,3,2,0,1)
        self.dconv2 = Convblock(128,64)
        self.upconv1 = nn.ConvTranspose2d(64,32,3,2,0,1)
        self.dconv1 = Convblock(64,32)
        self.out = nn.Conv2d(32,3,1,1)
        self.retain = retain

    def forward(self,x):

        # Encoder Network

        # Conv down 1
        conv1 = self.conv1(x)
        pool1 = F.max_pool2d(conv1,kernel_size=2,stride=2)
        # Conv down 2
        conv2 = self.conv2(pool1)
        pool2 = F.max_pool2d(conv2,kernel_size=2,stride=2)
        # Conv down 3
        conv3 = self.conv3(pool2)
        pool3 = F.max_pool2d(conv3,kernel_size=2,stride=2)
        # Conv down 4 (removed)
        # conv4 = self.conv4(pool3)
        # pool4 = F.max_pool2d(conv4,kernel_size=2,stride=2)

        # BottelNeck
        neck = self.neck(pool3) #cb to pool 4 if layer 4 left in

        # Decoder Network

        # Upconv 1 again removed
        # upconv4 = self.upconv4(neck)
        # croped = self.crop(conv4,upconv4)
        # # Making the skip connection 1
        # dconv4 = self.dconv4(torch.cat([upconv4,croped],1))
        # Upconv 2
        upconv3 = self.upconv3(neck) # replace to dconv 4 if unskip
        croped = self.crop(conv3,upconv3)
        # Making the skip connection 2
        dconv3 = self.dconv3(torch.cat([upconv3,croped],1))
        # Upconv 3
        upconv2 = self.upconv2(dconv3)
        croped = self.crop(conv2,upconv2)
        # Making the skip connection 3
        dconv2 = self.dconv2(torch.cat([upconv2,croped],1))
        # Upconv 4
        upconv1 = self.upconv1(dconv2)
        croped = self.crop(conv1,upconv1)
        # Making the skip connection 4
        dconv1 = self.dconv1(torch.cat([upconv1,croped],1))
        # Output Layer
        out = self.out(dconv1)

        if self.retain == True:
            out = F.interpolate(out,list(x.shape)[2:])

        return out

    def crop(self,input_tensor,target_tensor):
        # Crops for border kernels
        _,_,H,W = target_tensor.shape
        return transform.CenterCrop([H,W])(input_tensor)

In [9]:
model = DirectUNet(3).float()

from torchsummary import summary
summary(model, (3,256,256))
model = model.to(device)

epochs = 5

loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

train_acc = []
val_acc = []
train_loss = []
val_loss = []

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 256, 256]             896
       BatchNorm2d-2         [-1, 32, 256, 256]              64
              ReLU-3         [-1, 32, 256, 256]               0
            Conv2d-4         [-1, 32, 254, 254]           9,248
              ReLU-5         [-1, 32, 254, 254]               0
         Convblock-6         [-1, 32, 254, 254]               0
            Conv2d-7         [-1, 64, 127, 127]          18,496
       BatchNorm2d-8         [-1, 64, 127, 127]             128
              ReLU-9         [-1, 64, 127, 127]               0
           Conv2d-10         [-1, 64, 125, 125]          36,928
             ReLU-11         [-1, 64, 125, 125]               0
        Convblock-12         [-1, 64, 125, 125]               0
           Conv2d-13          [-1, 128, 62, 62]          73,856
      BatchNorm2d-14          [-1, 128,

In [10]:
for i in range(epochs):

    trainloss = 0
    valloss = 0

    for img,label in tqdm(dataloader):
        #print("new image")
        optimizer.zero_grad()
        img = img.to(device)
        label = label.to(device)
        output = model(img)
        loss = loss_func(output,label)
        loss.backward()
        optimizer.step()
        trainloss+=loss.item()

    train_loss.append(trainloss/len(dataloader))

    for img,label in tqdm(valloader):
        img = img.to(device)
        label = label.to(device)
        output = model(img)
        loss = loss_func(output,label)
        valloss+=loss.item()

    val_loss.append(valloss/len(valloader))

    print("epoch : {} ,train loss : {} ,valid loss : {} ".format(i,train_loss[-1],val_loss[-1]))

100%|██████████| 93/93 [1:24:23<00:00, 54.45s/it]
100%|██████████| 500/500 [00:19<00:00, 25.79it/s]


epoch : 0 ,train loss : 0.076130265309926 ,valid loss : 0.04493915801867843 


100%|██████████| 93/93 [3:19:26<00:00, 128.67s/it]  
100%|██████████| 500/500 [02:30<00:00,  3.33it/s]


epoch : 1 ,train loss : 0.04552027643207581 ,valid loss : 0.04342434840649367 


100%|██████████| 93/93 [2:40:01<00:00, 103.24s/it]  
100%|██████████| 500/500 [00:13<00:00, 36.29it/s]


epoch : 2 ,train loss : 0.040475767426272874 ,valid loss : 0.03571418998017907 


100%|██████████| 93/93 [2:36:51<00:00, 101.20s/it]  
100%|██████████| 500/500 [00:11<00:00, 43.24it/s]


epoch : 3 ,train loss : 0.035446759393458725 ,valid loss : 0.034184001591056584 


100%|██████████| 93/93 [2:23:38<00:00, 92.67s/it]   
100%|██████████| 500/500 [00:15<00:00, 32.31it/s]


epoch : 4 ,train loss : 0.03387145539845831 ,valid loss : 0.03324828418344259 


In [11]:
#save model
torch.save(model.state_dict(), '/Users/nathanieljames/Desktop/direct/models/direct_unet.pth')