In [1]:
import zipfile
import os

In [2]:
zip_ref = zipfile.ZipFile('images.zip')
zip_ref.extractall("/images")
zip_ref.close()

zip_ref = zipfile.ZipFile('depth.zip')
zip_ref.extractall("/depth")
zip_ref.close()

In [3]:
#Depth Code:
!pip install piqa

Collecting piqa
  Downloading piqa-1.3.2-py3-none-any.whl (32 kB)
Installing collected packages: piqa
Successfully installed piqa-1.3.2
[0m

In [4]:
import os
import json
import random
import torch
from torchvision import transforms

def create_json_dataset(image_dir, mask_dir):
    for folder_name in os.listdir(image_dir):
        image_path = image_dir + "/" + folder_name
        mask_path = mask_dir + "/" + folder_name
        data = []
        for sub_folder in os.listdir(image_path):
            for image in os.listdir(image_path + "/" + sub_folder):
                image_name = image_path + "/" + sub_folder + "/" + image
                mask_name = mask_path + "/" + sub_folder + "/" + image[:-15] + "depth.png"
                data.append([image_name, mask_name])

        with open(f'{folder_name}.json', "w", encoding='utf-8') as f:
            json.dump(data, f)

def transform(image, mask):
    hflip = transforms.RandomHorizontalFlip(p=1)
    vflip = transforms.RandomVerticalFlip(p=1)
    totensor = transforms.PILToTensor()

    if random.random() > 0.5:
        image = hflip(image)
        mask = hflip(mask)

    # #Vertical Flipping
    if random.random() > 0.5:
        image = vflip(image)
        mask = vflip(mask)

    image = totensor(image)
    mask = totensor(mask)

    return image, mask

def add_result(result, version):
    with open(f'results_v{version}.txt', 'a') as f:
        f.write(result + "\n")
    f.close()

def save_checkpoint(epoch, model, version):
    state = {'epoch': epoch,
             'model': model}
    filename = f'Depth_v{version}.pth.tar'
    torch.save(state, filename)
    
def draw_loss_graph(file_name, save_name, from_epoch = 0):
    if file_name.startswith("http"):
        f = urlopen(file_name).read().decode('utf-8').split("\n")[:-1]
    else:
        f = open(file_name, 'r')
        f = f.readlines()
    x = []
    y = []
    y_val = []
    for i in f[from_epoch:]:
        temp = i.split(" ")
        x.append(int(temp[1]))
        y.append(float(temp[5]))
        y_val.append(float(temp[-1].replace("\n", "")))

    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.plot(x, y, color='c', label='Train Loss')
    plt.plot(x, y_val, color='orange',label='Val Loss')
    plt.ticklabel_format(useOffset=False, style='plain')
    plt.legend()
    plt.savefig(save_name)
    plt.close()
    

In [5]:
from torch.utils.data import Dataset
import json
from PIL import Image
import torch

class CityScapeDepth(Dataset):
    def __init__(self, filename, size = (512, 256),transform=None):
        super(CityScapeDepth, self).__init__()
        f = open(filename)
        self.data = json.load(f)
        self.transform = transform
        self.size = size

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image_path, mask_path = self.data[index]
        image = Image.open(image_path).convert("RGB").resize(self.size)
        mask = Image.open(mask_path).convert('L').resize(self.size, Image.Resampling.NEAREST)

        if self.transform:
            image, mask = self.transform(image, mask)
        image = image / 255.
        mask = mask / 255.

        image_stack = torch.zeros(2, 3, self.size[1], self.size[0]//2)
        mask_stack = torch.zeros(2, 1, self.size[1], self.size[0]//2)

        image_stack[0] = image[:, :, :self.size[0]//2]
        image_stack[1] = image[:, :, self.size[0]//2:]

        mask_stack[0] = mask[:, :, :self.size[0]//2]
        mask_stack[1] = mask[:, :, self.size[0]//2:]
        return image_stack.to(torch.float32), mask_stack.to(torch.float32)

    def collate_fn(self, batch):
        images = []
        masks = []

        for b in batch:
            images.append(b[0])
            masks.append(b[1])

        images = torch.cat(images, dim=0)
        masks = torch.cat(masks, dim=0)

        return images, masks

In [6]:
import torch
import torch.nn as nn
from piqa import SSIM

class AttentionGate(nn.Module):
    def __init__(self, g_in_c, x_in_c):
        super(AttentionGate, self).__init__()

        self.g_conv_layer = nn.Conv2d(g_in_c, x_in_c, 1, 1)
        self.x_conv_layer = nn.Conv2d(x_in_c, x_in_c, 1, 2)
        self.si_conv_layer = nn.Conv2d(x_in_c*2, 1, 1, 1)
        self.resampling = nn.Upsample(scale_factor=2)

    def forward(self, g, x):
        g = self.g_conv_layer(g)
        g = torch.cat([g, self.x_conv_layer(x)], dim=1)
        g = nn.ReLU()(g)
        g = self.si_conv_layer(g)
        g = nn.Sigmoid()(g)
        g = self.resampling(g)
        x = x*g
        return x

class ConvLayers(nn.Module):
    def __init__(self, in_c, out_c):
        super(ConvLayers, self).__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, 3, padding=1)
        self.conv2 = nn.Conv2d(out_c + in_c, out_c, 3, padding=1)
        self.batchNorm = nn.BatchNorm2d(out_c)

    def forward(self, x):
        y = self.conv1(x)
        y = torch.cat([y, x], dim=1)
        y = self.conv2(y)
        y = self.batchNorm(y)
        return nn.ReLU()(y)

class DownSampling(nn.Module):
    def __init__(self, in_c, out_c):
        super(DownSampling, self).__init__()
        self.conv1 = ConvLayers(in_c=in_c, out_c=out_c)
        self.conv2 = ConvLayers(in_c=out_c, out_c=out_c)
        self.dropout = nn.Dropout2d(0.2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x, self.dropout(nn.MaxPool2d(2)(x))

class UpSampling(nn.Module):
    def __init__(self, in_c, out_c):
        super(UpSampling, self).__init__()
        self.attention_layer = AttentionGate(in_c, out_c)
        self.upsampling_layer = nn.Upsample(scale_factor=2)
        self.conv_layer = ConvLayers(in_c + out_c, out_c)
        self.dropout = nn.Dropout2d(0.2)

    def forward(self, x, intermediate_value):
        intermediate_value = self.attention_layer(x, intermediate_value)
        x = self.upsampling_layer(x)
        x = torch.cat([x, intermediate_value], dim=1)
        return self.dropout(self.conv_layer(x))

class UNET(nn.Module):
    def __init__(self, in_c, out_c):
        super(UNET, self).__init__()
        self.layer1 = DownSampling(in_c, 32)
        self.downLayers = nn.ModuleList([DownSampling(2**i, 2**(i + 1)) for i in range(5, 8)])
        self.intermediate_layer = ConvLayers(2**(8), 2**(9))
        self.upLayers = nn.ModuleList([UpSampling(2**i, 2**(i -1)) for i in range(9, 5, -1)])
        self.final_layer = nn.Conv2d(32, out_channels=out_c, kernel_size=1)
        self.activation_layer = nn.Sigmoid()

    def forward(self, x):
        intermediate_values = []
        i, x = self.layer1(x)
        intermediate_values.append(i)
        for layer in self.downLayers:
            i, x = layer(x)
            intermediate_values.append(i)
        x = self.intermediate_layer(x)

        for layer, i in zip(self.upLayers, intermediate_values[::-1]):
            x = layer(x, i)

        x = self.final_layer(x)
        return self.activation_layer(x)


class DepthEstimationLoss(nn.Module):
    def __init__(self):
        super(DepthEstimationLoss, self).__init__()
        self.mse_loss_layer = torch.nn.MSELoss()
        self.smooth_l1_loss_layer = torch.nn.SmoothL1Loss()
        self.ssim = SSIM(n_channels=1)

    def forward(self, predicted_depth, ground_truth_depth):
        MSE_loss = self.mse_loss_layer(predicted_depth, ground_truth_depth)
        smooth_l1_loss = self.smooth_l1_loss_layer(predicted_depth, ground_truth_depth)
        ssim_loss = (1. - self.ssim(predicted_depth, ground_truth_depth))/2
        return MSE_loss + smooth_l1_loss + ssim_loss


In [6]:
# from utils import create_json_dataset

image_folder_path = "/images"
depth_folder_path = "/depth"

create_json_dataset(image_folder_path, depth_folder_path)

In [None]:
# from dataset import CityScapeDepth
# from model import UNET, DepthEstimationLoss
from torch.utils.data import DataLoader
import torch
import torch.backends.cudnn as cudnn
# from utils import transform, add_result, save_checkpoint
import os
import time

def train(checkpoint):
    if checkpoint == None:
        model = UNET(in_c=3, out_c=1)
        start_epoch = 0
    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        model = checkpoint['model']

    model = model.to(device=device)
    criterion = DepthEstimationLoss().to(device=device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-6, amsgrad=True)

    print(f" -- Initiating the Training Process -- Version: {version}")
    print(f"Epoch: {start_epoch}: ")

    for epoch in range(start_epoch, epochs + 1):
        average_loss = 0
        model.train()
        for i, (image, mask) in enumerate(train_gen):
            image = image.to(device)
            mask = mask.to(device)

            optimizer.zero_grad()

            pred_mask = model(image)
            loss = criterion(pred_mask, mask)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            average_loss = average_loss + loss.item()
            torch.cuda.empty_cache()

            if i%50 == 0:
                print("=", end="")

        model.eval()
        validation_loss = 0
        for j, (image, mask) in enumerate(val_gen):
            image = image.to(device)
            mask = mask.to(device)
            pred_mask = model(image)
            loss = criterion(pred_mask, mask)

            validation_loss = validation_loss + loss.item()
            torch.cuda.empty_cache()

        save_checkpoint(epoch=epoch, model=model, version=version)
        add_result(f"Epoch: {epoch} | Average Loss: {average_loss/(i + 1)} | Val Loss: {validation_loss/(j + 1)}", version)
        print(f"   Epoch: {epoch} | Average Loss: {average_loss/(i + 1)} | Val Loss: {validation_loss/(j + 1)} {time.ctime()}")

        
if __name__ == "__main__":
    version = 9
    cudnn.benchmark = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if f"Depth_v{version}.pth.tar" in os.listdir():
        checkpoint = f"Depth_v{version}.pth.tar"
    else:
        checkpoint = None
    batch_size = 9
    workers = 8
    epochs = 3000
    lr = 1e-5
    train_file = "train.json"
    val_file = "val.json"
    size = (512, 256)

    train_dataset = CityScapeDepth(train_file, size, transform)
    val_dataset = CityScapeDepth(val_file, size, transform)

    train_gen = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        num_workers=workers,
        shuffle=True,
        pin_memory=True,
        collate_fn=train_dataset.collate_fn
    )

    val_gen = DataLoader(
        dataset=val_dataset,
        batch_size=batch_size,
        num_workers=workers,
        collate_fn=val_dataset.collate_fn
    )

    train(checkpoint)

 -- Initiating the Training Process -- Version: 9
Epoch: 1849: 
=====