In [None]:
import torch
import numpy as np
import os
from torch.utils.data import Dataset
import cv2 as cv
from torchsummary import summary
import pandas as pd
import psutil
import matplotlib.pyplot as plt
from datetime import datetime

In [None]:
device = torch.device("cuda:0")

In [None]:
torch.set_num_threads(4)

In [None]:
TRAIN_IMG_PATH = 'train/image'
TRAIN_LABEL_PATH = 'train/label'
VAL_IMG_PATH = 'val/image'
VAL_LABEL_PATH = 'val/label'
BATCHSIZE = 32
IMAGESIZE = 128

In [None]:
#https://www.kaggle.com/code/ivanshingel/cars-segmentation-research
class CityscapesDataset(Dataset):
    def __init__(self, transforms= None, train= True):
        self.train = train
        self.images_path = TRAIN_IMG_PATH
        self.labels_path = TRAIN_LABEL_PATH
        if not train:
            self.images_path = VAL_IMG_PATH
            self.labels_path = VAL_LABEL_PATH

        self.len = len(os.listdir(self.images_path))

    def __getitem__(self, index):
        #load sample {img,,label}
        naming_label = int(os.listdir(self.images_path)[index].split('.')[0])
        image = np.asarray(np.load(os.path.join(self.images_path, f'{os.listdir(self.images_path)[index]}')),
                                            dtype= np.float32)
        image = cv.resize(image, (IMAGESIZE, IMAGESIZE), interpolation = cv.INTER_NEAREST)
        image = torch.from_numpy(image)
        image = image.permute(2, 0, 1)

        labelSuffix = '_label'
        label = np.asarray(np.load(os.path.join(self.labels_path, f'{naming_label}{labelSuffix}.npy')))

        label = cv.resize(label, (IMAGESIZE, IMAGESIZE), interpolation = cv.INTER_NEAREST)
        label = torch.from_numpy(label)
        label = torch.Tensor(label)
        label = label.reshape(1, IMAGESIZE, IMAGESIZE)
        label = (label + 1)

        return image, label

    def __len__(self):
        return self.len

In [None]:
def show_img_and_mask(img, label, epoch = None):
    if epoch != None:
        fig = plt.figure(figsize=(10, 7))
        fig.add_subplot(1, 2, 1)
        img = img.permute(1,2,0).detach().numpy()
        plt.imshow(img * IMAGESIZE)
        fig.add_subplot(1, 2, 2)
        plt.imshow(label.permute(1,2,0).detach().numpy())
        plt.savefig('pic/epoch_'+str(epoch)+'.png')
        plt.close(fig)
    else:
        fig = plt.figure(figsize=(10, 7))
        fig.add_subplot(1, 2, 1)
        img = img.permute(1,2,0)
        plt.imshow(img * IMAGESIZE)
        fig.add_subplot(1, 2, 2)
        plt.imshow(label.permute(1,2,0))

    return

In [None]:
training_set = CityscapesDataset(train=True)
validation_set = CityscapesDataset(train=False)

In [None]:
#https://pytorch.org/tutorials/beginner/introyt/trainingyt.html
training_loader = torch.utils.data.DataLoader(training_set, batch_size=BATCHSIZE, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=BATCHSIZE, shuffle=False)

In [None]:
class ConvLayer(torch.nn.Module):
    """Convelution layer
       convelution layers for one step on the "stairs"
    """
    def __init__(self, in_channels, out_channels):
        super(ConvLayer, self).__init__()
        self.convlayer = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias = False),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias = False),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.convlayer(x)

class Down(torch.nn.Module):
    """
    Makes the step down
    """
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.double_conv = ConvLayer(in_channels, out_channels)
        self.down_sample = torch.nn.MaxPool2d((2,2), stride=2)

    def forward(self, x):
        skip_out = self.double_conv(x)
        down_out = self.down_sample(skip_out)
        return (down_out, skip_out)

class Up(torch.nn.Module):
    """
    makes the step up
    """
    def __init__(self, in_channels, out_channels):
        super(Up, self).__init__()

        self.double_conv = ConvLayer(in_channels, out_channels)
        self.up_sample = torch.nn.ConvTranspose2d(in_channels= out_channels, out_channels= out_channels, kernel_size=(2,2), stride=2)

    def forward(self, x):
        x = self.double_conv(x)
        return self.up_sample(x)

class Unet(torch.nn.Module):
    def __init__(self):
        super(Unet, self).__init__()
        self.down1 = Down(in_channels = 3, out_channels = 32)
        self.down2 = Down(in_channels = 32, out_channels = 64)
        self.down3 = Down(in_channels = 64, out_channels = 128)
        self.down4 = Down(in_channels = 128, out_channels = 256)

        #The bottom step
        self.bottom = Up(in_channels = 256, out_channels = 512)

        #Up + down 4
        self.up4 = Up(in_channels = 512+256, out_channels = 256)
        #Up + down 3
        self.up3 = Up(in_channels = 256+128, out_channels = 128)
        #Up + down 2
        self.up2 = Up(in_channels = 128+64, out_channels = 64)
        #up and out
        self.out1 = ConvLayer(in_channels = 64+32, out_channels = 32)

        self.out2 = torch.nn.Sequential(
            torch.nn.Conv2d(32, 20, kernel_size=1, bias = False),
            torch.nn.Softmax(dim=1)
        )

    def forward(self, x):
        #Down
        x, skip1 = self.down1(x)
        x, skip2 = self.down2(x)
        x, skip3 = self.down3(x)
        x, skip4 = self.down4(x)
        #Bottom
        x = self.bottom(x)
        #Up
        x = torch.cat((skip4, x), dim = 1)
        x = self.up4(x)

        x = torch.cat((skip3, x), dim = 1)
        x = self.up3(x)

        x = torch.cat((skip2, x), dim = 1)
        x = self.up2(x)

        x = torch.cat((skip1, x), dim = 1)
        x = self.out1(x)
        x = self.out2(x)
        return x

In [None]:
#Input size
model = Unet()
summary(model, input_size = (3,128, 128))

# Train model

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
loss_fn = torch.nn.CrossEntropyLoss()

In [None]:
def train_one_epoch(epoch_index):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for b, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        # Zero your gradients for every batch!
        optimizer.zero_grad()
        # Make predictions for this batch
        outputs = model(inputs)
        # Compute the loss and its gradients
        batch = len(outputs)
        labels = labels.long()

        loss = loss_fn(outputs.reshape(batch, 20, IMAGESIZE, IMAGESIZE),
                       labels.reshape(batch, IMAGESIZE, IMAGESIZE))
        loss.backward()
        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if b % 10 == 9:
            last_loss = running_loss / b # loss per batch
            #print('  batch {} loss: {}'.format(b + 1, last_loss))

    return last_loss

In [None]:
epoch_number = 0

EPOCHS = 100

best_vloss = 100_000

timeLine = []

for epoch in range(EPOCHS):
    start = datetime.now()
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number)

    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            #Send to GPU
            vinputs = vinputs.to(device)
            vlabels = vlabels.to(device)

            voutputs = model(vinputs)
            batch = len(voutputs)
            vlabels = vlabels.long()

            """
            vloss = 0
            for e in range(len(voutputs)):
                    #if i == 1 and e == 1:
                        #show_img_and_mask(vinputs[e], voutputs[e], epoch=epoch)
                    temploss = loss_fn(voutputs[e], vlabels[e])/(IMAGESIZE*IMAGESIZE)
                    vloss += temploss
            running_vloss += vloss/BATCHSIZE"""
            vloss = loss_fn(voutputs.reshape(batch, 20, IMAGESIZE, IMAGESIZE),
                            vlabels.reshape(batch, IMAGESIZE, IMAGESIZE))
            running_vloss += vloss
    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {} time {}'.format(avg_loss, avg_vloss, datetime.now()-start ))
    timeLine.append([epoch, avg_loss, avg_vloss.item()])
    # Track best performance, and save the model's state
    if avg_vloss < best_vloss or epoch_number + 1 == EPOCHS:
        print('model saved')
        best_vloss = avg_vloss
        model_path = 'models/model_{}.pt'.format(epoch_number)
        torch.save(model.state_dict(), model_path)
    elif epoch_number%10 == 9:
        print('model saved periodically')
        model_path = 'models/model_{}.pt'.format(epoch_number)
        torch.save(model.state_dict(), model_path)

    pd.DataFrame(timeLine, columns= ['Epoch', 'Train Loss', 'Validation Loss']).to_csv('loss.csv')
    epoch_number += 1

In [None]:
inputs, labels = 0,0

for b, data in enumerate(training_loader):
    if b == 1:
        # Every data instance is an input + label pair
        inputs, labels = data

In [None]:
model = Unet()
models = [f for f in os.listdir('models')]

for weights in models:
    epoch = int(weights.split('_')[1][:-3])+1
    model = Unet()
    model.load_state_dict(torch.load('models/{}'.format(weights),map_location=torch.device('cpu')))
    model.eval()

    out = model(inputs)

    #make picture
    show_img_and_mask(out[28], labels[28], epoch = epoch)