In [None]:
## Import and setup ##

from google.colab import drive
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageFilter
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torchvision import models
import os
import gc
import glob
import random
import collections
from collections import deque
import contextlib

drive.mount('/content/drive')


Definition Bisenet model with backbone ResNet


In [None]:
## Bisenet model with backbone ResNet18 or ResNet101 ##

class resnet18(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = models.resnet18(pretrained=pretrained)
        self.conv1 = self.features.conv1
        self.bn1 = self.features.bn1
        self.relu = self.features.relu
        self.maxpool1 = self.features.maxpool
        self.layer1 = self.features.layer1
        self.layer2 = self.features.layer2
        self.layer3 = self.features.layer3
        self.layer4 = self.features.layer4

    def forward(self, input):
        x = self.conv1(input)
        x = self.relu(self.bn1(x))
        x = self.maxpool1(x)
        feature1 = self.layer1(x)  # 1 / 4
        feature2 = self.layer2(feature1)  # 1 / 8
        feature3 = self.layer3(feature2)  # 1 / 16
        feature4 = self.layer4(feature3)  # 1 / 32
        # global average pooling to build tail
        tail = torch.mean(feature4, 3, keepdim=True)
        tail = torch.mean(tail, 2, keepdim=True)
        return feature3, feature4, tail


class resnet101(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = models.resnet101(pretrained=pretrained)
        self.conv1 = self.features.conv1
        self.bn1 = self.features.bn1
        self.relu = self.features.relu
        self.maxpool1 = self.features.maxpool
        self.layer1 = self.features.layer1
        self.layer2 = self.features.layer2
        self.layer3 = self.features.layer3
        self.layer4 = self.features.layer4

    def forward(self, input):
        x = self.conv1(input)
        x = self.relu(self.bn1(x))
        x = self.maxpool1(x)
        feature1 = self.layer1(x)  # 1 / 4
        feature2 = self.layer2(feature1)  # 1 / 8
        feature3 = self.layer3(feature2)  # 1 / 16
        feature4 = self.layer4(feature3)  # 1 / 32
        # global average pooling to build tail
        tail = torch.mean(feature4, 3, keepdim=True)
        tail = torch.mean(tail, 2, keepdim=True)
        return feature3, feature4, tail


def build_contextpath(name):
    model = {
        'resnet18': resnet18(pretrained=True),
        'resnet101': resnet101(pretrained=True)
    }
    return model[name]


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, input):
        x = self.conv1(input)
        return self.relu(self.bn(x))


class Spatial_path(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.convblock1 = ConvBlock(in_channels=3, out_channels=64)
        self.convblock2 = ConvBlock(in_channels=64, out_channels=128)
        self.convblock3 = ConvBlock(in_channels=128, out_channels=256)

    def forward(self, input):
        x = self.convblock1(input)
        x = self.convblock2(x)
        x = self.convblock3(x)
        return x


class AttentionRefinementModule(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.sigmoid = nn.Sigmoid()
        self.in_channels = in_channels
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))

    def forward(self, input):
        # global average pooling
        x = self.avgpool(input)
        assert self.in_channels == x.size(1), 'in_channels and out_channels should all be {}'.format(x.size(1))
        x = self.conv(x)
        x = self.sigmoid(self.bn(x))
        # x = self.sigmoid(x)
        # channels of input and x should be same
        x = torch.mul(input, x)
        return x


class FeatureFusionModule(torch.nn.Module):
    def __init__(self, num_classes, in_channels):
        super().__init__()
        # self.in_channels = input_1.channels + input_2.channels
        # resnet101 3328 = 256(from spatial path) + 1024(from context path) + 2048(from context path)
        # resnet18  1024 = 256(from spatial path) + 256(from context path) + 512(from context path)
        self.in_channels = in_channels

        self.convblock = ConvBlock(in_channels=self.in_channels, out_channels=num_classes, stride=1)
        self.conv1 = nn.Conv2d(num_classes, num_classes, kernel_size=1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(num_classes, num_classes, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))

    def forward(self, input_1, input_2):
        x = torch.cat((input_1, input_2), dim=1)
        assert self.in_channels == x.size(1), 'in_channels of ConvBlock should be {}'.format(x.size(1))
        feature = self.convblock(x)
        x = self.avgpool(feature)

        x = self.relu(self.conv1(x))
        x = self.sigmoid(self.conv2(x))
        x = torch.mul(feature, x)
        x = torch.add(x, feature)
        return x


class BiSeNet(nn.Module):
    def __init__(self, num_classes, context_path='resnet18'):
        super().__init__()
        # build spatial path
        self.spatial_path = Spatial_path()

        # build spatial path
        self.context_path = build_contextpath(context_path)

        # build attention refinement module  for resnet 101
        if context_path == 'resnet101':
            self.attention_refinement_module1 = AttentionRefinementModule(1024, 1024)
            self.attention_refinement_module2 = AttentionRefinementModule(2048, 2048)
            # supervision block
            self.supervision1 = nn.Conv2d(1024, num_classes, kernel_size=1)
            self.supervision2 = nn.Conv2d(2048, num_classes, kernel_size=1)
            # build feature fusion module
            self.feature_fusion_module = FeatureFusionModule(num_classes, 3328)

        elif context_path == 'resnet18':
            # build attention refinement module  for resnet 18
            self.attention_refinement_module1 = AttentionRefinementModule(256, 256)
            self.attention_refinement_module2 = AttentionRefinementModule(512, 512)
            # supervision block
            self.supervision1 = nn.Conv2d(256, num_classes, kernel_size=1)
            self.supervision2 = nn.Conv2d(512, num_classes, kernel_size=1)
            # build feature fusion module
            self.feature_fusion_module = FeatureFusionModule(num_classes, 1024)
        else:
            print('Error: unspport context_path network \n')

        # build final convolution
        self.conv = nn.Conv2d(num_classes, num_classes, kernel_size=1)

        self.init_weight()

        self.mul_lr = []
        self.mul_lr.append(self.spatial_path)
        #self.mul_lr.append(self.saptial_path)
        self.mul_lr.append(self.attention_refinement_module1)
        self.mul_lr.append(self.attention_refinement_module2)
        self.mul_lr.append(self.supervision1)
        self.mul_lr.append(self.supervision2)
        self.mul_lr.append(self.feature_fusion_module)
        self.mul_lr.append(self.conv)


    def init_weight(self):
        for name, m in self.named_modules():
            if 'context_path' not in name:
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                elif isinstance(m, nn.BatchNorm2d):
                    m.eps = 1e-5
                    m.momentum = 0.1
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)

    def forward(self, input):
        # output of spatial path
        sx = self.spatial_path(input)

        # output of context path
        cx1, cx2, tail = self.context_path(input)
        cx1 = self.attention_refinement_module1(cx1)
        cx2 = self.attention_refinement_module2(cx2)
        cx2 = torch.mul(cx2, tail)
        # upsampling
        cx1 = torch.nn.functional.interpolate(cx1, size=sx.size()[-2:], mode='bilinear')
        cx2 = torch.nn.functional.interpolate(cx2, size=sx.size()[-2:], mode='bilinear')
        cx = torch.cat((cx1, cx2), dim=1)

        if self.training == True:
            cx1_sup = self.supervision1(cx1)
            cx2_sup = self.supervision2(cx2)
            cx1_sup = torch.nn.functional.interpolate(cx1_sup, size=input.size()[-2:], mode='bilinear')
            cx2_sup = torch.nn.functional.interpolate(cx2_sup, size=input.size()[-2:], mode='bilinear')

        # output of feature fusion module
        result = self.feature_fusion_module(sx, cx)

        # upsampling
        result = torch.nn.functional.interpolate(result, scale_factor=8, mode='bilinear')
        result = self.conv(result)

        if self.training == True:
            return result, cx1_sup, cx2_sup

        return result

def get_bisenet_model(num_classes, context_path='resnet18'):
    return BiSeNet(num_classes=num_classes, context_path=context_path)


Definition of discriminator, architecture as in article https://openaccess.thecvf.com/content_cvpr_2018/papers/Tsai_Learning_to_Adapt_CVPR_2018_paper.pdf

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=19): # Input is a 19-channel class map
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=1), # Output a single value
        )

    def forward(self, x):
        return self.model(x)

For EXTENSION: definition of a variations of the cross-entropy loss:  Focal loss

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, gamma, ignore_index: int = 255, reduction: str = "mean"):
        super().__init__()
        self.gamma = gamma
        self.ignore_index = ignore_index
        self.reduction = reduction

    def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        # logits = model outputs, target = Ground truth labels

        target = target.long()

        # Compute log-softmax over the class dimension
        log_probs = F.log_softmax(logits, dim=1)  # (N, C, H, W)

        # handle ignore_index
        if self.ignore_index is not None:
            # Create a bool for valid pixels (not equal to ignore_index)
            valid = (target != self.ignore_index)
            if valid.sum() == 0:
                return logits.new_zeros(())
            safe_target = target.clone()
            safe_target[~valid] = 0
        else:
            valid = torch.ones_like(target, dtype=torch.bool)
            safe_target = target

        # Compute focal weight
        log_pt = log_probs.gather(1, safe_target.unsqueeze(1)).squeeze(1)
        pt = log_pt.exp()
        focal_weight = (1 - pt).pow(self.gamma)

        # Compute focal loss for each sample
        loss = -focal_weight * log_pt
        loss = loss[valid]

        # Take mean or sum of focal loss
        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        else:
            return loss

Definition of transformations used as preprocessing and data augmentation

In [None]:
## Class for transformation ##

class SegmentationTransform:
    """
    Class to apply preprocessing and data augmentation for image segmentation.
    Includes resizing, optional horizontal flipping, color jitter, Gaussian blur,
    and normalization to ImageNet statistics.
    """
    def __init__(self, resize, flip = False, color_jitter=False, gaussian_blur=False, train=True):
        # Resize for images and masks
        self.resize_img = transforms.Resize(resize, interpolation=Image.BILINEAR)
        self.resize_mask = transforms.Resize(resize, interpolation=Image.NEAREST)

        # Augmentations
        self.flip_flag = flip and train
        self.color_jitter_flag = color_jitter and train
        self.gaussian_blur_flag = gaussian_blur and train
        self.color_jitter = transforms.ColorJitter(0.2, 0.2, 0.2, 0.1)

        # ImageNet normalization
        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]

    def __call__(self, img, mask):
        # Resize
        img = self.resize_img(img)
        mask = Image.fromarray(mask.astype(np.uint8))
        mask = self.resize_mask(mask)

        # Horizontal flip
        if self.flip_flag and random.random() < 0.5:
            img  =TF.hflip(img)
            mask =TF.hflip(mask)

        # color jitter
        if self.color_jitter_flag and random.random() < 0.5:
            img = self.color_jitter(img)

        # gaussian blur
        if self.gaussian_blur_flag and random.random() < 0.5:
            img = img.filter(ImageFilter.GaussianBlur(radius=1.5))

        # conversion to tensor + normalize
        img = TF.to_tensor(img)
        img = TF.normalize(img, mean=self.mean, std=self.std)
        mask = torch.from_numpy(np.array(mask, dtype=np.uint8)).long()

        return img, mask

Dataset GTA5

In [None]:
ID_TO_TRAINID = {
    7:0, 8:1, 11:2, 12:3, 13:4, 17:5,
    19:6, 20:7, 21:8, 22:9, 23:10, 24:11,
    25:12, 26:13, 27:14, 28:15, 31:16, 32:17, 33:18
}     # This dictionary is mapping the original label IDs to the corresponding Cityscapes training IDs.

def mapping_labels(label_np):
    """
    Map GTA5 label IDs to Cityscapes training IDs.
    ID which is not mapped is set to the ignore index 255.
    """
    mapped =np.full_like(label_np, 255)
    for k, v in ID_TO_TRAINID.items():
        mapped[label_np == k] =v
    return mapped

class GTA5Dataset(Dataset):
    """
    Dataset for GTA5 segmentation dataset.
    Loads images and labels and applies optional transformations.
    """
    def __init__(self, root, transform=None):
        self.transform = transform

        # Load all images and labels
        self.images =sorted(glob.glob(f"{root}/images/**/*.png", recursive=True))
        self.labels =sorted(glob.glob(f"{root}/labels/**/*.png", recursive=True))
        min_len = min(len(self.images), len(self.labels))
        self.images = self.images[:min_len]
        self.labels = self.labels[:min_len]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):

        # Load image and label
        img = Image.open(self.images[idx]).convert("RGB")
        label_np= np.array(Image.open(self.labels[idx]), dtype=np.int64)

        # Map GTA5 labels to Cityscapes training IDs
        label_mapped= mapping_labels(label_np)

        # Optional transformations
        if self.transform:
            img, label_mapped = self.transform(img, label_mapped)

        return img,label_mapped

Dataset Cityscapes

In [None]:
# Dataset Cityscapes
class CityscapesDataset(Dataset):
    def __init__(self, root, split="val", transform=None):
        self.transform = transform

        # Load all images and labels
        self.images = sorted(glob.glob(f"{root}/images/{split}/**/*.png", recursive=True))
        self.labels = sorted(glob.glob(f"{root}/gtFine/{split}/**/*_labelTrainIds.png", recursive=True))   # extract the trainIds
        min_len = min(len(self.images), len(self.labels))
        self.images = self.images[:min_len]
        self.labels = self.labels[:min_len]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):

        # Load image and label
        img = Image.open(self.images[idx]).convert("RGB")
        label_np = np.array(Image.open(self.labels[idx]), dtype=np.int64)

        # Optional transformations
        if self.transform:
            img, label_np = self.transform(img, label_np)

        return img, label_np

import torch


In [None]:
# Transformations for train (taking the data augmentations that perform best in 3b) and validation dataset

transform_gta5_best_2_3  = SegmentationTransform(resize=(720,1280), color_jitter=True,  gaussian_blur=True,  train=True)
transform_cityscapes_val = SegmentationTransform(resize=(512,1024), train=False)

Utils Function

In [None]:
def compute_miou(preds, labels, num_classes=19, device="cuda"):
    """
    Compute the mean Intersection over Union (mIoU) for semantic segmentation.
    This function calculates the IoU for each class and returns both the mean IoU
    and the list of per-class IoUs. It handles the common "void" label (255) by
    excluding it from false positive calculations.
    """


    # Initialize variables for storing True Positives , False Positives and False Negatives
    tp = torch.zeros(num_classes, dtype=torch.int64, device=device)
    fp = torch.zeros(num_classes, dtype=torch.int64, device=device)
    fn = torch.zeros(num_classes, dtype=torch.int64, device=device)


    # TP, FP, FN for each class
    for cls in range(num_classes):
        # True Positive
        tp[cls] += ((labels == cls) & (preds == cls)).sum()
        # False Positive
        fp[cls] += ((labels != cls) & (labels != 255) & (preds == cls)).sum()
        # False Negative
        fn[cls] += ((labels == cls) & (preds != cls)).sum()

    iou_per_class = []

    # IoU for each class and store in a list
    for cls in range(num_classes):
        denom = tp[cls] + fp[cls] + fn[cls]
        iou = tp[cls].float() / (denom.float() + 1e-10)
        print(f"Class {cls}: TP={tp[cls].item()}, FP={fp[cls].item()}, FN={fn[cls].item()}, IoU={iou.item():.4f}")
        if denom > 0:  # only include classes with at least one pixel
            iou_per_class.append(iou.item())


    mean_iou = np.mean(iou_per_class) if iou_per_class else 0.0
    return mean_iou, iou_per_class



def extract_images(batch):
    """
    Extracts the image tensor from many possible batch format:
    - a PyTorch tensor
    - a list of tensors
    - a dictionary containing tensors
    """
    # Case1: batch is a PyTorch tensor
    if torch.is_tensor(batch):
        return batch

    # Case2: batch is a list or tuple
    if isinstance(batch, (list, tuple)):
        for item in batch:
            if torch.is_tensor(item):
                return item
        return extract_images(batch[0])

    # Case3: batch is a dictionary
    if isinstance(batch, dict):
        for key in ('image', 'img', 'images', 'pixel_values', 'x'):
            if key in batch and torch.is_tensor(batch[key]):
                return batch[key]
        for v in batch.values():
            if torch.is_tensor(v):
                return v



Definition of Datasets and Dataloaders for training and validation

In [None]:
## Dataset and dataloader ##

# Dataset
train_dataset = GTA5Dataset("/content/drive/MyDrive/GTA5/GTA5_", transform=transform_gta5_best_2_3)
city_dataset = CityscapesDataset("/content/drive/MyDrive/Cityscapes/Cityspaces", split="train", transform=transform_gta5_best_2_3)
city_val_dataset = CityscapesDataset("/content/drive/MyDrive/Cityscapes/Cityspaces", split="val", transform=transform_cityscapes_val)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
city_loader = DataLoader(city_dataset, batch_size=4, shuffle=True, num_workers=2)
city_val_loader = DataLoader(city_val_dataset, batch_size=8, shuffle=False, num_workers=2)

Function for the learning rate: polynomial learning rate

In [None]:
## learning rate ##

def poly_lr_scheduler(optimizer, init_lr, iter, lr_decay_iter=1,
                      max_iter=300, power=0.9):
    """
    Polynomial decay of learning rate
        :param init_lr is base learning rate
        :param iter is a current iteration
        :param lr_decay_iter how frequently decay occurs, default is 1
        :param max_iter is number of maximum iterations
        :param power is a polymomial power
      Returns the scalar learning rate
    """

    lr = init_lr*(1 - iter/max_iter)**power
    optimizer.param_groups[0]['lr'] = lr
    return lr

Initialization of model, optimizers and loss functions

In [None]:
## Model, loss functions and optimizer ##

num_classes = 19
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## hyperparameters
initial_lr_s = 2.5e-4
initial_lr_d = 1e-4
aux_loss_weight = 0.4
initial_lr_s = 2.5e-4
initial_lr_d = 1e-4
lambda_adv = 0.001

# Initialize DeepLab v2 model with pre-trained ResNet18 backbone
model = get_bisenet_model(num_classes=19, context_path='resnet18')
model = model.to(device)

# Initialize discriminator
discriminator = Discriminator(in_channels=num_classes)
discriminator = discriminator.to(device)

# Optimizers
optimizer_s = torch.optim.SGD(model.parameters(), lr=initial_lr_s, momentum=0.9, weight_decay=1e-4)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=initial_lr_d,  betas=(0.9, 0.99))

#-- FOR EXTENSION --#
# optimizer_s = torch.optim.Adam(model.parameters(), lr=initial_lr_s, betas=(0.9, 0.99), weight_decay=1e-4)

# Loss functions
adv_criterion = nn.BCEWithLogitsLoss().to(device)
seg_criterion = nn.CrossEntropyLoss(ignore_index=255).to(device)

#-- FOR EXTENSION --#
# gammas = 0
# gammas = 2
# gammas = 3
# seg_criterion = FocalLoss(gamma=gammas, ignore_index=255).to(device)

Checkpoint Management, Device Setup, and AMP Initialization.

This section prepares the training environment by:
   - Creating a directory to save model checkpoints
   - Selecting the computation device (GPU if available, else CPU).
   - Initializing PyTorch’s Automatic Mixed Precision (AMP) GradScaler

In [None]:
## Setup directories, device, and AMP scaler ##

checkpoint_dir = "/content/drive/MyDrive/4_step_Machine/MachineLearning/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
save_interval = 3
max_checkpoints = 2
saved_checkpoints = deque(maxlen=max_checkpoints)
use_amp = (device.type == 'cuda')


# Select device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Automatic Mixed Precision (AMP): GradScaler
scaler = torch.amp.GradScaler(device='cuda') if device.type == 'cuda' else None


# Management of checkpoints
resume_training = True
latest_checkpoint = None
latest_epoch = -1

for fname in os.listdir(checkpoint_dir):
    if fname.startswith("checkpoint_epoch") and fname.endswith(".pt"):
        epoch_num = int(fname.split("_epoch")[1].split(".")[0])
        if epoch_num > latest_epoch:
            latest_epoch = epoch_num
            latest_checkpoint = os.path.join(checkpoint_dir, fname)

start_epoch = 0
if latest_checkpoint:
    checkpoint = torch.load(latest_checkpoint, map_location=device)

    # Restore model and optimizer state:
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    # Restore AMP scaler if used
    if scaler and checkpoint.get('scaler_state_dict'):
        scaler.load_state_dict(checkpoint['scaler_state_dict'])

    # Resume from next epoch
    start_epoch = checkpoint['epoch'] + 1
    print(f"Restored from {latest_checkpoint}")
else:
    print(" No checkpoint found. Starting from scratch")

 No checkpoint found. Starting from scratch


Training of model:
this section sets up an Adversarial training loop for domain adaptation.

In [None]:

## Variables for training progress ##
num_epochs = 50
num_classes = 19
save_interval = 3
global_iter = 0
adv_epochs = 3

# Best mIoU, corresponding checkpoint info and list for storing evaluation of loss function during epochs
best_miou = 0.0
best_epoch_ckpt = None
epoch_loss_list = []


## training loop ##

for epoch in range(start_epoch, num_epochs):

    model.train()
    discriminator.train()

    total_loss = 0.0

    # update learning rate
    lr = poly_lr_scheduler(optimizer_s, initial_lr_s, epoch)

    target_iter = iter(city_loader)

    for batch_idx, (images_s, labels_s) in enumerate(train_loader):
        global_iter += 1

        images_s =images_s.to(device, non_blocking=True)
        labels_s =labels_s.long().to(device, non_blocking=True)

        optimizer_s.zero_grad()
        optimizer_d.zero_grad()

        # Extract a batch from dataset cityscapes (target domain)
        try:
            batch_t = next(target_iter)
        except StopIteration:
            target_iter = iter(city_loader)
            batch_t = next(target_iter)

        images_t = extract_images(batch_t).to(device, non_blocking=True)

        with torch.autocast(device_type='cuda', enabled=use_amp):

            ## (1) Forward pass on source (with labels) and target (without labels)
            pred_s = model(images_s)
            pred_t = model(images_t)

            if isinstance(pred_s, (list, tuple)):
                main_pred_s, *aux_preds_s = pred_s
                main_pred_t, *aux_preds_t= pred_t

                # Main loss
                loss_seg = seg_criterion(main_pred_s, labels_s)

                # Auxiliary losses
                for aux_s in aux_preds_s:
                    loss_seg += aux_loss_weight*seg_criterion(aux_s, labels_s)

            else:
                main_pred_s = pred_s
                main_pred_t = pred_t
                loss_seg =seg_criterion(main_pred_s, labels_s)

            ## (2) Train discriminator: predict domain from segmentation outputs
            pred_s_prob = F.softmax(main_pred_s, dim=1)
            pred_t_prob = F.softmax(main_pred_t, dim=1)

            pred_s_prob = F.interpolate(pred_s_prob, size=images_s.shape[2:], mode="bilinear",align_corners=False)
            pred_t_prob =F.interpolate(pred_t_prob, size=images_t.shape[2:], mode="bilinear",align_corners=False)

            # Discriminator predictions
            d_out_s = discriminator(pred_s_prob.detach())
            d_out_t = discriminator(pred_t_prob.detach())

            # labels for discriminator
            d_label_s = torch.ones_like(d_out_s)   # source - real
            d_label_t =torch.zeros_like(d_out_t)  # target - fake

            # Discriminator loss
            loss_adv_d = adv_criterion(d_out_s, d_label_s) + adv_criterion(d_out_t, d_label_t)

            ## (3) Train generator so that discriminator fail to classifie
            d_out_t_for_gen = discriminator(pred_t_prob)
            loss_adv_gen = adv_criterion(d_out_t_for_gen, torch.ones_like(d_out_t_for_gen))

            ## (4) Total generator loss = supervised loss + adversarial loss
            loss_g =loss_seg + lambda_adv * loss_adv_gen

        # Update generator
        optimizer_s.zero_grad(set_to_none=True)
        scaler.scale(loss_g).backward()
        scaler.step(optimizer_s)
        scaler.update()

        # Update discriminator
        optimizer_d.zero_grad(set_to_none=True)
        scaler.scale(loss_adv_d).backward()
        scaler.step(optimizer_d)
        scaler.update()

        total_loss+=loss_g.item()

    torch.cuda.empty_cache()
    gc.collect()

    epoch_loss = total_loss /len(train_loader)
    epoch_loss_list.append(epoch_loss)

    ## Miou on validation set ##

    all_predictions = []
    all_labels = []

    model.eval()
    with torch.no_grad():
        for images, labels in city_val_loader:
            images =images.to(device, non_blocking=True)
            labels =labels.long().to(device, non_blocking=True)

            outputs = model(images)
            main_pred = outputs[0] if isinstance(outputs, (list, tuple)) else outputs

            main_pred =F.interpolate(main_pred, size=labels.shape[-2:], mode="bilinear", align_corners=False)
            probabilities =F.softmax(main_pred, dim=1)
            predictions = torch.argmax(probabilities, dim=1)

            all_predictions.append(predictions)
            all_labels.append(labels)

    all_predictions = torch.cat(all_predictions, dim=0).cpu()
    all_labels =torch.cat(all_labels, dim=0).cpu()

    epoch_miou, IoU_per_class = compute_miou(all_predictions, all_labels, num_classes=num_classes)
    print(f"End of epoch {epoch+1} — Loss: {epoch_loss:.4f}, mIoU: {epoch_miou:.4f}, LR of the final epoch: {lr:.6f}")

    ## Saving checkpoints every 3 epochs ##
    if (epoch + 1) % save_interval == 0:
        checkpoint_filename = f"checkpoint_epoch{epoch+1}.pt"
        checkpoint_path = os.path.join(checkpoint_dir, checkpoint_filename)
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_s_state_dict': optimizer_s.state_dict(),
            'optimizer_d_state_dict': optimizer_d.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
            'loss': epoch_loss,
            'miou': epoch_miou,
        }, checkpoint_path)
        saved_checkpoints.append(checkpoint_path)
        print(f"Checkpoint saved: {checkpoint_filename}")

        while len(saved_checkpoints) > max_checkpoints:
            old_ckpt = saved_checkpoints.popleft()
            if os.path.exists(old_ckpt):
                os.remove(old_ckpt)
                print(f"Removed old checkpoint: {os.path.basename(old_ckpt)}")

    ## Saving the best model ##
    if epoch_miou > best_miou:
        best_miou = epoch_miou
        best_epoch_ckpt = os.path.join(checkpoint_dir, "4_step_best_epoch.pt")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_s_state_dict': optimizer_s.state_dict(),
            'optimizer_d_state_dict': optimizer_d.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
            'loss': epoch_loss,
            'miou': epoch_miou,
        }, best_epoch_ckpt)
        print(f"New best model saved: {best_epoch_ckpt} con mIoU {best_miou:.4f}")

print(f"epoch_loss_list: {epoch_loss_list}")

