# Federated Semantic Segmentation for Self Driving Cars

## STEP 0: SETUP PHASE

### Installs

In [1]:
!pip install torchmetrics
!pip install wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchmetrics
  Downloading torchmetrics-0.11.4-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 KB[0m [31m20.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torchmetrics
Successfully installed torchmetrics-0.11.4
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting wandb
  Downloading wandb-0.14.2-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m72.4 MB/s[0m eta [36m0:00:00[0m
Collecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.19.1-py2.py3-none-any.whl (199 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.2/199.2 KB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting GitPython!=3.1.29,>=1.0.0
  Downloading GitPython-3.1.31-py3-none-any.whl (184 kB)
[2K     

### Mounting Google Drive

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
cd /content/drive/MyDrive/AMLproject/

/content/drive/MyDrive/AMLproject


### Cloning FedDrive

In [4]:
import os
if not os.path.isdir('./FedDrive'):
  !git clone https://github.com/Erosinho13/FedDrive


### Imports

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as torch_data
import torch.nn.functional as F 
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import copy
import random
# import torchvision
# import logging
import warnings
import math
import json
import wandb
# TODO: remove all cv2 code
import cv2

# from torchvision import transforms
from torch.backends import cudnn
from torch import from_numpy
from PIL import Image
from torch.utils.data import DataLoader
from torchmetrics.classification import MulticlassJaccardIndex
from torchvision.datasets import VisionDataset
from collections import OrderedDict

from drive.MyDrive.AMLproject import transform as T
from FedDrive.src.modules.bisenetv2 import BiSeNetV2

warnings.resetwarnings()
warnings.simplefilter('ignore')

### Parameter Configuration Step 2

In [None]:
#os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

PARTITION = 'A' # A or B to choose which dataloader to use
SEED = 42
DEVICE = 'cuda' # 'cuda' or 'cpu'

NUM_CLASSES = 19
if NUM_CLASSES == 19:
  cl19 = True

BATCH_SIZE = 8     # Higher batch sizes allows for larger learning rates. An empirical heuristic suggests that, when changing ***
                     # the batch size, learning rate should change by the same factor to have comparable results

LR = 0.05           # The initial Learning Rate ***
MOMENTUM = 0.9       # Hyperparameter for SGD, keep this at 0.9 when using SGD ***
WEIGHT_DECAY =0.0005  # Regularization, you can keep this at the default ***

NUM_EPOCHS = 10     # Total number of training epochs (iterations over dataset)

# servono per decrementare il lerning rate nel tempo 
# STEP_SIZE = 500       # How many epochs before decreasing learning rate (if using a step-down policy)
# GAMMA = 0.8          # Multiplicative factor for learning rate step-down

LOG_FREQUENCY = 20
LOG_FREQUENCY_EPOCH = 3

ROOT_DIR = os.path.join('data', 'Cityscapes')

### Parameter configuration Step 3

In [None]:
PARTITION = "B"  # 'A' or 'B'
SPLIT = 1  # 1 or 2 // 1 = Uniform : 2 = Heterogenous
MAX_SAMPLE_PER_CLIENT = 20

IMAGES_FINAL = "leftImg8bit"
TARGET_FINAL = "gtFine_labelIds"

N_ROUND = 50
CLIENT_PER_ROUND = 5  # clients picked each round
NUM_EPOCHS = 2


CHECKPOINTS = 5


if PARTITION == 'A':
  if SPLIT == 1:
    TOT_CLIENTS = 36
  else:
    TOT_CLIENTS = 46
else:
  if SPLIT == 1:
    TOT_CLIENTS = 25
  else:
    TOT_CLIENTS = 33


### Parameter configuration Step 4

In [None]:
PARTITION = 'A' # A or B to choose which dataloader to use
SPLIT = 1  # 1 or 2 // 1 = Uniform : 2 = Heterogenous

SEED = 42
DEVICE = 'cuda' # 'cuda' or 'cpu'

NUM_CLASSES = 19
cl19 = True

BATCH_SIZE = 8     # Higher batch sizes allows for larger learning rates. An empirical heuristic suggests that, when changing ***
                     # the batch size, learning rate should change by the same factor to have comparable results

LR = 0.05           # The initial Learning Rate ***
MOMENTUM = 0.9       # Hyperparameter for SGD, keep this at 0.9 when using SGD ***
WEIGHT_DECAY = 0.0005  # Regularization, you can keep this at the default ***

NUM_EPOCHS = 10     # Total number of training epochs (iterations over dataset)

# servono per decrementare il lerning rate nel tempo 
# STEP_SIZE = 500       # How many epochs before decreasing learning rate (if using a step-down policy)
# GAMMA = 0.8          # Multiplicative factor for learning rate step-down

LOG_FREQUENCY = 20
LOG_FREQUENCY_EPOCH = 3


LOAD_CKPT_PATH = 'ckpt_0.pth'
CKPT_PATH = './checkpoints/'
CKPT_DIR = './checkpoints/'
LOAD_CKPT = False

# TODO: move
if not os.path.isdir(CKPT_DIR):
  os.mkdir(CKPT_DIR)

CTSC_ROOT = "./data/Cityscapes/"
GTA5_ROOT = "./data/GTA5/"

# For FDA 
FDA = True
MAX_SAMPLE_PER_CLIENT = 20
N_STYLE = 5
# b == 0 --> 1x1, b == 1 --> 3x3, b == 2 --> 5x5, ...'
BETA_WINDOW_SIZE = 1

if PARTITION == 'A':
  if SPLIT == 1:
    TOT_CLIENTS = 36
  else:
    TOT_CLIENTS = 46
else:
  if SPLIT == 1:
    TOT_CLIENTS = 25
  else:
    TOT_CLIENTS = 33


### Parameter configuration Step 5

In [32]:
T_ROUND = 2 
PSEUDO_LAB = True
FDA = False

PARTITION = "A"  # 'A' or 'B'
SPLIT = 1  # 1 or 2 // 1 = Uniform : 2 = Heterogenous

SEED = 42
DEVICE = 'cuda' # 'cuda' or 'cpu'

NUM_CLASSES = 19
if NUM_CLASSES == 19:
  cl19 = True  

BATCH_SIZE = 8     # Higher batch sizes allows for larger learning rates. An empirical heuristic suggests that, when changing ***
                     # the batch size, learning rate should change by the same factor to have comparable results

LR = 0.05           # The initial Learning Rate ***
MOMENTUM = 0.9       # Hyperparameter for SGD, keep this at 0.9 when using SGD ***
WEIGHT_DECAY = 0.0005  # Regularization, you can keep this at the default ***

LOG_FREQUENCY = 20
LOG_FREQUENCY_EPOCH = 3

MAX_SAMPLE_PER_CLIENT = 20

IMAGES_FINAL = "leftImg8bit"
TARGET_FINAL = "gtFine_labelIds"

N_ROUND = 50
CLIENT_PER_ROUND = 5  # clients picked each round
NUM_EPOCHS = 2

CHECKPOINTS = 5

CTSC_ROOT = "./data/Cityscapes/"
GTA5_ROOT = "./data/GTA5/"

# Used to load the model from the previous step
ckpt_path = 'step4_A_pretraining_0.15.pth'
CKPT_DIR = './checkpoints'

if PARTITION == 'A':
  if SPLIT == 1:
    TOT_CLIENTS = 36
  else:
    TOT_CLIENTS = 46
else:
  if SPLIT == 1:
    TOT_CLIENTS = 25
  else:
    TOT_CLIENTS = 33


### Data Augmentation Options

In [7]:
# data augmentation options
# set as None if a transformation should not be used

RANDOM_HORIZONTAL_FLIP = 0.5
#RANDOM_HORIZONTAL_FLIP = None      #probability of the image being flipped
#COLOR_JITTER = (0.2,0.3,0.2,0.2) # (brighteness, contrast, saturation, hue)
COLOR_JITTER = None
RANDOM_ROTATION = 5              # degree of rotation
#RANDOM_ROTATION = None
#RANDOM_CROP = (512,1024)         # output size of the crop
RANDOM_CROP = None
RESIZE = (512,1024)              # output size
#RESIZE = None
#RANDOM_VERTICAL_FLIP  = 0.3     # probability of the image being flipped
RANDOM_VERTICAL_FLIP = None
CENTRAL_CROP = (512,1024)
#CENTRAL_CROP = (512,1024)
#RANDOM_RESIZE_CROP = (1024,2048)
RANDOM_RESIZE_CROP = None

### Transforms setup

In [8]:
def setup_transform():
    transformers = []
    if RANDOM_HORIZONTAL_FLIP is not None:
        transformers.append(T.RandomHorizontalFlip(RANDOM_HORIZONTAL_FLIP))
    if COLOR_JITTER is not None:
        transformers.append(T.ColorJitter(*COLOR_JITTER))
    if RANDOM_ROTATION is not None:
        transformers.append(T.RandomRotation(RANDOM_ROTATION))
    if RANDOM_CROP is not None:
        transformers.append(T.RandomCrop(RANDOM_CROP))
    if RANDOM_VERTICAL_FLIP is not None:
        transformers.append(T.RandomVerticalFlip(RANDOM_VERTICAL_FLIP))
    if CENTRAL_CROP is not None:
        transformers.append(T.CenterCrop(CENTRAL_CROP))
    if RANDOM_RESIZE_CROP is not None:
        transformers.append(T.RandomResizedCrop(RANDOM_RESIZE_CROP))
    if RESIZE is not None:
        transformers.append(T.Resize(RESIZE))

    transformers.append(T.ToTensor())

    transforms = T.Compose(transformers)

    return transforms

### Fixing random seed

In [9]:
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

### Datasets

#### Cityscapes

In [10]:
class Cityscapes(torch_data.Dataset):

    """
    image path: data/Cityscapes/images/name_leftImg8bit.png
    taget path: data/Cityscapes/labels/name_gtFine_labelIds.png
    """

    def __init__(self, root, transform=None, cl19=False, filename=None, id_client=None):
        eval_classes = [
            7,
            8,
            11,
            12,
            13,
            17,
            19,
            20,
            21,
            22,
            23,
            24,
            25,
            26,
            27,
            28,
            31,
            32,
            33,
        ]
        self.root = root

        if filename is None:
            raise ValueError("filename is None")

        if id_client is not None:
            with open(os.path.join(root, filename)) as f:
                dict_data = json.load(f)

            self.paths_images = [l[0] for l in dict_data[str(id_client)]]
            self.paths_tagets = [l[1] for l in dict_data[str(id_client)]]

        else:
            with open(os.path.join(root, filename), "r") as f:
                lines = f.readlines()

            # manipulate each file row in order to obtain the correct path
            self.paths_images = [l.strip().split("@")[0] for l in lines]
            self.paths_tagets = [l.strip().split("@")[1] for l in lines]

            # self.len = len(self.paths_images)
            # self.transform = transform

        self.len = len(self.paths_images)
        self.transform = transform
        self.return_unprocessed_image = False

        if cl19:
            classes = eval_classes
            mapping = np.zeros((256,), dtype=np.int64) + 255
            for i, cl in enumerate(classes):
                mapping[cl] = i
            self.target_transform = lambda x: from_numpy(mapping[x])

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is the label of segmentation.
        """

        # # using PIL
        img = Image.open(os.path.join(self.root,"images",self.paths_images[index]))
        target = Image.open(os.path.join(self.root,"labels",self.paths_tagets[index]))

        if self.return_unprocessed_image:
            return img

        if self.transform:
            img, target = self.transform(img, target)

        if self.target_transform:
            target = self.target_transform(target)

        return img, target  # output: Tensor[image_channels, image_height, image_width]

    def __len__(self):
        return self.len


#### GTA5

In [11]:
class GTA5(VisionDataset):
    labels2train = {
        "cityscapes": {
            7: 0,
            8: 1,
            11: 2,
            12: 3,
            13: 4,
            17: 5,
            19: 6,
            20: 7,
            21: 8,
            22: 9,
            23: 10,
            24: 11,
            25: 12,
            26: 13,
            27: 14,
            28: 15,
            31: 16,
            32: 17,
            33: 18,
        },
    }

    def __init__(
        self,
        root,
        transform=None,
        mean=(0.5, 0.5, 0.5),
        std=(0.5, 0.5, 0.5),
        cv2=False,
        target_dataset="cityscapes",
    ):
        assert (
            target_dataset in GTA5.labels2train
        ), f"Class mapping missing for {target_dataset}, choose from: {GTA5.labels2train.keys()}"

        self.labels2train = GTA5.labels2train[target_dataset]

        # super().__init__(root, transform=transform, target_transform=None)

        self.root = root
        self.transform = transform
        self.mean = mean
        self.std = std
        self.cv2 = cv2

        self.target_transform = self.__map_labels()

        self.return_unprocessed_image = False
        self.style_tf_fn = None

        with open(os.path.join(self.root, "train.txt"), "r") as f:
            lines = f.readlines()

        # manipulate each file row in order to obtain the correct path
        self.paths_images = [l.strip() for l in lines]

        self.len = len(self.paths_images)

    def set_style_tf_fn(self, style_tf_fn):
        self.style_tf_fn = style_tf_fn

    def reset_style_tf_fn(self):
        self.style_tf_fn = None

    def __getitem__(self, index):
        x_path = os.path.join(self.root, "images", self.paths_images[index])
        y_path = os.path.join(self.root, "labels", self.paths_images[index])

        x = Image.open(x_path)
        y = Image.open(y_path)

        if self.return_unprocessed_image:
            return x

        if self.style_tf_fn is not None:
            x = self.style_tf_fn(x)

        if self.transform is not None:
            x, y = self.transform(x, y)
        y = self.target_transform(y)

        return x, y

    def __len__(self):
        return self.len

    def __map_labels(self):
        mapping = np.zeros((256,), dtype=np.int64) + 255
        for k, v in self.labels2train.items():
            mapping[k] = v
        return lambda x: from_numpy(mapping[x])


### Validation Utilities

In [12]:
colors = [
        [128, 64, 128],
        [244, 35, 232],
        [70, 70, 70],
        [102, 102, 156],
        [190, 153, 153],
        [153, 153, 153],
        [250, 170, 30],
        [220, 220, 0],
        [107, 142, 35],
        [152, 251, 152],
        [0, 130, 180],
        [220, 20, 60],
        [255, 0, 0],
        [0, 0, 142],
        [0, 0, 70],
        [0, 60, 100],
        [0, 80, 100],
        [0, 0, 230],
        [119, 11, 32]    
        ]

label_colours = dict(zip(range(NUM_CLASSES), colors))

def decode_segmap(temp):
    #convert gray scale to color
    #print colored map
    temp=temp.numpy()
    r = temp.copy()
    g = temp.copy()
    b = temp.copy()
    for l in range(0, NUM_CLASSES):
        r[temp == l] = label_colours[l][0]
        g[temp == l] = label_colours[l][1]
        b[temp == l] = label_colours[l][2]

    rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
    rgb[:, :, 0] = r / 255.0
    rgb[:, :, 1] = g / 255.0
    rgb[:, :, 2] = b / 255.0
    return rgb

def compute_miou(net, val_dataloader):
    net = net.to(DEVICE)
    net.train(False)  # Set Network to evaluation mode
    jaccard = MulticlassJaccardIndex(num_classes=NUM_CLASSES, ignore_index=255).to(
        DEVICE
    )

    jacc = 0
    count = 0
    for images, labels in val_dataloader:
        images = images.to(DEVICE, dtype=torch.float32)
        labels = labels.to(DEVICE, dtype=torch.long)
        # Forward Pass
        outputs = net(images)
        # Get predictions
        _, preds = torch.max(outputs.data, 1)

        # Update Corrects
        jacc += jaccard(preds, labels.squeeze())
        count += 1

    # Calculate Accuracy
    metric = jacc.item() / count
    # net.train(True)
    return metric


def validation_plot(net, val_dataloader, n_images):
    net = net.to(DEVICE)
    net.train(False)
    rows = 1
    columns = 3
    n = 0
    for imgs, targets in val_dataloader:
        # i = random.randint(BATCH_SIZE)
        if n >= n_images:
          break
        imgsfloat = imgs.to(DEVICE, dtype=torch.float32)
        outputs = net(imgsfloat)
        _, preds = torch.max(outputs.data, 1)
        # Added in order to use the decode_segmap function
        preds = preds.cpu()  # or equally preds = preds.to('cpu')
        
        for i in range(len(imgs)):
          if n >= n_images:
            break
          # pick the first image of each batch
          print(imgs[i].shape, targets[i].shape)
          print("img:", imgs[i].squeeze().shape, " target:", targets[i].squeeze().shape)
          print("pred:", preds.shape)

          figure = plt.figure(figsize=(10, 20))
          figure.add_subplot(rows, columns, 1)
          #plt.imshow(imgs[0].permute((1, 2, 0)).squeeze())
          plt.imshow(imgs[i].permute((1, 2, 0)).squeeze())
          plt.axis("off")
          plt.title("Image")

          figure.add_subplot(rows, columns, 2)
          #plt.imshow(decode_segmap(targets[0].permute((1, 2, 0)).squeeze()))
          plt.imshow(decode_segmap(targets[i]))
          plt.axis("off")
          plt.title("Groundtruth")

          figure.add_subplot(rows, columns, 3)
          #plt.imshow(decode_segmap(preds[0].squeeze()))
          plt.imshow(decode_segmap(preds[i]))
          plt.axis("off")
          plt.title("Prediction")
          plt.show()
          n += 1
    return

## STEP 1 : GENERATING DATASETS

### Generate splits for step 2

In [None]:

IMAGES_FINAL = "leftImg8bit"
TARGET_FINAL = "gtFine_labelIds"

with open(os.path.join(ROOT_DIR, "train.txt"), "r") as ft:
    lines_train = ft.readlines()
with open(os.path.join(ROOT_DIR, "val.txt"), "r") as fv:
    lines_val = fv.readlines()

lines = lines_train + lines_val
images = [
    (
        l.split("/")[0],
        l.strip().split("/")[1],
        l.strip().split("/")[1].replace(IMAGES_FINAL, TARGET_FINAL),
    )
    for l in lines
]

with open(os.path.join(ROOT_DIR, "train_B.txt"), "w") as f:
    for l in lines_train:
        img = l.strip().split("/")[1]
        lbl = img.replace(IMAGES_FINAL, TARGET_FINAL)
        f.write(img + "@" + lbl + "\n")

with open(os.path.join(ROOT_DIR, "test_B.txt"), "w") as f:
    for l in lines_val:
        img = l.strip().split("/")[1]
        lbl = img.replace(IMAGES_FINAL, TARGET_FINAL)
        f.write(img + "@" + lbl + "\n")

city_dic = {}

for i in images:
    if i[0] not in city_dic:
        city_dic[i[0]] = []
        city_dic[i[0]].append(tuple(i[1:]))
    else:
        city_dic[i[0]].append(tuple(i[1:]))
test_img = []
train_img = []

for c in city_dic.values():
    s = random.sample(c, 2)
    test_img += s
    train_img += c
    for img in s:
        train_img.remove(img)

# save the split
with open(os.path.join(ROOT_DIR, "test_A.txt"), "w") as f:
    for img in test_img:
        f.write(img[0] + "@" + img[1] + "\n")

with open(os.path.join(ROOT_DIR, "train_A.txt"), "w") as f:
    for img in train_img:
        f.write(str(img[0]) + "@" + img[1] + "\n")


### generating splits for step 3

In [None]:
if PARTITION == "A":
    # train A
    with open(os.path.join(ROOT_DIR, "train_A.txt"), "r") as f:
        lines = f.readlines()
        images = [
            (
                l.strip().split("@")[0],
                l.strip().split("@")[1],
            )
            for l in lines
        ]
if PARTITION == "B":
    # train B
    with open(os.path.join(ROOT_DIR, "train.txt"), "r") as f:
        lines = f.readlines()
        images = [
            (
                l.strip().split("/")[1],
                l.strip().split("/")[1].replace(IMAGES_FINAL, TARGET_FINAL),
            )
            for l in lines
        ]

city_dic = {}
for i in images:
    city_name = i[0].split("_")[0]
    if city_name not in city_dic:
        city_dic[city_name] = []
    city_dic[city_name].append(i)

if SPLIT == 1:
    # uniform
    # every client has images from different cityes
    n_sample = len(images)
    n_client_per_city = math.ceil(n_sample / MAX_SAMPLE_PER_CLIENT)
    city_enum = list(enumerate(city_dic.keys()))
    choices = [k for k in city_dic.keys()]
    weights = [len(city_dic[c]) for c in choices]
    client_dict = {}
    for i in range(n_client_per_city):
        client_dict[i] = []
        for _ in range(MAX_SAMPLE_PER_CLIENT):
            choices = [k for k in city_dic.keys()]
            weights = [len(city_dic[c]) for c in choices]
            try:
                c = random.choices(choices, weights=weights, k=1)[0]
            except:
                break
            img, lable = city_dic[c].pop()
            client_dict[i].append((img, lable))
            if len(city_dic[c]) == 0:
                city_dic.pop(c)

    with open(
        os.path.join(ROOT_DIR, f"uniform{PARTITION}.json"), "w"
    ) as outfile:
        json.dump(client_dict, outfile, indent=4)

if SPLIT == 2:
    # heterogeneous
    # every client has images from only one city
    client_dict = {}
    tot_clients = 0
    for city in city_dic.keys():
        n_samples_per_city = len(city_dic[city])
        n_client_per_city = math.ceil(n_samples_per_city / MAX_SAMPLE_PER_CLIENT)
        avg = len(city_dic[city]) // n_client_per_city

        for i in range(tot_clients, tot_clients + n_client_per_city):
            client_dict[i] = []
            for _ in range(avg):
                img, lbl = city_dic[city].pop()
                client_dict[i].append((img, lbl))
            tot_clients += 1
        if len(city_dic[city]) > 0:
            for img, lbl in city_dic[city]:
                client_dict[i].append((img, lbl))
    with open(
        os.path.join(
            ROOT_DIR , f"heterogeneuos{PARTITION}.json"
        ),
        "w",
    ) as outfile:
        json.dump(client_dict, outfile, indent=4)

## STEP 2 : CENTRALIZED BASELINE

### Setup for WanDB

In [None]:
wandb.login()

transformer_dictionary = {
    "random-horizontal-flip":RANDOM_HORIZONTAL_FLIP,
    "color-jitter":COLOR_JITTER,
    "random-rotation":RANDOM_ROTATION,
    "random-crop":RANDOM_CROP,
    "random-vertical-flip":RANDOM_VERTICAL_FLIP,
    "central-crop":CENTRAL_CROP,
    "random-resized-crop":RANDOM_RESIZE_CROP,
    "resize":RESIZE,
    }

config = {
        "batch_size": BATCH_SIZE,
        "lr": LR,
        "num_epochs": NUM_EPOCHS,
        #"step_size": STEP_SIZE,
        "transformers": transformer_dictionary,
    }
name = f"Step_2_{PARTITION}_lr{LR}_bs{BATCH_SIZE}_e{NUM_EPOCHS}"
wandb.init(
    project = " STEP2",
    # entity = "lor-bellino",
    config = config,
    name = name,
)

###  Creating Transforms

In [None]:
transforms = setup_transform()

### Creating Dataloader

In [None]:

if PARTITION == "A":
    train_dataset = Cityscapes(
        root=ROOT_DIR,
        transform=transforms,
        cl19=True,
        filename="train_A.txt",
    )
elif PARTITION == "B":
    train_dataset = Cityscapes(
        root=ROOT_DIR,
        transform=transforms,
        cl19=True,
        filename="train_B.txt",
    )
else:
    raise NotImplementedError

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    drop_last=True,
)


### Training Loop

In [None]:
model = BiSeNetV2(NUM_CLASSES, output_aux=False, pretrained=True)
criterion = nn.CrossEntropyLoss(ignore_index=255)
parameters_to_optimize = model.parameters()
optimizer = optim.SGD(
    parameters_to_optimize, lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY
)
model = model.to(DEVICE)

cudnn.benchmark  # Calling this optimizes runtime

epochs = []

wandb.watch(model, log="all")

current_step = 0
# Start iterating over the epochs
for epoch in range(NUM_EPOCHS):
    print("Starting epoch {}/{}".format(epoch + 1, NUM_EPOCHS))
    epochs.append(epoch + 1)

    # Iterate over the dataset
    for images, labels in train_dataloader:
        images = images.to(DEVICE, dtype=torch.float32)
        labels = labels.to(DEVICE, dtype=torch.long)
        model.train()
        optimizer.zero_grad()

        predictions = model(images)
        loss = criterion(predictions, labels.squeeze())

        # Log loss
        if current_step % LOG_FREQUENCY == 0:
            print("Step {}, Loss {}".format(current_step, loss.item()))
            wandb.log({"train/loss": loss})
        # Compute gradients for each layer and update weights
        loss.backward()
        optimizer.step()

        current_step += 1

### Save the Model

In [None]:
name = f"step2_{PARTITION}_model.pth"
if not os.path.exists("./models/STEP2/"):
    print("creating models directory")
    os.makedirs("./models/STEP2/")
torch.save(model.state_dict(), "./models/STEP2/" + name)

### Creating Validation Dataloader

In [None]:
if PARTITION == "A":
    val_dataset = Cityscapes(
        root=ROOT_DIR,
        transform=transforms,
        cl19=True,
        filename="test_A.txt",
    )
elif PARTITION == "B":
    val_dataset = Cityscapes(
        root=ROOT_DIR,
        transform=transforms,
        cl19=True,
        filename="test_B.txt",
    )
else:
    raise NotImplementedError

val_dataloader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True
)


### Validation

In [None]:
print("computing miou ...")
miou = compute_miou(net=model, val_dataloader=val_dataloader)
print("Validation MIoU: {}".format(miou))
wandb.log({"val/miou": miou})
wandb.finish()

### Validation plot

In [None]:
print("validation plot : ")
validation_plot(net=model, val_dataloader=val_dataloader, n_images=20)
torch.cuda.empty_cache()

## STEP 3 : FEDERATED + SEMANTIC SEGMENTATION

### WanDB Setup

In [None]:
wandb.login()
config = {
        "batch_size": BATCH_SIZE,
        "lr": LR,
        "momentum": MOMENTUM,
        "num_epochs": NUM_EPOCHS,
        "n_client": CLIENT_PER_ROUND,
        "round": N_ROUND,
        "tot_client": TOT_CLIENTS,
        }
name = (
        f"Step_3_{PARTITION}_S{SPLIT}_R{N_ROUND}_c{CLIENT_PER_ROUND}"
        )
wandb.init(
        project=f"STEP3",
        # entity="lor-bellino",
        config=config,
        name=name,
    )

[34m[1mwandb[0m: Currently logged in as: [33mlor-bellino[0m. Use [1m`wandb login --relogin`[0m to force relogin


### Client Class

In [None]:
class Client():
  #def __init__(self, client_id, dataset, model, logger, writer, args, batch_size, world_size, rank, device=None, **kwargs):
  def __init__(self, client_id, dataset, model):
    self.id = client_id
    self.dataset = dataset
    self.model = model #copy.deepcopy(model)
    self.device = DEVICE
    self.batch_size = BATCH_SIZE
    #self.args = args
    self.loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=4, drop_last=True)


  def client_train(self):
    
    num_train_samples = len(self.dataset)
    # Define loss function
    criterion = nn.CrossEntropyLoss(ignore_index= 255)
    parameters_to_optimize = self.model.parameters() 
    optimizer = optim.SGD(parameters_to_optimize, lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

    self.model = self.model.to(DEVICE)
    self.model.train() # Sets module in training mode

    cudnn.benchmark # Calling this optimizes runtime
    
    # Start iterating over the epochs
    for epoch in range(NUM_EPOCHS):
      if epoch % LOG_FREQUENCY_EPOCH == 0: 
        print('Starting epoch {}/{}'.format(epoch+1, NUM_EPOCHS))

      # Iterate over the dataset
      for current_step, (images, labels) in enumerate(self.loader):
        images = images.to(DEVICE, dtype=torch.float32)
        labels = labels.to(DEVICE, dtype=torch.long)
        
        optimizer.zero_grad()
        predictions = self.model(images)
        loss = criterion(predictions, labels.squeeze())

        # Log loss
        if current_step % LOG_FREQUENCY == 0:
          print('Step {}, Loss {}'.format(current_step, loss.item()))
          #wandb.log({f"client{self.id}/loss":loss})
          wandb.log({f"client/loss":loss})

        loss.backward()  # backward pass: computes gradients
        optimizer.step() # update weights based on accumulated gradients

    return num_train_samples, copy.deepcopy(self.model.state_dict()) #generate_update

  def test(self, metrics, ret_samples_ids=None, silobn_type=None, train_cl_bn_stats=None, loader=None):
    return

  def save_model(self, epochs, path, optimizer, scheduler):
    return

### Server Class

In [None]:
class Server():
  #def __init__(self, model, logger, writer, local_rank, lr, momentum, optimizer=None):
  def __init__(self, model, lr= None , momentum= None):
    self.model = copy.deepcopy(model)
    self.model_params_dict = copy.deepcopy(self.model.state_dict())
    self.selected_clients = []
    self.updates = []
    self.lr = lr
    self.momentum = momentum
    self.optimizer = optim.SGD(params=self.model.parameters(), lr=1, momentum=0.9)
    self.total_grad = 0

  def select_clients(self, my_round, possible_clients, num_clients=4):
    num_clients = min(num_clients, len(possible_clients))
    np.random.seed(my_round)
    self.selected_clients = np.random.choice(possible_clients, num_clients, replace=False)
  
  def _compute_client_delta(self, cmodel):
      delta = OrderedDict.fromkeys(cmodel.keys())
      for k, x, y in zip(self.model_params_dict.keys(), self.model_params_dict.values(), cmodel.values()):
        delta[k] = y - x if "running" not in k and "num_batches_tracked" not in k else y
      return delta
  
  def train_round(self):

    self.optimizer.zero_grad()

    clients = self.selected_clients

    for i, c in enumerate(clients):

        print(f"CLIENT {i + 1}/{len(clients)} -> {c.id}:")

        c.model.load_state_dict(self.model_params_dict) # load_server_model_on_client
        out = c.client_train()

        num_samples, update = out

        update = self._compute_client_delta(update)
        
        self.updates.append((num_samples, update))
    return 

  def _server_opt(self, pseudo_gradient):
    for n, p in self.model.named_parameters():
        p.grad = -1.0 * pseudo_gradient[n]

    self.optimizer.step()

    bn_layers = OrderedDict(
        {k: v for k, v in pseudo_gradient.items() if "running" in k or "num_batches_tracked" in k})
    self.model.load_state_dict(bn_layers, strict=False)

  def _aggregation(self):
    total_weight = 0.
    base = OrderedDict()

    for (client_samples, client_model) in self.updates:

        total_weight += client_samples
        for key, value in client_model.items():
            if key in base:
                base[key] += client_samples * value.type(torch.FloatTensor)
            else:
                base[key] = client_samples * value.type(torch.FloatTensor)
    averaged_sol_n = copy.deepcopy(self.model_params_dict)
    for key, value in base.items():
        if total_weight != 0:
            averaged_sol_n[key] = value.to('cuda') / total_weight

    return averaged_sol_n

  def _get_model_total_grad(self):
    total_norm = 0
    for name, p in self.model.named_parameters():
        if p.requires_grad:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_grad = total_norm ** 0.5
    return total_grad

  def update_model(self):
    """FedAvg on the clients' updates for the current round.
    Weighted average of self.updates, where the weight is given by the number
    of samples seen by the corresponding client at training time.
    Saves the new central model in self.client_model and its state dictionary in self.model
    """

    averaged_sol_n = self._aggregation()
    
    self._server_opt(averaged_sol_n)
    self.total_grad = self._get_model_total_grad()
    self.model_params_dict = copy.deepcopy(self.model.state_dict())

    self.updates = []

  def test_model(self, clients_to_test, metrics, ret_samples_bool=False, silobn_type=''):
    return 



### Creating Transforms

In [None]:
transforms = setup_transform()

### Client Setup

In [None]:
def setup_clients(n_client, model):

  clients = []
  if PARTITION == 'A':
    if SPLIT == 1:
      filename="uniformA.json"
    else:
      filename='heterogeneuosA.json'
    for i in range(n_client):
      train_dataset = Cityscapes(root=ROOT_DIR, transform=transforms, cl19 = cl19,filename=filename,  id_client = i)
      client = Client(client_id = i, dataset = train_dataset, model = model)
      clients.append(client)
  else:
    if SPLIT == 1:
      filename="uniformB.json"
    else:
      filename='heterogeneuosB.json'
    for i in range(n_client):
      train_dataset = Cityscapes(root=ROOT_DIR, transform=transforms, cl19 = cl19, filename=filename,  id_client = i)
      client = Client(client_id = i, dataset = train_dataset, model = model)
      clients.append(client)

  return clients

### Validation Dataloader between rounds

In [None]:
def create_val_dataloader(transforms):
    if SPLIT == 1:
        filename = f"uniform{PARTITION}.json"
    else:
        filename = f"heterogeneuos{PARTITION}.json"

    val_dataset = Cityscapes(
        root=ROOT_DIR,
        transform=transforms,
        cl19=cl19,
        filename=filename,
        id_client=0,
    )
    val_dataloader = DataLoader(
        val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, drop_last=True
    )
    return val_dataloader

### Server Training Loop

In [None]:
model = BiSeNetV2(NUM_CLASSES, output_aux=False, pretrained=True)
model = model.to(DEVICE)

wandb.watch(model, log='all')

model_path = "./models/Step3/"
if not os.path.exists(model_path):
    print("creating models directory")
    os.makedirs(model_path)

train_clients = setup_clients(n_client = TOT_CLIENTS, model = model)
print('Number of total clients:', len(train_clients))
val_dataloader = create_val_dataloader(transforms)

server = Server(model, lr=LR, momentum = MOMENTUM)

for r in range(N_ROUND):
  print(f"ROUND {r + 1}/{N_ROUND}: Training {CLIENT_PER_ROUND} Clients...")
  server.select_clients(r, train_clients, num_clients=CLIENT_PER_ROUND) 
  server.train_round()
  server.update_model()
  miou = compute_miou(net=server.model, val_dataloader=val_dataloader)
  wandb.log({"server/miou": miou})
  print(f"Validation MIoU: {miou}")
  if r%CHECKPOINTS == 0:
    print(f"Saving the model")
    torch.save(model.state_dict(), "./models/STEP3/"+f"model_P{PARTITION}_S{SPLIT}_round{r:02}.pth")

torch.save(model.state_dict(), "./models/STEP3/"+f"model_P{PARTITION}_S{SPLIT}_round{r:02}.pth")

### Creating Validation Dataset

In [None]:
if PARTITION == 'A': 
  filename="test_A.txt"
elif PARTITION == 'B': 
  filename="test_B.txt"

val_dataset = Cityscapes(root=ROOT_DIR,transform=transforms, cl19 = cl19, filename=filename)
print("Dataset dimension:", len(val_dataset))
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

Dataset dimension:  250


### Validation

In [None]:
print("computing miou ...")
miou = compute_miou(net=model, val_dataloader=val_dataloader)
print("Validation MIoU: {}".format(miou))
wandb.log({"val/miou": miou})
wandb.finish()

computing miou ...
Validation MIoU: 0.2096177839463757


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
client/loss,█▇▆▄▃▃▃▃▂▂▃▃▂▂▂▂▂▁▂▂▁▂▂▂▂▁▁▂▂▁▁▂▂▁▂▂▁▁▁▁
server/miou,▁▂▂▃▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇██████████████
val/miou,▁

0,1
client/loss,0.46031
server/miou,0.20814
val/miou,0.20962


### Validation Plot

In [None]:
torch.save(model.state_dict(),"./models/model_S3_{PARTITION}.pth")

In [None]:
print("validation plot : ")
validation_plot(net=model, val_dataloader=val_dataloader, n_images=20)
torch.cuda.empty_cache()

validation plot : 


## STEP 4 : MOVING TOWARDS FREEDA - PRETRAINING PHASE

### Setup for WanDB

In [None]:
wandb.login()

transformer_dictionary = {
    "random-horizontal-flip":RANDOM_HORIZONTAL_FLIP,
    "color-jitter":COLOR_JITTER,
    "random-rotation":RANDOM_ROTATION,
    "random-crop":RANDOM_CROP,
    "random-vertical-flip":RANDOM_VERTICAL_FLIP,
    "central-crop":CENTRAL_CROP,
    "random-resized-crop":RANDOM_RESIZE_CROP,
    "resize":RESIZE,
    }

config = {
        "batch_size": BATCH_SIZE,
        "lr": LR,
        "num_epochs": NUM_EPOCHS,
        "weight_decay": WEIGHT_DECAY,
        "momentum" : MOMENTUM,
        #"step_size": STEP_SIZE,
        "transformers": transformer_dictionary,
        "FDA" : FDA,
        "num_styles": N_STYLE,
        "beta_window_size": BETA_WINDOW_SIZE
    }

name = f"step4{'_FDA' if FDA else ''}_{PARTITION}_S{SPLIT}_lr{LR}_bs{BATCH_SIZE}_e{NUM_EPOCHS}"
# if FDA:
#   name = f"Step_4_FDA_{PARTITION}_S{SPLIT}_lr{LR}_bs{BATCH_SIZE}_e{NUM_EPOCHS}"
# else: 
#   name = f"Step_4_{PARTITION}_S{SPLIT}_lr{LR}_bs{BATCH_SIZE}_e{NUM_EPOCHS}"

wandb.init(
    project = f"STEP{'4.4' if FDA else '4.2' }",
    # entity = "lor-bellino",
    config = config,
    name = name,
)

### Creating Transforms

In [None]:
transforms = setup_transform()

### Style Augment

In [None]:
class StyleAugment:

    def __init__(self, n_images_per_style=10, L=0.1, size=(1024, 512), b=None):
        self.styles = []
        self.styles_names = []
        self.n_images_per_style = n_images_per_style
        self.L = L
        self.size = size
        self.sizes = None
        self.cv2 = False
        self.b = b

    def preprocess(self, x):
        if isinstance(x, np.ndarray):
            x = cv2.resize(x, self.size, interpolation=cv2.INTER_CUBIC)
            self.cv2 = True
        else:
            x = x.resize(self.size, Image.BICUBIC)
        x = np.asarray(x, np.float32)
        x = x[:, :, ::-1]
        x = x.transpose((2, 0, 1))
        return x.copy()

    def deprocess(self, x, size):
        if self.cv2:
            x = cv2.resize(np.uint8(x).transpose((1, 2, 0))[:, :, ::-1], size, interpolation=cv2.INTER_CUBIC)
        else:
            x = Image.fromarray(np.uint8(x).transpose((1, 2, 0))[:, :, ::-1])
            x = x.resize(size, Image.BICUBIC)
        return x

    def add_style(self, loader, multiple_styles=False, name=None):
        if self.n_images_per_style < 0:
            return

        if name is not None:
            self.styles_names.append([name] * self.n_images_per_style if multiple_styles else [name])

        loader.return_unprocessed_image = True
        n = 0
        styles = []

        for sample in tqdm(loader, total=min(len(loader), self.n_images_per_style)):

            image = self.preprocess(sample)

            if n >= self.n_images_per_style:
                break
            styles.append(self._extract_style(image))
            n += 1

        if self.n_images_per_style > 1:
            if multiple_styles:
                self.styles += styles
            else:
                styles = np.stack(styles, axis=0)
                style = np.mean(styles, axis=0)
                self.styles.append(style)
        elif self.n_images_per_style == 1:
            self.styles += styles

        loader.return_unprocessed_image = False

    def _extract_style(self, img_np):
        fft_np = np.fft.fft2(img_np, axes=(-2, -1))
        amp = np.abs(fft_np)
        amp_shift = np.fft.fftshift(amp, axes=(-2, -1))
        if self.sizes is None:
            self.sizes = self.compute_size(amp_shift)
        h1, h2, w1, w2 = self.sizes
        style = amp_shift[:, h1:h2, w1:w2]
        return style

    def compute_size(self, amp_shift):
        _, h, w = amp_shift.shape
        b = (np.floor(np.amin((h, w)) * self.L)).astype(int) if self.b is None else self.b
        c_h = np.floor(h / 2.0).astype(int)
        c_w = np.floor(w / 2.0).astype(int)
        h1 = c_h - b
        h2 = c_h + b + 1
        w1 = c_w - b
        w2 = c_w + b + 1
        return h1, h2, w1, w2

    def apply_style(self, image):
        return self._apply_style(image)

    def _apply_style(self, img):

        if self.n_images_per_style < 0:
            return img

        if len(self.styles) > 0:
            n = random.randint(0, len(self.styles) - 1)
            style = self.styles[n]
        else:
            style = self.styles[0]

        if isinstance(img, np.ndarray):
            H, W = img.shape[0:2]
        else:
            W, H = img.size
        img_np = self.preprocess(img)

        fft_np = np.fft.fft2(img_np, axes=(-2, -1))
        amp, pha = np.abs(fft_np), np.angle(fft_np)
        amp_shift = np.fft.fftshift(amp, axes=(-2, -1))
        h1, h2, w1, w2 = self.sizes
        amp_shift[:, h1:h2, w1:w2] = style
        amp_ = np.fft.ifftshift(amp_shift, axes=(-2, -1))

        fft_ = amp_ * np.exp(1j * pha)
        img_np_ = np.fft.ifft2(fft_, axes=(-2, -1))
        img_np_ = np.real(img_np_)
        img_np__ = np.clip(np.round(img_np_), 0., 255.)

        img_with_style = self.deprocess(img_np__, (W, H))

        return img_with_style

    def test(self, images_np, images_target_np=None, size=None):

        Image.fromarray(np.uint8(images_np.transpose((1, 2, 0)))[:, :, ::-1]).show()
        fft_np = np.fft.fft2(images_np, axes=(-2, -1))
        amp = np.abs(fft_np)
        amp_shift = np.fft.fftshift(amp, axes=(-2, -1))
        h1, h2, w1, w2 = self.sizes
        style = amp_shift[:, h1:h2, w1:w2]

        fft_np_ = np.fft.fft2(images_np if images_target_np is None else images_target_np, axes=(-2, -1))
        amp_, pha_ = np.abs(fft_np_), np.angle(fft_np_)
        amp_shift_ = np.fft.fftshift(amp_, axes=(-2, -1))
        h1, h2, w1, w2 = self.sizes
        amp_shift_[:, h1:h2, w1:w2] = style
        amp__ = np.fft.ifftshift(amp_shift_, axes=(-2, -1))

        fft_ = amp__ * np.exp(1j * pha_)
        img_np_ = np.fft.ifft2(fft_, axes=(-2, -1))
        img_np_ = np.real(img_np_)
        img_np__ = np.clip(np.round(img_np_), 0., 255.)
        Image.fromarray(np.uint8(images_target_np.transpose((1, 2, 0)))[:, :, ::-1]).show()
        Image.fromarray(np.uint8(img_np__).transpose((1, 2, 0))[:, :, ::-1]).show()


### Creating Dataloader

In [None]:
# define the clients type
if SPLIT == 1:
  filename = f'uniform{PARTITION}.json'
else:
  filename = f'heterogeneuos{PARTITION}.json'
  
if FDA:
  #L 0.01, 0.05, 0.09
  # b == 0 --> 1x1, b == 1 --> 3x3, b == 2 --> 5x5, ...'
  SA = StyleAugment(n_images_per_style=MAX_SAMPLE_PER_CLIENT, L=0.01, size=(1024, 512), b=BETA_WINDOW_SIZE) 

  clients = random.sample([_ for _ in range(TOT_CLIENTS)],N_STYLE)
  for c in clients:
    client_dataset = Cityscapes(root=CTSC_ROOT, transform=transforms, cl19 = cl19,filename=filename,  id_client = c)
    SA.add_style(client_dataset)

100%|██████████| 16/16 [00:06<00:00,  2.44it/s]
100%|██████████| 18/18 [00:07<00:00,  2.41it/s]
100%|██████████| 16/16 [00:11<00:00,  1.45it/s]
100%|██████████| 18/18 [00:12<00:00,  1.44it/s]
100%|██████████| 17/17 [00:08<00:00,  1.94it/s]


In [None]:
train_dataset = GTA5(root=GTA5_ROOT, transform=transforms)
if FDA:
  train_dataset.set_style_tf_fn(SA.apply_style)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

### Creating MIOU dataloader

In [None]:
miou_filename = f"uniform{PARTITION}.json"
miou_dataset = Cityscapes(root=CTSC_ROOT, transform=transforms, cl19 = cl19, filename=miou_filename, id_client = random.randint(0,TOT_CLIENTS-1))
miou_dataloader = DataLoader(miou_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

### Loading the model

In [None]:
if LOAD_CKPT:
    print("loading the model")
    model = BiSeNetV2(NUM_CLASSES, output_aux=False, pretrained=False)
    model.load_state_dict(torch.load("models/" + LOAD_CKPT_PATH))
else:
    print("creating a new model")
    model = BiSeNetV2(NUM_CLASSES, output_aux=False, pretrained=True)

### Training Loop

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=255)
parameters_to_optimize = model.parameters()
optimizer = optim.SGD(parameters_to_optimize, lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
model = model.to(DEVICE)

cudnn.benchmark  # Calling this optimizes runtime

epochs = []

wandb.watch(model, log="all")

current_step = 0

best_miou = 0

# Start iterating over the epochs
for epoch in range(NUM_EPOCHS):
    print(f"Starting epoch {epoch + 1}/{NUM_EPOCHS}")
    epochs.append(epoch + 1)
    for images, labels in train_dataloader:
       images = images.to(DEVICE, dtype=torch.float32)
       labels = labels.to(DEVICE, dtype=torch.long)
       model.train()
       optimizer.zero_grad()

       predictions = model(images)
       loss = criterion(predictions, labels.squeeze())

       # Log loss
       if current_step % LOG_FREQUENCY == 0:
           print(f"Step {current_step}, Loss {loss.item()}")
           wandb.log({"train/loss": loss})
        # Compute gradients for each layer and update weights
       loss.backward()
       optimizer.step()

       current_step += 1
    miou = compute_miou(model, miou_dataloader)
    wandb.log({"train/miou": miou})
    if miou > best_miou:
        print(f"Saving the model with miou {miou:.2} (best so far)")
        best_miou = miou
        torch.save(
               {
                    "epoch": epoch + 1,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "loss": loss,
                    "miou": miou,
                },
                CKPT_PATH + f"step4_{PARTITION}_S{SPLIT}_{'FDA' if FDA else 'pretraining'}_{miou:.2}.pth",
            )

### Creating Validation Dataloader

In [None]:
val_dataset = Cityscapes(root=CTSC_ROOT, transform=transforms, cl19 = cl19, filename=f"test_{PARTITION}.txt")
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

### Extract the best model

In [None]:
torch.cuda.empty_cache()
ckpt = torch.load(CKPT_PATH + f"step4_{PARTITION}_S{SPLIT}_{'FDA' if FDA else 'pretraining'}_{best_miou:.2}.pth")
model = BiSeNetV2(NUM_CLASSES, output_aux=False, pretrained=True)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
model.to(DEVICE)

### Validation

In [None]:
print("computing miou ...")
miou = compute_miou(net=model, val_dataloader=val_dataloader)
print("Validation MIoU: {}".format(miou))
wandb.log({"val/miou": miou})
wandb.finish()

### Validation Plot

In [None]:
print("validation plot : ")
validation_plot(net=model, val_dataloader=val_dataloader, n_images=5)
torch.cuda.empty_cache()

## STEP 5 : FEDERATED SELF TRAINING USING PSEUDO-LABELS

### Setup for WanDB

In [15]:
wandb.login()

transformer_dictionary = {
    "random-horizontal-flip":RANDOM_HORIZONTAL_FLIP,
    "color-jitter":COLOR_JITTER,
    "random-rotation":RANDOM_ROTATION,
    "random-crop":RANDOM_CROP,
    "random-vertical-flip":RANDOM_VERTICAL_FLIP,
    "central-crop":CENTRAL_CROP,
    "random-resized-crop":RANDOM_RESIZE_CROP,
    "resize":RESIZE,
    }

config = {
        "batch_size": BATCH_SIZE,
        "lr": LR,
        "momentum": MOMENTUM,
        "weight_decay": WEIGHT_DECAY,
        "num_epochs": NUM_EPOCHS,
        "n_client": CLIENT_PER_ROUND,
        "round": N_ROUND,
        "tot_client": TOT_CLIENTS,
        "transformers": transformer_dictionary,
    }


name = f"Step_5{'_FDA' if FDA else ''}_T{T_ROUND}_{PARTITION}_S{SPLIT}_R{N_ROUND}_c{CLIENT_PER_ROUND}"

wandb.init(
    project = "STEP5",
    # entity = "lor-bellino",
    config = config,
    name = name,
)

[34m[1mwandb[0m: Currently logged in as: [33melia-fontana[0m ([33maml2023[0m). Use [1m`wandb login --relogin`[0m to force relogin


### Creating Transform

In [16]:
transforms = setup_transform()

### Self Training Loss


In [17]:
class SelfTrainingLoss(nn.Module):
    requires_reduction = False

    def __init__(self, conf_th=0.9, fraction=0.66, ignore_index=255, lambda_selftrain=1, **kwargs):
        super().__init__()
        self.conf_th = conf_th
        self.fraction = fraction
        self.ignore_index = ignore_index
        self.teacher = None
        self.lambda_selftrain = lambda_selftrain

    def set_teacher(self, model):
        self.teacher = model

    def get_image_mask(self, prob, pseudo_lab):
        max_prob = prob.detach().clone().max(0)[0]
        mask_prob = max_prob > self.conf_th if 0. < self.conf_th < 1. else torch.zeros(max_prob.size(),
                                                                                       dtype=torch.bool).to(
            max_prob.device)
        mask_topk = torch.zeros(max_prob.size(), dtype=torch.bool).to(max_prob.device)
        if 0. < self.fraction < 1.:
            for c in pseudo_lab.unique():
                mask_c = pseudo_lab == c
                max_prob_c = max_prob.clone()
                max_prob_c[~mask_c] = 0
                _, idx_c = torch.topk(max_prob_c.flatten(), k=int(mask_c.sum() * self.fraction))
                mask_topk_c = torch.zeros_like(max_prob_c.flatten(), dtype=torch.bool)
                mask_topk_c[idx_c] = 1
                mask_c &= mask_topk_c.unflatten(dim=0, sizes=max_prob_c.size())
                mask_topk |= mask_c
        return mask_prob | mask_topk

    def get_batch_mask(self, pred, pseudo_lab):
        b, _, _, _ = pred.size()
        mask = torch.stack([self.get_image_mask(pb, pl) for pb, pl in zip(F.softmax(pred, dim=1), pseudo_lab)], dim=0)
        return mask

    def get_pseudo_lab(self, pred, imgs=None, return_mask_fract=False, model=None):
        teacher = self.teacher if model is None else model
        if teacher is not None:
            with torch.no_grad():
                try:
                    pred = teacher(imgs)['out']
                except:
                    pred = teacher(imgs)
                pseudo_lab = pred.detach().max(1)[1]
        else:
            pseudo_lab = pred.detach().max(1)[1]
        mask = self.get_batch_mask(pred, pseudo_lab)
        pseudo_lab[~mask] = self.ignore_index
        if return_mask_fract:
            return pseudo_lab, F.softmax(pred, dim=1), mask.sum() / mask.numel()
        return pseudo_lab

    def forward(self, pred, imgs=None):
        pseudo_lab = self.get_pseudo_lab(pred, imgs)
        loss = F.cross_entropy(input=pred, target=pseudo_lab, ignore_index=self.ignore_index, reduction='none')
        return loss.mean() * self.lambda_selftrain

### Client class

In [18]:
class Client():
  #def __init__(self, client_id, dataset, model, logger, writer, args, batch_size, world_size, rank, device=None, **kwargs):
  def __init__(self, client_id, dataset, model,  pseudo_lab=False, teacher_model=None):
    self.id = client_id
    self.dataset = dataset
    self.model = model #copy.deepcopy(model)
    self.device = DEVICE
    self.batch_size = BATCH_SIZE
    self.loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=4, drop_last=True)

    # ADDED
    self.pseudo_lab = pseudo_lab
    self.teacher_model = copy.deepcopy(teacher_model)
    # Define loss function
    if self.pseudo_lab:
      self.criterion = SelfTrainingLoss()
      self.criterion.set_teacher(self.teacher_model)
    else:
      self.criterion = nn.CrossEntropyLoss(ignore_index= 255)

  def client_train(self):
    
    num_train_samples = len(self.dataset)

    parameters_to_optimize = self.model.parameters() 
    optimizer = optim.SGD(parameters_to_optimize, lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

    self.model = self.model.to(DEVICE)
    self.model.train() # Sets module in training mode

    cudnn.benchmark # Calling this optimizes runtime
    
    # Start iterating over the epochs
    for epoch in range(NUM_EPOCHS):
      if epoch % LOG_FREQUENCY_EPOCH == 0: 
        print('Starting epoch {}/{}'.format(epoch+1, NUM_EPOCHS))

      # Iterate over the dataset
      for current_step, (images, labels) in enumerate(self.loader):
        images = images.to(DEVICE, dtype=torch.float32)
        labels = labels.to(DEVICE, dtype=torch.long)
        
        optimizer.zero_grad()

        predictions = self.model(images)
        # ADDED
        if self.pseudo_lab:
            loss = self.criterion(predictions, images.squeeze())
        else:
            loss = self.criterion(predictions, labels.squeeze())

        # Log loss
        if current_step % LOG_FREQUENCY == 0:
          print('Step {}, Loss {}'.format(current_step, loss.item()))
          wandb.log({f"client/loss":loss})

        loss.backward()  # backward pass: computes gradients
        optimizer.step() # update weights based on accumulated gradients

    return num_train_samples, copy.deepcopy(self.model.state_dict()) #generate_update

  def test(self, metrics, ret_samples_ids=None, silobn_type=None, train_cl_bn_stats=None, loader=None):
    return

  def save_model(self, epochs, path, optimizer, scheduler):
    return

### Server class

In [19]:
class Server():
  #def __init__(self, model, logger, writer, local_rank, lr, momentum, optimizer=None):
  def __init__(self, model, lr= None , momentum= None, pseudo_lab=False):
    self.model = copy.deepcopy(model)
    self.model_params_dict = copy.deepcopy(self.model.state_dict())
    self.selected_clients = []
    self.updates = []
    self.lr = lr
    self.momentum = momentum
    self.optimizer = optim.SGD(params=self.model.parameters(), lr=1, momentum=0.9)
    self.total_grad = 0

    # ADDED
    self.pseudo_lab = pseudo_lab

  def select_clients(self, my_round, possible_clients, num_clients=4):
    num_clients = min(num_clients, len(possible_clients))
    np.random.seed(my_round)
    self.selected_clients = np.random.choice(possible_clients, num_clients, replace=False)

  def set_clients_teacher(self, train_clients):
    for c in train_clients:
      c.criterion.set_teacher(copy.deepcopy(self.model))

  
  def _compute_client_delta(self, cmodel):
      delta = OrderedDict.fromkeys(cmodel.keys())
      for k, x, y in zip(self.model_params_dict.keys(), self.model_params_dict.values(), cmodel.values()):
        delta[k] = y - x if "running" not in k and "num_batches_tracked" not in k else y
      return delta
  
  def train_round(self):

    self.optimizer.zero_grad()

    clients = self.selected_clients

    for i, c in enumerate(clients):
        print(f"CLIENT {i + 1}/{len(clients)} -> {c.id}:")

        c.model.load_state_dict(self.model_params_dict) # load_server_model_on_client
        out = c.client_train()

        num_samples, update = out

        update = self._compute_client_delta(update)
        
        self.updates.append((num_samples, update))
    return 

  def _server_opt(self, pseudo_gradient):
    for n, p in self.model.named_parameters():
        p.grad = -1.0 * pseudo_gradient[n]

    self.optimizer.step()

    bn_layers = OrderedDict(
        {k: v for k, v in pseudo_gradient.items() if "running" in k or "num_batches_tracked" in k})
    self.model.load_state_dict(bn_layers, strict=False)

  def _aggregation(self):
    total_weight = 0.
    base = OrderedDict()

    for (client_samples, client_model) in self.updates:

        total_weight += client_samples
        for key, value in client_model.items():
            if key in base:
                base[key] += client_samples * value.type(torch.FloatTensor)
            else:
                base[key] = client_samples * value.type(torch.FloatTensor)
    averaged_sol_n = copy.deepcopy(self.model_params_dict)
    for key, value in base.items():
        if total_weight != 0:
            averaged_sol_n[key] = value.to('cuda') / total_weight

    return averaged_sol_n

  def _get_model_total_grad(self):
    total_norm = 0
    for name, p in self.model.named_parameters():
        if p.requires_grad:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_grad = total_norm ** 0.5
    return total_grad

  def update_model(self):
    """FedAvg on the clients' updates for the current round.
    Weighted average of self.updates, where the weight is given by the number
    of samples seen by the corresponding client at training time.
    Saves the new central model in self.client_model and its state dictionary in self.model
    """

    averaged_sol_n = self._aggregation()
    
    self._server_opt(averaged_sol_n)
    self.total_grad = self._get_model_total_grad()
    self.model_params_dict = copy.deepcopy(self.model.state_dict())

    self.updates = []

  def test_model(self, clients_to_test, metrics, ret_samples_bool=False, silobn_type=''):
    return 

### Creating MIOU dataloader

In [23]:
if SPLIT == 1:
  filename = f'uniform{PARTITION}.json'
else:
  filename = f'heterogeneous{PARTITION}.json'

miou_dataset = Cityscapes(root=CTSC_ROOT, transform=transforms, cl19 = cl19, filename=filename, id_client = random.randint(0,TOT_CLIENTS-1))
miou_dataloader = DataLoader(miou_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

### Setup clients

In [24]:
def setup_clients(n_client, model):

  clients = []
  if SPLIT == 1:
    filename=f"uniform{PARTITION}.json"
  else:
    filename=f'heterogeneuos{PARTITION}.json'
    
  for i in range(n_client):
    train_dataset = Cityscapes(root=CTSC_ROOT, transform=transforms, cl19 = cl19,filename=filename,  id_client = i)
    client = Client(client_id = i, dataset = train_dataset, model = model, pseudo_lab = PSEUDO_LAB, teacher_model=model)
    clients.append(client)
    
  return clients

### Training loop

In [None]:
model = BiSeNetV2(NUM_CLASSES, output_aux=False, pretrained=True)
# Load the model from the previous step
if ckpt_path in os.listdir(CKPT_DIR):
  checkpoint = torch.load(os.path.join(CKPT_DIR,ckpt_path))
  model.load_state_dict(checkpoint['model_state_dict'])
else:
  raise FileNotFoundError

model = model.to(DEVICE)

wandb.watch(model, log='all')

model_path = "./models/Step5/"
if not os.path.exists(model_path):
  print("creating models directory")
  os.makedirs(model_path)


train_clients = setup_clients(n_client = TOT_CLIENTS, model = model)
print(len(train_clients))

server = Server(model, lr=LR, momentum = MOMENTUM, pseudo_lab=PSEUDO_LAB)

for r in range(N_ROUND):
  print(f"ROUND {r + 1}/{N_ROUND}: Training {CLIENT_PER_ROUND} Clients...")
  if r % T_ROUND == 0: 
    print('Teacher model updated')
    server.set_clients_teacher(train_clients)
  server.select_clients(r, train_clients, num_clients=CLIENT_PER_ROUND) 
  server.train_round()
  server.update_model()

  # TODO: move validation inside Server class
  miou = compute_miou(net=server.model, val_dataloader=miou_dataloader)
  wandb.log({"server/miou": miou})
  print(f"Validation MIoU: {miou}")
  
  if r%CHECKPOINTS == 0:
    print(f"Saving the model")
    torch.save(model.state_dict(), model_path+f"model{'_FDA' if FDA else ''}_T{T_ROUND}_P{PARTITION}_S{SPLIT}_round{r:02}.pth")

torch.save(model.state_dict(), model_path+f"model{'_FDA' if FDA else ''}_T{T_ROUND}_P{PARTITION}_S{SPLIT}.pth")

### Creating Validation Dataloader


In [34]:
val_dataset = Cityscapes(root=CTSC_ROOT, transform=transforms, cl19 = cl19, filename=f"test_{PARTITION}.txt")
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

### Extracting the best model

In [None]:
torch.cuda.empty_cache()
#ckpt = torch.load(torch.load_state_dict(model_path+f"model{'_FDA' if FDA else ''}_T{T_ROUND}_P{PARTITION}_S{SPLIT}.pth"))
model = BiSeNetV2(NUM_CLASSES, output_aux=False, pretrained=True)
model.load_state_dict(torch.load(model_path+f"model{'_FDA' if FDA else ''}_T{T_ROUND}_P{PARTITION}_S{SPLIT}.pth"))
model.eval()
model.to(DEVICE)

### Validation

In [None]:
print("computing miou ...")
miou = compute_miou(net=model, val_dataloader=val_dataloader)
print("Validation MIoU: {}".format(miou))
wandb.log({"val/miou": miou})
wandb.finish()

### Validation Plot

In [None]:
print("validation plot : ")
validation_plot(net=model, val_dataloader=val_dataloader, n_images=5)
torch.cuda.empty_cache()

## STEP 6 : MOBILENET V2

### Setup for WanDB

In [None]:
wandb.login()

transformer_dictionary = {
    "random-horizontal-flip":RANDOM_HORIZONTAL_FLIP,
    "color-jitter":COLOR_JITTER,
    "random-rotation":RANDOM_ROTATION,
    "random-crop":RANDOM_CROP,
    "random-vertical-flip":RANDOM_VERTICAL_FLIP,
    "central-crop":CENTRAL_CROP,
    "random-resized-crop":RANDOM_RESIZE_CROP,
    "resize":RESIZE,
    }

config = {
        "batch_size": BATCH_SIZE,
        "lr": LR,
        "num_epochs": NUM_EPOCHS,
        #"step_size": STEP_SIZE,
        "transformers": transformer_dictionary,
    }
name = f"Step_6_{PARTITION}_lr{LR}_bs{BATCH_SIZE}_e{NUM_EPOCHS}"
wandb.init(
    project = " STEP6",
    # entity = "lor-bellino",
    config = config,
    name = name,
)

### MobileNet V2

In [None]:
import copy
from collections import OrderedDict

import torch

# from .mobilenetv2 import MobileNetV2
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.segmentation.deeplabv3 import DeepLabV3, DeepLabHead
from torchvision._internally_replaced_utils import load_state_dict_from_url

import torch.nn as nn
import math

__all__ = ['mobilenetv2', 'MobileNetV2']


def _make_divisible(v, divisor, min_value=None):
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


def conv_3x3_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )


def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        assert stride in [1, 2]

        hidden_dim = round(inp * expand_ratio)
        self.identity = stride == 1 and inp == oup

        if expand_ratio == 1:
            self.conv = nn.Sequential(
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

    def forward(self, x):
        if self.identity:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileNetV2(nn.Module):
    def __init__(self, num_classes=1000, width_mult=1., in_channels=3):
        super(MobileNetV2, self).__init__()
        self.cfgs = [
            [1, 16, 1, 1],
            [6, 24, 2, 2],
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            [6, 96, 3, 1],
            [6, 160, 3, 2],
            [6, 320, 1, 1],
        ]

        input_channel = _make_divisible(32 * width_mult, 4 if width_mult == 0.1 else 8)
        layers = [conv_3x3_bn(in_channels, input_channel, 2)]
        block = InvertedResidual
        for t, c, n, s in self.cfgs:
            output_channel = _make_divisible(c * width_mult, 4 if width_mult == 0.1 else 8)
            for i in range(n):
                layers.append(block(input_channel, output_channel, s if i == 0 else 1, t))
                input_channel = output_channel

        output_channel = _make_divisible(1280 * width_mult, 4 if width_mult == 0.1 else 8) if width_mult > 1.0 else 1280
        layers.append(conv_1x1_bn(input_channel, output_channel))
        self.features = nn.Sequential(*layers)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(output_channel, num_classes)

        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()


def mobilenetv2(**kwargs):
    return MobileNetV2(**kwargs)



def _deeplabv3_mobilenetv2(
        backbone: MobileNetV2,
        num_classes: int,
) -> DeepLabV3:
    backbone = backbone.features

    out_pos = len(backbone) - 1
    out_inplanes = backbone[out_pos][0].out_channels
    return_layers = {str(out_pos): "out"}

    backbone = create_feature_extractor(backbone, return_layers)
    classifier = DeepLabHead(out_inplanes, num_classes)

    return DeepLabV3(backbone, classifier)


def deeplabv3_mobilenetv2(
        num_classes: int = 21,
        in_channels: int = 3
) -> DeepLabV3:

    width_mult = 1
    backbone = MobileNetV2(width_mult=width_mult, in_channels=in_channels)
    model_urls = {
        0.5: 'https://github.com/d-li14/mobilenetv2.pytorch/raw/master/pretrained/mobilenetv2_0.5-eaa6f9ad.pth',
        1.: 'https://github.com/d-li14/mobilenetv2.pytorch/raw/master/pretrained/mobilenetv2_1.0-0c6065bc.pth'}
    state_dict = load_state_dict_from_url(model_urls[width_mult], progress=True)
    state_dict_updated = state_dict.copy()
    for k, v in state_dict.items():
        if 'features' not in k and 'classifier' not in k:
            state_dict_updated[k.replace('conv', 'features.18')] = v
            del state_dict_updated[k]

    if in_channels == 4:
        aux = torch.zeros((32, 4, 3, 3))
        aux[:, 0, :, :] = copy.deepcopy(state_dict_updated['features.0.0.weight'][:, 0, :, :])
        aux[:, 1, :, :] = copy.deepcopy(state_dict_updated['features.0.0.weight'][:, 1, :, :])
        aux[:, 2, :, :] = copy.deepcopy(state_dict_updated['features.0.0.weight'][:, 2, :, :])
        aux[:, 3, :, :] = copy.deepcopy(state_dict_updated['features.0.0.weight'][:, 2, :, :])
        state_dict_updated['features.0.0.weight'] = aux
    backbone.load_state_dict(state_dict_updated, strict=False)

    model = _deeplabv3_mobilenetv2(backbone, num_classes)
    model.task = 'segmentation'

    return model


### Creating Transforms

In [None]:
transforms = setup_transform()

### Creating Dataloader

In [None]:

if PARTITION == "A":
    train_dataset = Cityscapes(
        root=ROOT_DIR,
        transform=transforms,
        cl19=True,
        filename="train_A.txt",
    )
elif PARTITION == "B":
    train_dataset = Cityscapes(
        root=ROOT_DIR,
        transform=transforms,
        cl19=True,
        filename="train_B.txt",
    )
else:
    raise NotImplementedError

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    drop_last=True,
)


### Loading the model

In [None]:
if LOAD_CKPT:
    print("loading the model")
    model = deeplabv3_mobilenetv2(num_classes=NUM_CLASSES)
    # model = BiSeNetV2(NUM_CLASSES, output_aux=False, pretrained=False)
    model.load_state_dict(torch.load("models/" + LOAD_CKPT_PATH))
else:
    print("creating a new model")
    model = deeplabv3_mobilenetv2(num_classes=NUM_CLASSES)
    # model = BiSeNetV2(NUM_CLASSES, output_aux=False, pretrained=True)

### Creating Miou dataloader

In [None]:
if SPLIT == 1:
  filename = f'uniform{PARTITION}.json'
else:
  filename = f'heterogeneous{PARTITION}.json'

miou_dataset = Cityscapes(root=CTSC_ROOT, transform=transforms, cl19 = cl19, filename=filename, id_client = random.randint(0,TOT_CLIENTS-1))
miou_dataloader = DataLoader(miou_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

### Training Loop

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=255)
parameters_to_optimize = model.parameters()
optimizer = optim.SGD(parameters_to_optimize, lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
model = model.to(DEVICE)

cudnn.benchmark  # Calling this optimizes runtime

epochs = []

wandb.watch(model, log="all")

current_step = 0

best_miou = 0

# Start iterating over the epochs
for epoch in range(NUM_EPOCHS):
    print(f"Starting epoch {epoch + 1}/{NUM_EPOCHS}")
    epochs.append(epoch + 1)
    for images, labels in train_dataloader:
       images = images.to(DEVICE, dtype=torch.float32)
       labels = labels.to(DEVICE, dtype=torch.long)
       model.train()
       optimizer.zero_grad()

       predictions = model(images)
       loss = criterion(predictions, labels.squeeze())

       # Log loss
       if current_step % LOG_FREQUENCY == 0:
           print(f"Step {current_step}, Loss {loss.item()}")
           wandb.log({"train/loss": loss})
        # Compute gradients for each layer and update weights
       loss.backward()
       optimizer.step()

       current_step += 1
    miou = compute_miou(model, miou_dataloader)
    wandb.log({"train/miou": miou})
    if miou > best_miou:
        print(f"Saving the model with miou {miou:.2} (best so far)")
        best_miou = miou
        torch.save(
               {
                    "epoch": epoch + 1,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "loss": loss,
                    "miou": miou,
                },
                CKPT_PATH + f"step4_{PARTITION}_S{SPLIT}_{'FDA' if FDA else 'pretraining'}_{miou:.2}.pth",
            )

### Creating Validation Dataloader

In [None]:
val_dataset = Cityscapes(root=CTSC_ROOT, transform=transforms, cl19 = cl19, filename=f"test_{PARTITION}.txt")
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

### Extract the best model

In [None]:
torch.cuda.empty_cache()
ckpt = torch.load(CKPT_PATH + f"step4_{PARTITION}_S{SPLIT}_{'FDA' if FDA else 'pretraining'}_{best_miou:.2}.pth")
model = BiSeNetV2(NUM_CLASSES, output_aux=False, pretrained=True)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
model.to(DEVICE)

### Validation

In [None]:
print("computing miou ...")
miou = compute_miou(net=model, val_dataloader=val_dataloader)
print("Validation MIoU: {}".format(miou))
wandb.log({"val/miou": miou})
wandb.finish()

### Validation Plot

In [None]:
print("validation plot : ")
validation_plot(net=model, val_dataloader=val_dataloader, n_images=5)
torch.cuda.empty_cache()