In [1]:
#Math Libraries
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F

#Metrics Libraries
from torchmetrics import Metric
from torchmetrics.segmentation import MeanIoU
from torchmetrics.classification import MulticlassAccuracy
from torchmetrics.regression import MeanAbsoluteError

#Tensorboard
from torch.utils.tensorboard import SummaryWriter

#Graphic Libraries
from tqdm import tqdm
import matplotlib.pyplot as plt

#System Libraries
import os
import glob

2025-01-22 19:36:17.040593: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2025-01-22 19:36:17.080741: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
for i in range(torch.cuda.device_count()):
   print(torch.cuda.get_device_properties(i).name)

In [None]:
torch.cuda.mem_get_info()

In [5]:
BATCH_SIZE = 8
LEARNING_RATE = 0.0001
EPOCHS = 100
LABELS = 7
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
class CityscapesDataset(Dataset):
    def __init__(self, root="./cityscapes_preprocessed", split="train", labels=7):
        self.root = root
        self.split = split
        self.images = glob.glob(os.path.join(root, split, "image", "*.npy"))
        self.labels = glob.glob(os.path.join(root, split, f"label_{labels}", "*.npy"))
        self.depth = glob.glob(os.path.join(root, split, "depth", "*.npy"))
        self.images.sort()
        self.labels.sort()
        self.depth.sort()

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = torch.from_numpy(np.load(self.images[idx])).permute(2, 0, 1)
        label = torch.from_numpy(np.load(self.labels[idx]))
        depth = torch.from_numpy(np.load(self.depth[idx])).squeeze(2)
        return image, label, depth

In [None]:
cityscapes_train = CityscapesDataset(split="train",labels=LABELS)
train_dl = DataLoader(cityscapes_train, batch_size=BATCH_SIZE, shuffle=False)
for image, label, depth in train_dl:
    print(f'Image: {image.shape}, Label: {label.shape}, Depth: {depth.shape}')
    print(f'Image: {image.max()}, {image.min()}') 
    print(f'Label: {label.max()}, {label.min()}')
    print(f'Depth: {depth.max()}, {depth.min()}')
    break
ax, fig = plt.subplots(3, figsize=(10, 10))
fig[0].imshow(image[0].permute(1, 2, 0))
fig[1].imshow(label[0])
fig[2].imshow(depth[0])
plt.show()

cityscapes_val = CityscapesDataset(split="val", labels=LABELS)
val_dl = DataLoader(cityscapes_val, batch_size=BATCH_SIZE, shuffle=False)

In [8]:
def init_weights(model):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Basic Modules

In [9]:
class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        logits = self.conv(x)
        logits = self.bn(logits)
        logits = self.relu(logits)
        return logits

class DownSampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_layer = ConvLayer(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)

    def forward(self, x):
        out = self.conv_layer(x)
        logits, indices = self.pool(out)
        return logits, indices, out

class UpSampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
        self.conv_layer = ConvLayer(in_channels, out_channels)

    def forward(self, x, indices):
        logits = self.unpool(x, indices)
        up_layer = logits
        logits = self.conv_layer(logits)
        return logits, up_layer

# MTAN model definition

In [10]:
class EncoderSH(nn.Module):
    def __init__(self, filter):
        super().__init__()
        self.enc_blocks = nn.ModuleList()
        self.down_blocks = nn.ModuleList()
        self.enc_blocks.append(ConvLayer(3, filter[0]))
        self.down_blocks.append(DownSampleBlock(filter[0], filter[0]))
        for i in range(len(filter) - 1):
            self.enc_blocks.append(ConvLayer(filter[i], filter[i+1]))
            self.down_blocks.append(DownSampleBlock(filter[i+1], filter[i+1]))

    def forward(self, x):
        down_indices = []
        down_layer = []
        enc_layer = []
        out = []
        logits = x
        for i in range(len(self.down_blocks)):
            logits = self.enc_blocks[i](logits)
            enc_layer.append(logits)
            logits, indices, down = self.down_blocks[i](logits)
            down_layer.append(down)    
            out.append(logits)
            down_indices.append(indices)
            
        return logits, enc_layer, down_layer, down_indices, out

class DecoderSH(nn.Module):
    def __init__(self, filter):
        super().__init__()
        self.up_blocks = nn.ModuleList()
        self.dec_blocks = nn.ModuleList()
        for i in range(len(filter) - 1):
            self.up_blocks.append(UpSampleBlock(filter[i], filter[i+1]))
            self.dec_blocks.append(ConvLayer(filter[i+1], filter[i+1]))
        self.up_blocks.append(UpSampleBlock(filter[-1], filter[-1]))
        self.dec_blocks.append(ConvLayer(filter[-1], filter[-1]))
            

    def forward(self, x, down_indices):
        up_layer = []
        dec_layer = []
        logits = x
        for i in range(len(self.up_blocks)):
            logits, up = self.up_blocks[i](logits, down_indices[-(i+1)])
            up_layer.append(up)
            logits = self.dec_blocks[i](logits)
            dec_layer.append(logits)
        return up_layer, dec_layer
    
class SharedNet(nn.Module):
    def __init__(self, filter):
        super().__init__()
        self.enc = EncoderSH(filter)
        self.dec = DecoderSH([filter[-(i+1)] for i in range(len(filter))])  

    def forward(self, x):
        logits, enc_layer, down_layer, down_indices, enc_out = self.enc(x)
        enc_dict = {'out': enc_layer, 'down': down_layer}
        up_layer, dec_layer = self.dec(logits, down_indices)
        dec_dict = {'out': dec_layer, 'up': up_layer}
        out_dict = {'enc': enc_out, 'dec': dec_layer}
        return enc_dict, dec_dict, down_indices, out_dict
    
class AttEncBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super().__init__()
        self.att_layer_g = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=1, padding=0, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU()
        )
        self.att_layer_h = nn.Sequential(
            nn.Conv2d(mid_channels, mid_channels, kernel_size=1, padding=0, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.Sigmoid()
        )
        self.att_layer_f = ConvLayer(mid_channels, out_channels)
        self.down = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, enc_layer, down_layer, x=None):
        logits = enc_layer if x == None else torch.cat([enc_layer, x], dim=1)
        g = self.att_layer_g(logits)
        h = self.att_layer_h(g)
        p = h * down_layer
        logits = self.att_layer_f(p)
        logits = self.down(logits)
        return logits


class AttDecBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2)
        self.att_layer_f = ConvLayer(in_channels, out_channels)
        self.att_layer_g = nn.Sequential(
            nn.Conv2d(mid_channels, out_channels, kernel_size=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        self.att_layer_h = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.Sigmoid()
        )

    def forward(self, x, up_layer, dec_layer):
        logits = self.up(x)
        logits = self.att_layer_f(logits)
        logits = torch.cat([logits, up_layer], dim=1)
        g = self.att_layer_g(logits)
        h = self.att_layer_h(g)
        p = h * dec_layer
        logits = p
        return logits
    
class AttNet(nn.Module):
    def __init__(self, filter):
        super().__init__()
        self.enc_att = nn.ModuleList()
        self.dec_att = nn.ModuleList()

        self.enc_att.append(AttEncBlock(filter[0], filter[0], filter[1]))
        for i in range(1,len(filter)-1):
            self.enc_att.append(AttEncBlock(2*filter[i], filter[i], filter[i+1]))
        self.enc_att.append(AttEncBlock(2*filter[-1], filter[-1], filter[-1]))
        
        for i in range(1, len(filter)):
            self.dec_att.append(AttDecBlock(filter[-i], filter[-i]+filter[-i-1], filter[-i-1]))
        self.dec_att.append(AttDecBlock(filter[0], 2*filter[0], filter[0]))

    def forward(self, enc_dict, dec_dict):
        for i in range(len(self.enc_att)):
            if i == 0:
                logits = self.enc_att[i](enc_dict['out'][i], enc_dict['down'][i])
            else:
                logits = self.enc_att[i](enc_dict['out'][i], enc_dict['down'][i], logits)

        for i in range(len(self.dec_att)):
            logits = self.dec_att[i](logits, dec_dict['up'][i], dec_dict['out'][i])
        return logits
    
class MTAN(nn.Module):
    def __init__(self, classes=7, tasks=2):
        super().__init__()
        self.name = "mtan"
        filter = [64, 128, 256, 512, 512]
        self.classes = classes + 1 #background
        self.tasks = tasks
        self.sh_net = SharedNet(filter)
        self.attnet_task = nn.ModuleList([AttNet(filter) for _ in range(tasks)])
        #to train with cross entropy loss
        self.seg_head = nn.Conv2d(filter[0], self.classes, kernel_size=1)
        #to train with L1 loss
        self.depth_head = nn.Sequential(
            nn.Conv2d(filter[0], 1, kernel_size=1), 
            nn.Sigmoid()
        )
        init_weights(self)

    def forward(self, x):
        enc_dict, dec_dict, _, _ = self.sh_net(x)
        logits = []
        for i in range(self.tasks):
            logits.append(self.attnet_task[i](enc_dict, dec_dict))
        logits_seg = self.seg_head(logits[0])
        logits_depth = self.depth_head(logits[1])
        return logits_seg, logits_depth

# DenseNet model definition

In [11]:
class TaskNet(nn.Module):
    def __init__(self, filter, classes):
        super().__init__()
        self.classes = classes
        self.start_conv = nn.Sequential(
            ConvLayer(3, filter[0]),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )  
        self.dense_enc = nn.ModuleList()
        self.dense_dec = nn.ModuleList()
        for i in range(len(filter)-1):
            dense_block_enc = nn.Sequential(
                ConvLayer(2*filter[i], filter[i+1]),
                ConvLayer(filter[i+1], filter[i+1]),
                nn.MaxPool2d(kernel_size=2, stride=2)   
            )
            self.dense_enc.append(dense_block_enc)
        dense_block_enc = nn.Sequential(
            ConvLayer(2*filter[-1], filter[-1]),
            ConvLayer(filter[-1], filter[-1]),
            nn.Upsample(scale_factor=2)
        )
        self.dense_enc.append(dense_block_enc)
        for i in range(len(filter)-1):
            dense_block_dec = nn.Sequential(
                ConvLayer(filter[-i-1]+filter[-i-2], filter[-i-2]),
                ConvLayer(filter[-i-2], filter[-i-2]),
                nn.Upsample(scale_factor=2)
            )
            self.dense_dec.append(dense_block_dec)
        dense_block_dec = nn.Sequential(
            ConvLayer(filter[0]+filter[0], filter[0]),
            ConvLayer(filter[0], filter[0])
        )
        self.dense_dec.append(dense_block_dec)
        self.head = nn.Sequential(
            ConvLayer(filter[0], filter[0]),
            ConvLayer(filter[0], filter[0]),
            nn.Conv2d(filter[0], classes, kernel_size=1)
        )
        
    def forward(self, x, out_dict):
        logits = self.start_conv(x)
        for i in range(len(self.dense_enc)):
            feat_in = torch.cat((logits, out_dict['enc'][i]), dim=1)
            logits = self.dense_enc[i](feat_in)
        for i in range(len(self.dense_dec)):
            feat_in = torch.cat((logits, out_dict['dec'][i]), dim=1)
            logits = self.dense_dec[i](feat_in)
        logits = self.head(logits)
        return logits

class DenseNet(nn.Module):
    def __init__(self, classes=7):
        super().__init__()
        self.name = "densenet"
        filter = [64, 128, 256, 512, 512]
        self.classes = classes
        self.sh_net = SharedNet(filter)
        self.seg_net = TaskNet(filter, self.classes+1)
        self.depth_net = TaskNet(filter, classes=1)
        init_weights(self)

    def forward(self, x):
        _, _, _, out_dict = self.sh_net(x)
        logits_seg = self.seg_net(x, out_dict)
        logits_depth = self.depth_net(x, out_dict)
        return logits_seg, logits_depth

# SegNet, DepthNet, SplitNet model definition

In [12]:
class Encoder(nn.Module):
    def __init__(self, filter):
        super().__init__()
        start_block = nn.Sequential(
            ConvLayer(3, filter[0]), 
            ConvLayer(filter[0], filter[0]), 
            ConvLayer(filter[0], filter[0])
        )
        self.enc_blocks = nn.ModuleList([start_block])
        self.down_blocks = nn.ModuleList([DownSampleBlock(filter[0], filter[0])])
        for i in range(len(filter) - 1):
            block = nn.Sequential(
                ConvLayer(filter[i], filter[i+1]), 
                ConvLayer(filter[i+1], filter[i+1]), 
                ConvLayer(filter[i+1], filter[i+1]),
            )
            self.enc_blocks.append(block)
            self.down_blocks.append(DownSampleBlock(filter[i+1], filter[i+1]))

    def forward(self, x):
        down_indices = []
        logits = x
        for i in range(len(self.down_blocks)):
            logits = self.enc_blocks[i](logits)
            logits, indices, _ = self.down_blocks[i](logits)
            down_indices.append(indices)
        return logits, down_indices

class Decoder(nn.Module):
    def __init__(self, filter):
        super().__init__()
        self.up_blocks = nn.ModuleList()
        self.dec_blocks = nn.ModuleList()
        for i in range(len(filter) - 1):
            block = nn.Sequential(
                ConvLayer(filter[i+1], filter[i+1]), 
                ConvLayer(filter[i+1], filter[i+1]),
                ConvLayer(filter[i+1], filter[i+1])
            )
            self.dec_blocks.append(block)
            self.up_blocks.append(UpSampleBlock(filter[i], filter[i+1]))
        self.up_blocks.append(UpSampleBlock(filter[-1], filter[-1]))
        block = nn.Sequential(
            ConvLayer(filter[-1], filter[-1]), 
            ConvLayer(filter[-1], filter[-1]),
            ConvLayer(filter[-1], filter[-1])
        )
        self.dec_blocks.append(block)

    def forward(self, x, down_indices):
        logits = x
        for i in range(len(self.up_blocks)):
            logits, _ = self.up_blocks[i](logits, down_indices[-(i+1)])
            logits = self.dec_blocks[i](logits)
        return logits

class SegNet(nn.Module):
    def __init__(self, classes=7, mid_layers=4):
        super().__init__()
        self.name = "segnet"
        filter = [64, 128, 256, 512, 512]
        self.classes = classes + 1
        self.enc_net = Encoder(filter)
        self.mid_net = nn.Sequential(*[ConvLayer(filter[-1], filter[-1]) for _ in range(mid_layers)])
        self.dec_net = Decoder([filter[-(i+1)] for i in range(len(filter))])
        
        self.seg_head = nn.Sequential(
            ConvLayer(filter[0], filter[0]),
            nn.Conv2d(filter[0], self.classes, kernel_size=1)
        )
        init_weights(self)

    def forward(self, x):
        logits, down_indices = self.enc_net(x)
        logits = self.mid_net(logits)
        logits = self.dec_net(logits, down_indices)
        logits = self.seg_head(logits)
        return logits

class DepthNet(nn.Module):
    def __init__(self, mid_layers=4):
        super().__init__()
        filter = [64, 128, 256, 512, 512]
        self.name = "depthnet"
        self.classes = 1
        self.enc_net = Encoder(filter)
        self.mid_net = nn.Sequential(*[ConvLayer(filter[-1], filter[-1]) for _ in range(mid_layers)])
        self.dec_net = Decoder([filter[-(i+1)] for i in range(len(filter))])
        self.depth_head = nn.Sequential(
            ConvLayer(filter[0], filter[0]),
            nn.Conv2d(filter[0], self.classes, kernel_size=1),
            nn.Sigmoid()
        )
        init_weights(self)

    def forward(self, x):
        logits, down_indices = self.enc_net(x)
        logits = self.mid_net(logits)
        logits = self.dec_net(logits, down_indices)
        logits = self.depth_head(logits)
        return logits

class SplitNet(nn.Module):
    def __init__(self, classes=7):
        super().__init__()
        self.name = "splitnet"
        filter = [64, 128, 256, 512, 512]
        self.classes = classes + 1
        self.enc_net = Encoder(filter)
        self.mid_net = nn.Sequential(
            ConvLayer(filter[-1], filter[-1]),
            ConvLayer(filter[-1], filter[-1]),
            ConvLayer(filter[-1], filter[-1]),
            ConvLayer(filter[-1], filter[-1]),
        )
        self.dec_net = Decoder([filter[-(i+1)] for i in range(len(filter))])
        self.seg_head = nn.Sequential(
            ConvLayer(filter[0], filter[0]),
            ConvLayer(filter[0], filter[0]),
            nn.Conv2d(filter[0], self.classes, kernel_size=1)
        )
        self.depth_head = nn.Sequential(
            ConvLayer(filter[0], filter[0]),
            ConvLayer(filter[0], filter[0]),
            nn.Conv2d(filter[0], 1, kernel_size=1),
            nn.Sigmoid()
        )
        init_weights(self)

    def forward(self, x):
        logits, down_indices = self.enc_net(x)
        logits = self.mid_net(logits)
        logits = self.dec_net(logits, down_indices)
        logits_seg = self.seg_head(logits)
        logits_depth = self.depth_head(logits)
        return logits_seg, logits_depth

# CrossStitch model definition

In [13]:
class CrossStitchNet(nn.Module):
    def __init__(self, classes=7, tasks=2):
        super().__init__()
        self.name = "crossstitch"
        filter = [64, 128, 256, 512, 512]
        self.classes = classes + 1
        self.tasks = tasks
        self.alphas = nn.ParameterList([nn.Parameter(torch.rand(2)) for _ in range(2)])

        self.nets = nn.ModuleList()
        self.nets.append(nn.ModuleList())
        self.nets.append(nn.ModuleList())

        for i in range(tasks):
            self.nets[i].append(ConvLayer(3, filter[0]))
        # self.netA = nn.ModuleList()
        # self.netB = nn.ModuleList()
        # self.netA.append(ConvLayer(3, filter[0]))
        # self.netB.append(ConvLayer(3, filter[0]))

        for i in range(len(filter)-1):
            for j in range(tasks):
                self.nets[j].append(ConvLayer(filter[i], filter[i+1]))
                self.nets[j].append(ConvLayer(filter[i+1], filter[i+1]))
                self.nets[j].append(DownSampleBlock(filter[i+1], filter[i+1]))

            # self.netA.append(ConvLayer(filter[i], filter[i+1]))
            # self.netA.append(ConvLayer(filter[i+1], filter[i+1]))
            # self.netB.append(ConvLayer(filter[i], filter[i+1]))
            # self.netB.append(ConvLayer(filter[i+1], filter[i+1]))
            #for _ in range(2):
                #self.netA.append(ConvLayer(filter[i+1], filter[i+1]))
                #self.netB.append(ConvLayer(filter[i+1], filter[i+1]))
            # self.netA.append(DownSampleBlock(filter[i+1], filter[i+1]))
            # self.netB.append(DownSampleBlock(filter[i+1], filter[i+1]))

        # for _ in range(2):
        #     self.netA.append(ConvLayer(filter[-1], filter[-1]))
        #     self.netB.append(ConvLayer(filter[-1], filter[-1]))

        for i in range(len(filter)-1):
            for j in range(tasks):
                self.nets[j].append(UpSampleBlock(filter[-(i+1)], filter[-(i+2)]))
                self.nets[j].append(ConvLayer(filter[-(i+2)], filter[-(i+2)]))
                self.nets[j].append(ConvLayer(filter[-(i+2)], filter[-(i+2)]))
            # self.netA.append(UpSampleBlock(filter[-(i+1)], filter[-(i+2)]))
            # self.netB.append(UpSampleBlock(filter[-(i+1)], filter[-(i+2)]))
            #for _ in range(2):
                #self.netA.append(ConvLayer(filter[-(i+2)], filter[-(i+2)]))
                #self.netB.append(ConvLayer(filter[-(i+2)], filter[-(i+2)]))
            # self.netA.append(ConvLayer(filter[-(i+2)], filter[-(i+2)]))
            # self.netA.append(ConvLayer(filter[-(i+2)], filter[-(i+2)]))
            # self.netB.append(ConvLayer(filter[-(i+2)], filter[-(i+2)]))
            # self.netB.append(ConvLayer(filter[-(i+2)], filter[-(i+2)]))

        heads = nn.ModuleList()
        heads.append(nn.Conv2d(filter[0], self.classes, kernel_size=1))
        heads.append(nn.Sequential(
            nn.Conv2d(filter[0], 1, kernel_size=1), 
            nn.Sigmoid())
        )
        for i in range(tasks):
            self.nets[i].append(ConvLayer(filter[0], filter[0]))
            self.nets[i].append(heads[i])
        # self.nets[0].append(nn.Conv2d(filter[0], self.classes, kernel_size=1))
        # self.nets[1].append(nn.Conv2d(filter[0], 1, kernel_size=1))
        # self.nets[1].append(nn.Sigmoid())
        # self.netA.append(ConvLayer(filter[0], filter[0]))
        # self.netA.append(nn.Conv2d(filter[0], self.classes, kernel_size=1))
        # self.netB.append(ConvLayer(filter[0], filter[0]))
        # self.netB.append(nn.Conv2d(filter[0], 1, kernel_size=1))
        # self.netB.append(nn.Sigmoid())


    def forward(self, x):
        logits_seg = x 
        logits_depth = x
        indices_A = []
        indices_B = []
        j = 1
        for modA, modB in zip(self.nets[0], self.nets[1]):
            if isinstance(modA, DownSampleBlock):
                logits_seg, idx_A, _ = modA(logits_seg)
                logits_depth, idx_B, _ = modB(logits_depth)
                indices_A.append(idx_A)
                indices_B.append(idx_B)
                logits_seg = self.alphas[0][0] * logits_seg + self.alphas[0][1] * logits_depth
                logits_depth = self.alphas[1][0] * logits_depth + self.alphas[1][1] * logits_seg
            elif isinstance(modA, UpSampleBlock):
                logits_seg, _ = modA(logits_seg, indices_A[-j])
                logits_depth, _ = modB(logits_depth, indices_B[-j])
                j += 1
                logits_seg = self.alphas[0][0] * logits_seg + self.alphas[0][1] * logits_depth
                logits_depth = self.alphas[1][0] * logits_depth + self.alphas[1][1] * logits_seg
            else: #isinstance(modA, ConvLayer) or head
                logits_seg = modA(logits_seg)
                logits_depth = modB(logits_depth)
        return logits_seg, logits_depth
                

In [None]:
segnet = SegNet()
depthnet = DepthNet()
mtan = MTAN()
splitnet = SplitNet()
crossstitch = CrossStitchNet()
densenet = DenseNet()
print(f"SegNet parameters: {count_params(segnet)}")
print(f"DepthNet parameters: {count_params(depthnet)}")

print(f"MTAN parameters: {count_params(mtan)}")
print(f"DenseNet parameters: {count_params(densenet)}")
print(f"CrossStitchNet parameters: {count_params(crossstitch)}")
print(f"SplitNet parameters: {count_params(splitnet)}")

# Training Pipelines

In [15]:
class MeanAbsoluteRelativeError(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("sum_rel_err", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("num_obs", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        sum_abs_target = torch.sum(torch.abs(target))
        self.sum_rel_err += torch.sum(torch.abs(preds - target))/sum_abs_target
        self.num_obs += target.shape[0]

    def compute(self):
        return self.sum_rel_err / self.num_obs

In [16]:
def add_plt(plt, data):
    for k in data.keys():
        plt[k].append(data[k].compute().cpu()) if isinstance(data[k], Metric) else plt[k].append(data[k])

def compute_lambdas(losses_seg, losses_depth, T, K):
    w_seg = np.mean(losses_seg['new']) / np.mean(losses_seg['old'])
    w_depth = np.mean(losses_depth['new']) / np.mean(losses_depth['old'])
    w = F.softmax(torch.tensor([w_seg/T, w_depth/T]), dim=0)*K
    return w

def update_stats(stats, x, y):
    for k in stats.keys():
        stats[k].update(x, y)

def reset_stats(stats):
    for k in stats.keys():
        stats[k].reset()

def update_losses(losses_seg, losses_depth):
    losses_seg['old'] = losses_seg['new']
    losses_depth['old'] = losses_depth['new']
    losses_seg['new'] = []
    losses_depth['new'] = []

# def save_fig_plots(model, epochs, plots=None):
#     plt.savefig(f"./models/{model.name}/{model.name}_train{epochs}.png")
#     torch.save(model.state_dict(), f"./models/{model.name}/{model.name}_train{epochs}.pth")
#     if plots is not None:
#         if not os.path.exists(f"./models/{model.name}/plots"):
#             os.makedirs(f"./models/{model.name}/plots")
#         for k in plots.keys():
#             torch.save(plots[k], f"./models/{model.name}/plots/{k}.pth")

def save_model_opt(model, opt, epochs):
    torch.save(model.state_dict(), f"./models/{model.name}/{model.name}_train{epochs}.pth")
    torch.save(opt.state_dict(), f"./models/{model.name}/{model.name}_opt_train{epochs}.pth")

def compute_loss_multitask(model, x, y_seg, y_dis, stats_seg, stats_depth):
    x = x.to(DEVICE).to(torch.float)
    y_seg = y_seg.to(DEVICE).to(torch.long)
    y_dis = y_dis.to(DEVICE).to(torch.float)
    loss_fn_seg = nn.CrossEntropyLoss(ignore_index=-1)

    loss_fn_depth = nn.L1Loss()
    output_seg, output_depth = model(x)
    
    loss_seg = loss_fn_seg(output_seg, y_seg)
    loss_depth = loss_fn_depth(output_depth.squeeze(1), y_dis)

    preds_seg = torch.argmax(output_seg, dim=1)
    preds_seg_flat = preds_seg.view(-1)
    y_seg_flat = y_seg.view(-1)
    pos_idx = torch.where(y_seg_flat != -1)
    preds_seg_flat = preds_seg_flat[pos_idx[0]].unsqueeze(0)
    y_seg_flat = y_seg_flat[pos_idx[0]].unsqueeze(0)

    update_stats(stats_seg, preds_seg_flat, y_seg_flat)
    update_stats(stats_depth, output_depth.squeeze(1), y_dis)
    return loss_seg, loss_depth    

def compute_grad(model):
    params = [p for p in model.parameters() if p.grad is not None and p.requires_grad]
    grad_norm = 0
    for p in params:
        p_grad = p.grad.detach().data.norm(2).item()
        grad_norm += p_grad**2
    return grad_norm**0.5

In [17]:
def train_multitask_dwa(model, opt, train_dl, val_dl=None, epochs=10, update_lambdas=10, T=2, save=False, check=5, grad=False):
    model = model.to(DEVICE)
    writer = SummaryWriter(f'./runs/{model.name}')

    lambdas = np.array([1, 1])
    losses_seg = {'new': [], 'old': []}
    losses_depth = {'new': [], 'old': []}
    plt_losses_train = {'seg': [], 'depth': [], 'total': []}
    plt_stats_train = {'miou': [], 'pix_acc': [], 'mae': [], 'mre': []}
    plt_lambdas = {'lambda0': [], 'lambda1': []}
    plt_grad = []

    miou = MeanIoU(num_classes=model.classes, per_class=False, include_background=False, input_format='index').to(DEVICE)
    pix_acc = MulticlassAccuracy(num_classes=model.classes, multidim_average='global', average='micro').to(DEVICE)
    stats_seg = {'miou':miou, 'pix_acc':pix_acc}
    mae = MeanAbsoluteError().to(DEVICE)
    mre = MeanAbsoluteRelativeError().to(DEVICE)
    stats_depth = {'mae':mae, 'mre':mre}
    if val_dl != None:
        plt_losses_val = {'seg': [], 'depth': [], 'total': []}
        plt_stats_val = {'miou': [], 'pix_acc': [], 'mae': [], 'mre': []}
        
    if save and not os.path.exists(f"./models/{model.name}"): 
        os.makedirs(f"./models/{model.name}")
    for epoch in range(epochs):
        model.train()

        reset_stats(stats_seg)
        reset_stats(stats_depth)
    
        total_loss = 0
        total_loss_seg = 0
        total_loss_depth = 0
        for x, y_seg, y_dis in tqdm(train_dl):
            opt.zero_grad()
            loss_seg, loss_depth = compute_loss_multitask(model, x, y_seg, y_dis, stats_seg, stats_depth)
            loss = lambdas[0]*loss_seg + lambdas[1]*loss_depth
            loss.backward()
            opt.step()

            losses_seg['new'].append(loss_seg.item())
            losses_depth['new'].append(loss_depth.item())
            if len(losses_seg['new']) == 2*update_lambdas and len(losses_seg['old']) == 0:
                losses_seg['old'] = losses_seg['new'][0:update_lambdas]
                losses_depth['old'] = losses_depth['new'][0:update_lambdas]
                losses_seg['new'] = losses_seg['new'][update_lambdas:]
                losses_depth['new'] = losses_depth['new'][update_lambdas:]

                lambdas = compute_lambdas(losses_seg, losses_depth, T, model.classes)
                update_losses(losses_seg, losses_depth)

            if len(losses_seg['new']) == update_lambdas and len(losses_seg['old']) == update_lambdas:
                lambdas = compute_lambdas(losses_seg, losses_depth, T, model.classes)
                update_losses(losses_seg, losses_depth)

            total_loss += loss.item()
            total_loss_seg += loss_seg.item()
            total_loss_depth += loss_depth.item()

        plt_lambdas['lambda0'].append(lambdas[0].item())
        plt_lambdas['lambda1'].append(lambdas[1].item())
        total_loss /= len(train_dl)
        total_loss_seg /= len(train_dl)
        total_loss_depth /= len(train_dl)
        plt_losses_train['seg'].append(total_loss_seg)
        plt_losses_train['depth'].append(total_loss_depth)
        plt_losses_train['total'].append(total_loss)
        # add_plt(plt_stats_train, stats_seg)
        # add_plt(plt_stats_train, stats_depth)
        print_stats = dict(stats_seg, **stats_depth)
        add_plt(plt_stats_train, print_stats)
        if grad:
            grad_norm = compute_grad(model)
            plt_grad.append(grad_norm)
            writer.add_scalar('Train/Gradient', grad_norm, epoch) 
        if epoch % check == 0:
            print(f"Epoch {epoch}/{epochs} - Train Total Loss: {total_loss:.4f}")
            print(f"Lambda_0: {lambdas[0]} - Train Loss Segmentation: {total_loss_seg:.4f}")
            print(f"Lambda_1: {lambdas[1]} - Train Loss Depth: {total_loss_depth:.4f}")
            for k in print_stats.keys():
                print(f"{k}: {print_stats[k].compute().cpu()}")
            print(f"Gradient Norm: {grad_norm}\n") if grad else print("\n")
            save_model_opt(model, opt, epoch) if save else None
        writer.add_scalar('Train/Loss/Total', total_loss, epoch)
        writer.add_scalar('Train/Loss/Segmentation', total_loss_seg, epoch)
        writer.add_scalar('Train/Loss/Depth', total_loss_depth, epoch)
        for k in print_stats.keys():
            writer.add_scalar(f'Train/{k}', print_stats[k].compute().cpu(), epoch)
                
        if val_dl != None and epoch % check == 0:
            losses_tmp, stats_tmp = val_epoch_multitask(model, val_dl, writer, epoch)
            add_plt(plt_losses_val, losses_tmp)
            add_plt(plt_stats_val, stats_tmp)

    _, ax = plt.subplots(4, 2, figsize=(40, 40)) if not grad else plt.subplots(5, 2, figsize=(50, 50))
    ax[0][0].plot(plt_losses_train['seg'])
    ax[0][0].set_title('Segmentation Loss')
    ax[0][1].plot(plt_losses_train['depth'])
    ax[0][1].set_title('Depth Loss')
    ax[1][0].plot(plt_lambdas['lambda0'])
    ax[1][0].plot(plt_lambdas['lambda1'])
    ax[1][0].set_title('Lambdas')
    ax[1][1].plot(plt_losses_train['total'])
    ax[1][1].set_title('Total Loss')
    ax[2][0].plot(plt_stats_train['miou'])
    ax[2][0].set_title('Mean IoU')
    ax[2][1].plot(plt_stats_train['pix_acc'])
    ax[2][1].set_title('Pixel Accuracy')
    ax[3][0].plot(plt_stats_train['mae'])
    ax[3][0].set_title('Mean Absolute Error')
    ax[3][1].plot(plt_stats_train['mre'])
    ax[3][1].set_title('Mean Absolute Relative Error')
    if grad:
        ax[4][0].plot(plt_grad)
        ax[4][0].set_title('Gradient Norm')
    if save:
        plt.savefig(f"./models/{model.name}/{model.name}_train{epochs}.png")
        torch.save(model.state_dict(), f"./models/{model.name}/{model.name}_train{epochs}.pth")
        # plots = {'plt_losses_train': plt_losses_train, 'plt_losses_val': plt_losses_val, 'plt_stats_train': plt_stats_train, 'plt_stats_val': plt_stats_val}
        # save_fig_plots(model, epochs, plots)

    if val_dl != None:
        _, ax = plt.subplots(3, 1, figsize=(20, 20))
        ax[0].plot(plt_losses_val['seg'])
        ax[0].set_title('Segmentation Loss')
        ax[1].plot(plt_losses_val['depth'])
        ax[1].set_title('Depth Loss')
        ax[2].plot(plt_losses_val['total'])
        ax[2].set_title('Total Loss')
        plt.savefig(f"./models/{model.name}/{model.name}_val{epochs}.png") if save else None

def val_epoch_multitask(model, val_dl, writer, epoch):
    with torch.no_grad():
        model.eval()
        total_loss = 0
        total_loss_seg = 0
        total_loss_depth = 0

        miou = MeanIoU(num_classes=model.classes, per_class=False, include_background=False, input_format='index').to(DEVICE)
        pix_acc = MulticlassAccuracy(num_classes=model.classes, multidim_average='global', average='micro').to(DEVICE)
        stats_seg = {'miou':miou, 'pix_acc':pix_acc}
        mae = MeanAbsoluteError().to(DEVICE)
        mre = MeanAbsoluteRelativeError().to(DEVICE)
        stats_depth = {'mae':mae, 'mre':mre}
        for x, y_seg, y_dis in tqdm(val_dl):
            loss_seg, loss_depth = compute_loss_multitask(model, x, y_seg, y_dis, stats_seg, stats_depth)
            loss = loss_seg + loss_depth
            total_loss += loss.item()
            total_loss_seg += loss_seg.item()
            total_loss_depth += loss_depth.item()
        total_loss /= len(val_dl)
        total_loss_seg /= len(val_dl)
        total_loss_depth /= len(val_dl)
        writer.add_scalar('Val/Loss/Total', total_loss, epoch)
        writer.add_scalar('Val/Loss/Segmentation', total_loss_seg, epoch)
        writer.add_scalar('Val/Loss/Depth', total_loss_depth, epoch)
        print(f"Val Total Loss: {total_loss:.4f}")
        print(f"Val Loss Segmentation: {total_loss_seg:.4f}")
        print(f"Val Loss Depth: {total_loss_depth:.4f}")
        losses = {'total': total_loss, 'seg': total_loss_seg, 'depth': total_loss_depth}
        # stats_comp = {
        #     'miou': stats_seg['miou'].compute().cpu(), 
        #     'pix_acc': stats_seg['pix_acc'].compute().cpu(), 
        #     'mae': stats_depth['mae'].compute().cpu(), 
        #     'mre': stats_depth['mre'].compute().cpu()
        # }
        stats_comp = dict(stats_seg, **stats_depth)
        for k in stats_comp.keys():
            print(f"{k}: {stats_comp[k].compute().cpu()}")
            writer.add_scalar(f'Val/{k}', stats_comp[k].compute().cpu(), epoch)
        print("\n")
        return losses, stats_comp

In [None]:
model = MTAN()
opt = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
nparams = count_params(model)
print(f"Number of trainable parameters: {nparams}")
train_multitask_dwa(model, opt, train_dl, val_dl, epochs=10, update_lambdas=10, T=2, save=True, check=2, grad=True)

In [112]:
def compute_loss_singletask(model, x, y, loss_fn, stats):
    x = x.to(DEVICE)
    y = y.to(DEVICE)
    output = model(x)
    
    if isinstance(loss_fn, nn.CrossEntropyLoss):
        loss = loss_fn(output, y)
        preds = torch.argmax(output, dim=1)
        preds_flat = preds.view(-1)
        y_flat = y.view(-1)
        pos_idx = torch.where(y_flat != -1)
        preds_flat = preds_flat[pos_idx[0]].unsqueeze(0)
        y_flat = y_flat[pos_idx[0]].unsqueeze(0)
        update_stats(stats, preds_flat, y_flat)
    else:
        loss = loss_fn(output.squeeze(1), y)
        preds = output.squeeze(1)
        update_stats(stats, preds, y)
    return loss

def train_singletask(model, opt, train_dl, loss_fn, val_dl=None, epochs=10, save=False, check=5, grad=False):
    model = model.to(DEVICE)
    writer = SummaryWriter(f'./runs/{model.name}')

    plt_loss_train = []
    plt_grad = []
    if isinstance(loss_fn, nn.CrossEntropyLoss):
        miou = MeanIoU(num_classes=model.classes, per_class=False, include_background=False, input_format='index').to(DEVICE)
        pix_acc = MulticlassAccuracy(num_classes=model.classes, multidim_average='global', average='micro').to(DEVICE)
        stats = {'miou':miou, 'pix_acc':pix_acc}
    else:
        mae = MeanAbsoluteError().to(DEVICE)
        mre = MeanAbsoluteRelativeError().to(DEVICE)
        stats = {'mae':mae, 'mre':mre}
    stats_str = list(stats.keys())
    plt_stats_train = {stats_str[0]: [], stats_str[1]: []}
    if val_dl != None:
        plt_loss_val = []
        plt_stats_val = {stats_str[0]: [], stats_str[1]: []}

    if save and not os.path.exists(f"./models/{model.name}"): 
        os.makedirs(f"./models/{model.name}")
    for epoch in range(epochs):
        model.train()
        reset_stats(stats)
        # for k in stats.keys():
        #     stats[k].reset()
    
        total_loss = 0
        for x, y_seg, y_dis in tqdm(train_dl):
            x = x.to(torch.float)
            y_seg = y_seg.to(torch.long)
            y_dis = y_dis.to(torch.float)
            y = y_seg.squeeze(dim=1) if isinstance(loss_fn, nn.CrossEntropyLoss) else y_dis

            opt.zero_grad()
            loss = compute_loss_singletask(model, x, y, loss_fn, stats)
            loss.backward()
            opt.step()
            
            total_loss += loss.item()
        total_loss /= len(train_dl)
        writer_string = 'Train/Loss/Segmentation' if isinstance(loss_fn, nn.CrossEntropyLoss) else 'Train/Loss/Depth'
        writer.add_scalar(writer_string, total_loss, epoch)
        plt_loss_train.append(total_loss)
        # for k in stats.keys():
        #     plt_stats_train[k].append(stats[k].compute().cpu())
        add_plt(plt_stats_train, stats)
        if grad:
            grad_norm = compute_grad(model)
            plt_grad.append(grad_norm)
            writer.add_scalar('Train/Gradient', grad_norm, epoch)
        if epoch % check == 0:
            print(f"Epoch {epoch}/{epochs} - Train Loss: {total_loss:.4f}")
            for k in stats.keys():
                print(f"{k}: {stats[k].compute().cpu()}")
            
            print(f"Gradient Norm: {grad_norm}\n")
            save_model_opt(model, opt, epoch) if save else None
                                
        if val_dl != None and epoch % check == 0:
            losses_tmp, stats_tmp = val_epoch_singletask(model, val_dl, loss_fn, writer, epoch)
            plt_loss_val.append(losses_tmp)
            for k in stats_tmp.keys():
                plt_stats_val[k].append(stats_tmp[k].compute().cpu())

    _, ax = plt.subplots(2, 2, figsize=(40, 40))
    ax[0][0].plot(plt_loss_train)
    ax[0][0].set_title('Loss')
    ax[0][1].plot(plt_stats_train[stats_str[0]])
    ax[0][1].set_title(stats_str[0])
    ax[1][0].plot(plt_stats_train[stats_str[1]])
    ax[1][0].set_title(stats_str[1])
    if grad:
        ax[1][1].plot(plt_grad)
        ax[1][1].set_title('Gradient Norm')
    if save:
        plt.savefig(f"./models/{model.name}/{model.name}_train{epochs}.png")
        torch.save(model.state_dict(), f"./models/{model.name}/{model.name}_train{epochs}.pth")
        # plots = {'plt_loss_train': plt_loss_train, 'plt_stats_train': plt_stats_train, 'plt_loss_val': plt_loss_val, 'plt_stats_val': plt_stats_val}
        # save_fig_plots(model, epochs, plots)

    if val_dl != None:
        _, ax = plt.subplots(3, 1, figsize=(20, 20))
        ax[0].plot(plt_loss_val)
        ax[0].set_title('Loss')
        for i, k in enumerate(stats.keys()):
            ax[i+1].plot(plt_stats_val[k])
            ax[i+1].set_title(k)
        plt.savefig(f"./models/{model.name}_val{epochs}.png") if save else None


def val_epoch_singletask(model, dl, loss_fn, writer, epoch):
    with torch.no_grad():
        model.eval()
        total_loss = 0

        if isinstance(loss_fn, nn.CrossEntropyLoss):
            miou = MeanIoU(num_classes=model.classes, per_class=False, include_background=False, input_format='index').to(DEVICE)
            pix_acc = MulticlassAccuracy(num_classes=model.classes, multidim_average='global', average='micro').to(DEVICE)
            stats = {'miou': miou, 'pix_acc': pix_acc}
        else:
            mae = MeanAbsoluteError().to(DEVICE)
            mre = MeanAbsoluteRelativeError().to(DEVICE)
            stats = {'mae': mae, 'mre': mre}
        for x, y_seg, y_dis in tqdm(dl):
            x = x.to(torch.float)
            y_seg = y_seg.to(torch.long)
            y_dis = y_dis.to(torch.float)
            y = y_seg.squeeze(dim=1) if isinstance(loss_fn, nn.CrossEntropyLoss) else y_dis
            loss = compute_loss_singletask(model, x, y, loss_fn, stats)
                
            total_loss += loss.item()
        total_loss /= len(dl)
        writer_string = 'Test/Loss/Segmentation' if isinstance(loss_fn, nn.CrossEntropyLoss) else 'Train/Loss/Depth'
        writer.add_scalar(writer_string, total_loss, epoch)
        print("Test Loss: ", total_loss)
        for k in stats.keys():
            print(f"{k}: {stats[k].compute().cpu()}")
        print("\n")
    return total_loss, stats

In [None]:
segnet = SegNet()
opt = torch.optim.Adam(segnet.parameters(), lr=LEARNING_RATE)
# print(densenet)
nparams = sum(p.numel() for p in segnet.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {nparams}")
train_singletask(segnet, opt, train_dl, nn.CrossEntropyLoss(ignore_index=-1), val_dl, epochs=EPOCHS, save=True, check=5, grad=True)

In [97]:
def visualize_results_multitask(model, img, img_seg, img_dis, save=False):
    with torch.no_grad():
        model = model.to(DEVICE)
        model.eval()
        img = img.to(DEVICE).to(torch.float)
        output_seg, output_dis = model(img.unsqueeze(0))
        pred_seg = torch.argmax(output_seg, dim=1).squeeze(0).cpu().detach().numpy()
        pred_dis = output_dis.squeeze(0, 1).cpu().detach().numpy()
        idx = img_seg==-1
        img_seg[idx] = 0
        print(f"Accuracy: {torch.sum(pred_seg == img_seg) / (img_seg.numel()-torch.sum(idx))}")

        plt.imshow(img.cpu().permute(1, 2, 0))

        _, ax = plt.subplots(2, 2, figsize=(10, 8))
        ax[0][0].imshow(img_seg)
        ax[0][0].set_title('Ground Truth Segmentation')
        ax[0][1].imshow(pred_seg)
        ax[0][1].set_title('Predicted Segmentation')
        ax[1][0].imshow(img_dis, cmap='gray')
        ax[1][0].set_title('Ground Truth Depth')
        ax[1][1].imshow(pred_dis, cmap='gray')
        ax[1][1].set_title('Predicted Depth')
        plt.show()
        if save:
            if not os.path.exists(f"./models/{model.name}"): 
                os.makedirs(f"./models/{model.name}")
            plt.savefig(f"./models/{model.name}/{model.name}_results.png")

In [None]:
# for i, (img, img_seg, img_dis) in enumerate(val_dl):
#     visualize_results(model, img[1], img_seg[1], img_dis[1])
#     if i == 5:
#         break

In [23]:
# from matplotlib import colors

# cmap = colors.ListedColormap(['k','b','y','g','r'])
# rm = np.random.randint(0,5,(5,5))
# print(rm)
# plt.imshow(rm, interpolation='nearest', cmap=cmap)
# plt.tight_layout()
# plt.show()