In [11]:
#Mounting Google Drive from Google Colab
from google.colab import drive
drive.mount('/content/drive')

In [12]:
#Changing the current working directory to the Google Drive
%cd /content/drive/My Drive/MLDL2024_project1-Enrico

In [13]:
!pip install -U fvcore

In [14]:
#Importing the necessary libraries
import os
import torch
import numpy as np
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader
from datasets.gta5 import GTA5Custom
from datasets.cityscapes import CityscapesCustom
from models.bisenet.build_bisenet import BiSeNet
from models.discriminator.discriminator import FCDiscriminator
from train_adversarial import train_adversarial
from utils import test_latency_FPS, test_FLOPs_params, plot_miou_over_epochs

In [15]:
#Set device agnostic code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

#Set the manual seeds
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

#Set training parameters
gta5_height, gta5_width = (720, 1280)
gta5_batch_size = 8

cityscapes_height, cityscapes_width = (512, 1024)
cityscapes_batch_size = 8

n_epochs = 50

#Labels for adversarial training
source_flag = 0
target_flag = 1

#Set up the interpolation layers
interp_s = nn.Upsample(size=(gta5_height, gta5_width), mode='bilinear')
interp_t = nn.Upsample(size=(cityscapes_height, cityscapes_width), mode='bilinear')

Using device: cpu


In [16]:
#Create Dataloaders for Cityscapes and GTA5
gta5_dir = os.path.dirname(os.getcwd()) + '/GTA5/GTA5/'
cityscapes_dir = os.path.dirname(os.getcwd()) + '/Cityscapes/Cityspaces/'

augment1 = transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3)
gta5_train_dataset_aug1 = GTA5Custom(gta5_dir, gta5_height, gta5_width, augment=augment1)
cityscapes_test_dataset = CityscapesCustom(cityscapes_dir, 'val', cityscapes_height, cityscapes_width)

gta5_train_dataloader_aug1 = DataLoader(gta5_train_dataset_aug1, gta5_batch_size, shuffle=True, num_workers=4)
cityscapes_test_dataloader = DataLoader(cityscapes_test_dataset, cityscapes_batch_size, shuffle=False, num_workers=4)

#Get the class names
class_names = cityscapes_test_dataset.get_class_names()

print(f'GTA5 (Train): {len(gta5_train_dataset_aug1)} images, divided into {len(gta5_train_dataloader_aug1)} batches of size {gta5_train_dataloader_aug1.batch_size}')
print(f'Cityscapes (Test): {len(cityscapes_test_dataset)} images, divided into {len(cityscapes_test_dataloader)} batches of size {cityscapes_test_dataloader.batch_size}')

GTA5 (Train): 2500 images, divided into 313 batches of size 8
Cityscapes (Test): 500 images, divided into 63 batches of size 8


In [None]:
#Testing Domain Adaptation with Adversarial Training

#Set up the segmentation network (our Generator) with the pretrained weights
g_model = BiSeNet(num_classes=19, context_path='resnet18').to(device)

#Set up the loss function and the optimizer for the Generator
g_criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
g_optimizer = torch.optim.SGD(g_model.parameters(), lr=2.5e-2, momentum=0.9, weight_decay=1e-4)

#Set up the Discriminator
d_model = FCDiscriminator(num_classes=19).to(device)

#Set up the loss function and the optimizer for the Discriminator
d_criterion = torch.nn.BCEWithLogitsLoss()
d_optimizer = torch.optim.Adam(d_model.parameters(), lr=1e-4, betas=(0.9, 0.99))

#Train the adversarial model
all_train_miou, all_test_miou, best_epoch = train_adversarial(g_model, d_model, g_criterion, d_criterion,
                                                              g_optimizer, d_optimizer,
                                                              source_flag, target_flag,
                                                              gta5_train_dataloader_aug1, cityscapes_test_dataloader,
                                                              interp_s, interp_t,
                                                              class_names, device, n_epochs, 'BiSeNet_adversarial')

plot_miou_over_epochs(all_train_miou, all_test_miou, best_epoch, 'BiSeNet_adversarial')