In [None]:
import os
import sys
import json
import utils
import torch
import torch.nn as nn
import torchvision
import torchvision.models.segmentation.deeplabv3 as dlv3
import torchvision.transforms.functional as tf
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import Subset
from torchsummary import summary
from sklearn.metrics import confusion_matrix
from tqdm import tqdm
from PIL import Image

## Data utilities

In [None]:
class SemanticLabelMapper():
    
    ID_TO_STRING = {
        'common': {
            0: 'road',
            1: 'sidewalk',
            2: 'building',
            3: 'wall',
            4: 'fence',
            5: 'pole',
            6: 'trafficlight',
            7: 'trafficsign',
            8: 'vegetation',
            9: 'terrain',
            10: 'sky',
            11: 'pedestrian',
            12: 'rider',
            13: 'car',
            14: 'truck',
            15: 'bus',
            16: 'train',
            17: 'motorcycle',
            18: 'bicycle',
            19: 'any'
        }
    }

    ID_TO_COLOR = {
        'common': {
            0: [70, 70, 70],
            1: [100, 40, 40],
            2: [55, 90, 80],
            3: [220, 20, 60],
            4: [153, 153, 153],
            5: [157, 234, 50],
            6: [128, 64, 128],
            7: [244, 35, 232],
            8: [107, 142, 35],
            9: [0, 0, 142],
            10: [102, 102, 156],
            11: [220, 220, 0],
            12: [70, 130, 180],
            13: [81, 0, 81],
            14: [150, 100, 100],
            15: [230, 150, 140],
            16: [180, 165, 180],
            17: [250, 170, 30],
            18: [110, 190, 160],
            19: [145, 170, 100]
        }
    }

    MAPPING = {
        'carla_to_common': [
            19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 19, 19, 19, 0, 19, 19, 19, 19
        ],
        'cityscapes_to_common': [
            19, 19, 19, 19, 19, 19, 19, 0, 1, 19, 19, 2, 3, 4, 19, 19, 19, 5, 19, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 19, 19, 16, 17, 18, 19   
        ]
    }

    def __init__(self, type=None) -> None:
        super().__init__()
        self.type = type

    def __map_value(self, pixel):
        return SemanticLabelMapper.MAPPING[self.type][pixel]

    def map_image(self, input):
        return np.vectorize(self.__map_value)(input)
    
    def map_from_dir(self, src_path, dst_path, extension):
        for file in tqdm(os.listdir(src_path)):
            if file.endswith(extension):
                src_image_path = f'{src_path}/{file}'
                dst_image_path = f'{dst_path}/{file}'
                src_image = np.array(Image.open(src_image_path, 'r'))
                dst_image = self.map_image(src_image)            
                dst_image = Image.fromarray(np.uint8(dst_image), 'L')
                dst_image.save(dst_image_path)

class HybridDataset(Dataset):

    def __init__(self, root_path, input_dir, target_dir, transform=None, type='real', labels_mapping=None) -> None:
        super(HybridDataset, self).__init__()
        self.root_path = root_path
        self.input_data = input_dir
        self.target_data = target_dir
        self.transform = transform
        self.type = type
        self.labels_mapping = labels_mapping
    
    def __len__(self):
        input_file_list = os.listdir(os.path.join(self.root_path, self.input_data))
        target_file_list = os.listdir(os.path.join(self.root_path, self.target_data))
        input_length = len(input_file_list)
        target_length = len(target_file_list)
        if target_length == input_length:
            return target_length

    def __getitem__(self, index):
        img_path_input_patch = os.path.join(self.root_path, self.input_data, f"{self.type}_rgb_{index}.png")
        img_path_tgt_patch = os.path.join(self.root_path, self.target_data, f"{self.type}_semantic_segmentation_{index}.png")
        
        ipt_patch = np.array(Image.open(img_path_input_patch, 'r')).astype(np.float32)
        tgt_patch = np.array(Image.open(img_path_tgt_patch, 'r',)).astype(np.int_)

        if self.labels_mapping is not None:
            try:
                semantic_label_mapper = SemanticLabelMapper(self.labels_mapping)
                tgt_patch = semantic_label_mapper.map_image(tgt_patch)
            except Exception as e:
                raise Exception(f'Could not perform label mapping!\n {e}')
        # tgt_patch.astype(np.float32)
        np.expand_dims(tgt_patch, axis=0)
            
        ipt_patch_tensor = tf.to_tensor(ipt_patch)
        tgt_patch_tensor = tf.to_tensor(tgt_patch)
        
        if self.transform:
            ipt_patch_tensor = self.transform(ipt_patch_tensor)
            tgt_patch_tensor = self.transform(tgt_patch_tensor)
            
        return ipt_patch_tensor, tgt_patch_tensor

## Function utilities

In [None]:
def iou(predictions, targets, num_classes, smooth=sys.float_info.epsilon):
    confusion_mat = confusion_matrix(targets.flatten(), predictions.flatten(), labels=range(num_classes))
    ious = []
    for c in range(num_classes):
        intersection = confusion_mat[c, c]
        union = confusion_mat[c, :].sum() + confusion_mat[:, c].sum() - intersection
        ious.append((intersection + smooth) / (union + smooth))
    return ious

def add_weight_decay(net, l2_value, skip_list=()):
    decay, no_decay = [], []
    for name, param in net.named_parameters():
        if not param.requires_grad: continue # frozen weights
        if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: no_decay.append(param)
        else: decay.append(param)
    return [{'params': no_decay, 'weight_decay': 0.}, {'params': decay, 'weight_decay': l2_value}]


## Training-related functions

In [None]:
def validate_epoch(model, val_dataloader, loss_function):
    model.eval()
    batch_losses = []
    metrics = {'miou': 0, 'loss': 0}
    
    with torch.no_grad():
        for (inputs, targets) in val_dataloader:
            ipts = torch.autograd.Variable(inputs).cuda()
            tgts = torch.autograd.Variable(targets).cuda()
            
            preds = model(ipts)['out']
            loss = loss_function(preds, tgts.squeeze(1).long())
            loss = loss.data.cpu().numpy()
            batch_losses += [loss]

            preds = torch.argmax(preds.cpu(), dim=1)
            tgts = torch.squeeze(targets, dim=1)

            metrics['miou'] += sum(utils.iou(preds, tgts, num_classes=20)) / 20
            metrics['loss'] += loss

        metrics['miou'] /= float(len(val_dataloader))
        metrics['loss'] /= float(len(val_dataloader))
    return metrics

In [None]:
def train_epoch(model, train_dataloader, loss_function, optimizer):
    model.train()
    batch_losses = []
    
    for (inputs, targets) in train_dataloader:
        ipts = torch.autograd.Variable(inputs).cuda()
        tgts = torch.autograd.Variable(targets).cuda()
        pred = model(ipts)['out']

        loss = loss_function(pred, tgts.squeeze(1).long())
        loss_val = loss.data.cpu().numpy()
        batch_losses += [loss_val]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return batch_losses

In [None]:
def train(model, train_dataloader, val_dataloader, epochs, loss_function, optimizer, lr_decay):
    model.train()
    epoch_train_losses = []
    epoch_val_losses = []
    epoch_val_mious = []
    for epoch in tqdm(range(epochs), desc='Epoch progress'):
        batch_train_losses = train_epoch(
            model=model, 
            train_dataloader=train_dataloader,
            loss_function=loss_function, 
            optimizer=optimizer)
        
        batch_val_metrics = validate_epoch(
            model=model,
            val_dataloader=val_dataloader,
            loss_function=loss_function)

        epoch_average_train_loss = np.mean(batch_train_losses)
        epoch_train_losses += [epoch_average_train_loss]
        epoch_val_losses += [batch_val_metrics['loss']]
        epoch_val_mious += [batch_val_metrics['miou']]

        print(f'\n\n[TRAIN] Epoch average loss: {epoch_average_train_loss:.2f}')
        print(f'[VAL] Epoch average loss: {batch_val_metrics["loss"]:.2f}')
        print(f'[VAL] Epoch average miou: {batch_val_metrics["miou"]:.2f}\n')
        
        if lr_decay:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= ((1.0 + float(epoch) / epochs) ** 0.9)

    plt.plot(epoch_train_losses, label='Train loss', color='blue') 
    plt.plot(epoch_val_losses, label='Validation loss', color='yellow') 
    plt.xlabel('Epoch') 
    plt.ylabel('Loss') 
    plt.title('Loss over Epochs') 
    plt.legend()
    plt.show()
 
    plt.plot(epoch_val_mious, label='Validation mIoU', color='green') 
    plt.xlabel('Epoch') 
    plt.ylabel('mIoU') 
    plt.title('mIoU over Epochs') 
    plt.legend()
    plt.show()

## Training configuration loading

In [None]:
import pprint
args = {}
with open('./configs/config-1.json') as json_file:
    args = json.load(json_file)
pprint.pprint(args)

In [None]:
torch.cuda.empty_cache()
train_dataset = HybridDataset(root_path=f'C:\\Users\\Manuel\\Projects\\GitHub_Repositories\\master_thesis\\datasets\\{args["data_source"]}\\train',
                                   input_dir='rgb',
                                   target_dir='semantic_segmentation_mapped',
                                   transform=torchvision.transforms.Compose([torchvision.transforms.Resize((args["image_height"], args["image_width"]))]),
                                   type=args["data_source"],
                                   labels_mapping=None)
train_dataset = Subset(train_dataset, indices=range(args["data_subset_size"]))
train_dataloader = DataLoader(dataset=train_dataset,
                              batch_size=args["data_batch_size"],
                              shuffle=True)
val_dataset = HybridDataset(root_path=f'C:\\Users\\Manuel\\Projects\\GitHub_Repositories\\master_thesis\\datasets\\{args["data_source"]}\\val',
                                 input_dir='rgb',
                                 target_dir='semantic_segmentation_mapped',
                                 transform=torchvision.transforms.Compose([torchvision.transforms.Resize((args["image_height"], args["image_width"]))]),
                                 type=args["data_source"],
                                 labels_mapping=None)
val_dataset = Subset(val_dataset, indices=range(args["val_data_subset_size"]))
val_dataloader = DataLoader(dataset=val_dataset,
                            batch_size=args["val_data_batch_size"],
                            shuffle=True)
    

## Training preparation

In [None]:
model = dlv3.deeplabv3_resnet50(pretrained=False, progress=True, output_stride=256, num_classes=len(SemanticLabelMapper.ID_TO_STRING['common'].keys()))  
if args["fine_tune"]:
    model.load_state_dict(torch.load(args["fine_tune_model_path"]))
    for name, param in model.backbone.named_parameters():
        if 'layer1' in name or 'layer2' in name or 'layer3' in name or 'layer4' in name or 'layer5' in name:
            print(f'---> Freezing layer: {name}.')
            param.requires_grad = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print('Loaded model to device.')

In [None]:
params = utils.add_weight_decay(model, l2_value=0.0001)
optimizer = torch.optim.Adam(params=params, lr=args["learning_rate"])
loss_function = nn.CrossEntropyLoss(ignore_index=args["ignore_label"])

In [None]:
train(model=model, 
      train_dataloader=train_dataloader, 
      val_dataloader=val_dataloader, 
      epochs=args["epochs"], 
      loss_function=loss_function, 
      optimizer=optimizer, 
      lr_decay=args["learning_rate_paper_decay"])

In [None]:
torch.save(model.state_dict(), f'{args["model_save_path"]}',)