# Storage

In [1]:
## the path to the `mldl2023` folder in your drive
rootMldl = '<your_drive>/mldl2023'

# Initialization

In [2]:
# packages 1
import shutil
import os
import torch

if not torch.cuda.is_available():
    raise RuntimeError('The model cannot operate without CUDA!')

In [3]:
# getting the notebook's name
from requests import get
from socket import gethostname, gethostbyname
notebookName = get(f'http://{gethostbyname(gethostname())}:9000/api/sessions').json()[0]['name']

In [None]:
# mounting the drive
from google.colab import drive
drive.mount('/content/drive')

In [5]:
# changing the root
os.chdir(rootMldl)

In [6]:
# packages 2
import random
import string
from typing import Any
from typing import List
import numpy as np
from PIL import Image
from torch import from_numpy
from torchvision.datasets import VisionDataset
import datasets.ss_transforms as tr
from utils.stream_metrics import StreamSegMetrics
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from models import deeplabv3, mobilenetv2
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch, os, copy
import tqdm.notebook as tqdm
import gc
from utils.utils import HardNegativeMining, MeanReduction
import math
import json
import csv
from pprint import pprint
from torch import from_numpy
import datetime
import time

# Memory Management

In [None]:
# packages
!pip install gputil
import prettytable
import psutil
import GPUtil

In [8]:
# garbage collector
def clearCache():
    gc.collect()
    torch.cuda.empty_cache()

In [9]:
# memory scanner
def printMemoryUsage(title='Memory status'):
    diskStatus = shutil.disk_usage('/content/')
    table = prettytable.PrettyTable(['type', 'available (MB)', 'used (MB)', 'free (MB)'])
    for field in table.field_names:
        table.align[field] = 'l'
    table.add_row([
        'disk',
        round((diskStatus[1]+diskStatus[2])/(1024*1024), 1),
        round((diskStatus[1])/(1024*1024), 1),
        round((diskStatus[2])/(1024*1024), 1)
    ])
    table.add_row([
        'ram',
        round((psutil.virtual_memory().available+psutil.virtual_memory().used)/(1024*1024), 1),
        round(psutil.virtual_memory().used/(1024*1024), 1),
        round(psutil.virtual_memory().available/(1024*1024), 1)
    ])
    for i, gpu in enumerate(GPUtil.getGPUs()):
        table.add_row([
            f'gpu-{i} ram',
            round(gpu.memoryUsed+gpu.memoryFree, 1),
            round(gpu.memoryUsed, 1),
            round(gpu.memoryFree, 1)
        ])
    print(title+':')
    print(table)
    print('')

# Dataset Class

In [10]:
class IDDADataset(VisionDataset):

    @staticmethod
    def get_mapping():
        classes = [255, 2, 4, 255, 11, 5, 0, 0, 1, 8, 13, 3, 7, 6, 255, 255, 15, 14, 12, 9, 10]
        mapping = 255*np.ones(256, dtype=np.int64)
        mapping[range(len(classes))] = classes
        return lambda x: from_numpy(mapping[x])

    def __init__(self, root, fileNames, transform=None):
        super().__init__(root=root, transform=transform, target_transform=IDDADataset.get_mapping())
        self.fileNames = fileNames

    def __getitem__(self, index):
        image = Image.open(self.root+'/images/'+self.fileNames[index]+'.jpg').convert('RGB')
        label = Image.open(self.root+'/labels/'+self.fileNames[index]+'.png').convert('L')
        if self.transform is not None:
            image, label = self.transform(image, label)
        label = self.target_transform(label)
        return image, label

    def __len__(self):
        return len(self.fileNames)

# Pseudo-Label Functions

In [11]:
def get_image_mask(prob, pseudo_lab, th):
    #.clone to copy
    # .detach to take out the computations behind the tensor, like gradient track
    # .max(0) to get max on dim 0, and [0] to get the probability
    #  max(0) is a tuple, first item is the highest probability, the second is the
    # index of the highest probability (on dim=0)
    max_prob = prob.detach().clone().max(0)[0]

    # put a mask where the highest probability is higher than th
    mask_prob = max_prob > th
    return mask_prob

In [12]:
def get_batch_mask(pred, pseudo_lab, th):
    import torch.nn.functional as F
    # normalize (using softmax) the probabilities for each class for each pixel
    # for each image in batch
    # and give the masks
    # (got this line from TA code)
    mask = torch.stack([get_image_mask(pb, pl, th) for pb, pl in zip(F.softmax(pred, dim=1), pseudo_lab)], dim=0)
    return mask

In [13]:
def get_pseudo_lab(imgs, model, th):
    model.eval()
    with torch.no_grad():
        output = model(imgs)
    # output["out"] returns the probabilities of each class
    pred = output["out"]

    # max(1) is a tuple, first item is the highest probability, the second is the
    # index of the highest probability
    # (on dim=1)
    # .detach() takes out the gradient computations and etc
    pseudo_lab = pred.detach().max(1)[1]

    mask = get_batch_mask(pred, pseudo_lab, th)

    # change the pixels where the max probability is under
    # the threshold to 255, that is ignored in the loss computation
    pseudo_lab[~mask] = 255
    return pseudo_lab

# Client class

In [14]:
class Client:

    def __init__(self,
                 device, name,
                 datasetTrain,
                 batchSizeTrain):

        self.device = device
        self.name = name

        self.dataLoaderTrain = DataLoader(datasetTrain, batch_size=batchSizeTrain, shuffle=True, drop_last=True)
        self.importance = len(datasetTrain)

    def train(self, modelTeacher, modelStudent, num_epochs, pbar, threshold, optimizer, criterion, metric):
        clearCache()
        modelStudent.train()

        lossCum = 0.0
        for epoch in range(num_epochs):
            lossEpoch = 0.0
            for batch, (images, _) in enumerate(self.dataLoaderTrain):
                pbar.set_postfix({
                    'client': pbar.client,
                    'epoch': f'{epoch+1}/{num_epochs}',
                    'batch': f'{batch+1}/{len(self.dataLoaderTrain)}'
                })

                images = images.to(self.device)
                pseudoLabels = get_pseudo_lab(images, modelTeacher, threshold).to(self.device)

                optimizer.zero_grad()
                outputs = modelStudent(images)['out']
                loss = criterion(outputs, pseudoLabels)
                loss.backward()
                optimizer.step()

                lossEpoch = (batch*lossEpoch + loss.item())/(batch + 1)
                _, prediction = outputs.max(dim=1)
                metric.update(pseudoLabels.cpu().numpy(), prediction.cpu().numpy())

                pbar.update(1)

            lossCum = (epoch*lossCum + lossEpoch)/(epoch + 1)
        miouCum = metric.get_results()['Mean IoU']

        return lossCum, miouCum

# Server class

In [15]:
class Server:

    def __init__(self,
                 device, modelTeacher, modelStudent, clients,
                 datasetTestIdda, datasetTestSame, datasetTestDiff,
                 batchSizeTest,
                 metricClass, num_rounds, max_num_clients_per_round, max_num_epochs_per_client, threshold, updatePeriod,
                 scheduler_dict, optimizer_dict,
                 notebookName, pathStorage, lastRound):

        self.device = device
        self.modelTeacher = modelTeacher
        self.modelStudent = modelStudent
        self.clients = clients

        self.dataLoaderTestIdda = DataLoader(datasetTestIdda, batch_size=batchSizeTest, shuffle=False, drop_last=False)
        self.dataLoaderTestSame = DataLoader(datasetTestSame, batch_size=batchSizeTest, shuffle=False, drop_last=False)
        self.dataLoaderTestDiff = DataLoader(datasetTestDiff, batch_size=batchSizeTest, shuffle=False, drop_last=False)

        self.metricClass = metricClass
        self.num_rounds = num_rounds
        self.max_num_clients_per_round = max_num_clients_per_round
        self.max_num_epochs_per_client = max_num_epochs_per_client
        self.threshold = threshold
        self.updatePeriod = updatePeriod


        # manually programmed cosine annealing scheduler
        self.scheduler = lambda round: scheduler_dict['lr_min'] + 0.5*(scheduler_dict['lr_initial']-scheduler_dict['lr_min'])*(1+math.cos(math.pi*round/scheduler_dict['T_max']))
        self.optimizer_dict = optimizer_dict

        # STORAGE
        self.notebookName = notebookName                                        # STORAGE
        self.pathStorage = pathStorage                                          # STORAGE
        self.initialRound = lastRound + 1                                       # STORAGE
        # STORAGE

        self.lr = self.scheduler(self.initialRound)

    def step(self):
        self.round += 1
        self.lr = self.scheduler(self.round)

    def select_clients(self):
        num_clients = min(self.max_num_clients_per_round, len(self.clients))
        return random.sample(self.clients, num_clients)

    def train(self):
        self.round = self.initialRound
        for round in range(self.initialRound, self.num_rounds):
            print(f'--------------------- round {round:03d} out of {self.num_rounds:03d} ---------------------')

            start = time.perf_counter()
            lossTrain, miouTrain = self.train_round()
            durationTrain = int(time.perf_counter() - start)
            print(f'-- loss (training): {lossTrain:.5f} -- miou (training): {100*miouTrain:.3f}% -- duration (training): {durationTrain}s')

            start = time.perf_counter()
            miouTestIdda = self.test(self.dataLoaderTestIdda)
            durationTestIdda = int(time.perf_counter() - start)
            print(f'-- miou (test - idda): {100*miouTestIdda:.3f}% -- duration (test - idda): {durationTestIdda}s')

            start = time.perf_counter()
            miouTestSame = self.test(self.dataLoaderTestSame)
            durationTestSame = int(time.perf_counter() - start)
            print(f'-- miou (test - same): {100*miouTestSame:.3f}% -- duration (test - same): {durationTestSame}s')

            start = time.perf_counter()
            miouTestDiff = self.test(self.dataLoaderTestDiff)
            durationTestDiff = int(time.perf_counter() - start)
            print(f'-- miou (test - diff): {100*miouTestDiff:.3f}% -- duration (test - diff): {durationTestDiff}s')

            ## STORAGE
            print('-- storing data')
            with open(self.pathStorage+'/'+'metrics.csv', 'a') as file:
                csv_writer = csv.writer(file)
                csv_writer.writerow([round, lossTrain, miouTrain, miouTestSame, miouTestIdda, miouTestDiff, durationTrain, durationTestIdda, durationTestSame, durationTestDiff])
                file.close()

            torch.save(self.modelTeacher.state_dict(), self.pathStorage+f'/model-teacher.pth')
            torch.save(self.modelStudent.state_dict(), self.pathStorage+f'/model-student.pth')

            with open(self.pathStorage+'/'+'log.txt', 'a') as file:
                file.write(f'-- models were overwritten on '+datetime.datetime.now().strftime('%Y-%m-%d at %H:%M:%S')+f' by {self.notebookName}\n')
                file.write(f'   models: [teacher, student]\n')
                file.write(f'   last successful round: {round}\n')
                file.close()
            ## STORAGE

    def train_round(self):

        # clients
        active_clients = self.select_clients()
        num_epochs_per_client = {client.name:self.max_num_epochs_per_client for client in active_clients}

        # updating the teacher model
        if self.updatePeriod != 0:
            if (self.round % self.updatePeriod) == 0:
                self.modelTeacher.load_state_dict(copy.deepcopy(self.modelStudent.state_dict()))

        # making a backup of the student model
        backup = copy.deepcopy(self.modelStudent.state_dict())

        # declaring initial values
        self.client_importances = []
        self.client_states = []

        # round (loop over clients)
        lossCum = 0.0
        miouCum = 0.0
        pbar = tqdm.tqdm(
            total = sum([num_epochs_per_client[client.name]*len(client.dataLoaderTrain) for client in active_clients]),
            desc = 'training'
        )
        for client_index, client in enumerate(active_clients):

            # restoring the student model
            clearCache()
            self.modelStudent.load_state_dict(copy.deepcopy(backup))

            # sending the model to the client
            pbar.client = f'{client_index+1}/{len(active_clients)}'
            loss, miou = client.train(
                modelTeacher = self.modelTeacher,
                modelStudent = self.modelStudent,
                num_epochs = num_epochs_per_client[client.name],
                pbar = pbar,
                threshold = self.threshold,
                optimizer = torch.optim.SGD(self.modelStudent.parameters(), lr=self.lr, **self.optimizer_dict),
                criterion = nn.CrossEntropyLoss(ignore_index=255),
                metric = self.metricClass(n_classes=21, name='miou')
            )

            # round outputs
            lossCum = (client_index*lossCum + loss)/(client_index + 1)
            miouCum = (client_index*miouCum + miou)/(client_index + 1)

            # storing the proposed importance of the client
            self.client_importances.append(float(client.importance))

            # storing the state of the client
            state = {}
            for key, tensor in self.modelStudent.state_dict().items():
                state[key] = copy.deepcopy(tensor.to('cpu'))
            self.client_states.append(state)

        # closing the progress bar
        pbar.close()

        # stepping the scheduler
        self.step()

        # calling the aggregator
        self.aggregate()

        # returning the results
        return lossCum, miouCum

    def aggregate(self):

        # normalizing the client importances
        importances = torch.tensor(self.client_importances, dtype=torch.float, device=self.device)
        importances /= importances.sum()

        # declaring an aggregate
        clearCache()
        agg = {}
        for key, tensor in self.modelStudent.state_dict().items():
            agg[key] = torch.zeros(tensor.shape, device=self.device)

        # aggregating
        for i, state in enumerate(self.client_states):
            for key, tensor in state.items():
                agg[key] += importances[i].item()*(copy.deepcopy(tensor).to(self.device))

        # updating the model
        self.modelStudent.load_state_dict(copy.deepcopy(agg))

    def test(self, dataLoader):
        clearCache()
        metric = self.metricClass(n_classes=21, name='miou')
        with torch.no_grad():
            self.modelStudent.eval()
            pbar = tqdm.tqdm(total=len(dataLoader), desc='testing')
            for batch, (images, labels) in enumerate(dataLoader):
                pbar.set_postfix({
                    'batch': f'{batch+1}/{len(dataLoader)}'
                })
                images = images.to(self.device)
                labels = labels.to(self.device)
                outputs = self.modelStudent(images)['out']
                _, prediction = outputs.max(dim=1)
                metric.update(labels.cpu().numpy(), prediction.cpu().numpy())
                pbar.update(1)
        pbar.close()
        miouCum = metric.get_results()['Mean IoU']
        return miouCum

# Different Things

In [16]:
# params class
class Params:
    def __init__(self, **args):
        for key, value in args.items():
            setattr(self, key, value)

In [17]:
# names reader
def getFileNames(root, containerName):
    fileNames = []
    with open(os.path.join(root, containerName), 'r') as file:
        for line in file.read().splitlines():
            fileNames.append(line)
    return fileNames

In [18]:
# main function
def main(pathStorage, params=None, config=None):
    print('path: '+pathStorage)

    # storage
    if os.path.exists(pathStorage):
        with open(pathStorage+'/log.txt') as file:
            lines = file.readlines()
            lastRoundLine = lines[-1]
            lastRound = lastRoundLine[lastRoundLine.find(':')+2:-1]
            file.close()

        with open(pathStorage+'/params.json') as file:
            params = Params(**json.load(file))
            file.close()

        clearCache()
        modelTeacher = deeplabv3.deeplabv3_mobilenetv2().to(params.device)
        modelStudent = deeplabv3.deeplabv3_mobilenetv2().to(params.device)
        modelTeacher.load_state_dict(torch.load(pathStorage+'/'+'model-teacher.pth'))
        modelStudent.load_state_dict(torch.load(pathStorage+'/'+'model-student.pth'))

        with open(pathStorage+'/log.txt', 'a') as file:
            file.write(f'-- models were loaded on '+datetime.datetime.now().strftime('%Y-%m-%d at %H:%M:%S')+f' by {notebookName}\n')
            file.write(f'   models: [teacher, student]\n')
            file.write(f'   last successful round: {lastRound}\n')
            file.close()
        lastRound = -1 if lastRound == 'n/a' else int(lastRound)
        if (lastRound == params.num_rounds - 1):
            print('-- this config is finished!')
            return

    else:
        for key, value in config.items():
            if hasattr(params, key):
                setattr(params, key, value)
            else:
                raise RuntimeError(f'Make sure each key of `config` is already an attribute of `params`! The key `{key}` does not exist!')

        os.makedirs(pathStorage, exist_ok=True)

        with open(pathStorage+'/params.json', 'w') as file:
            json.dump(params.__dict__, file, indent=4)
            file.close()

        with open(pathStorage+'/config.json', 'w') as file:
            json.dump(dict(config), file, indent=4)
            file.close()

        with open(pathStorage+'/metrics.csv', 'w', newline='') as file:
            csv_writer = csv.writer(file)
            csv_writer.writerow(['round', 'lossTrain', 'miouTrain', 'miouIdda', 'miouTestSame', 'miouTestDiff', 'durationTrain', 'durationIdda', 'durationTestSame', 'durationTestDiff'])
            file.close()

        clearCache()
        modelTeacher = deeplabv3.deeplabv3_mobilenetv2().to(params.device)
        modelStudent = deeplabv3.deeplabv3_mobilenetv2().to(params.device)
        modelTeacher.load_state_dict(torch.load(params.rootPretrained+'/model_pretrained_step_3_2.pth'))
        modelStudent.load_state_dict(torch.load(params.rootPretrained+'/model_pretrained_step_3_2.pth'))

        torch.save(modelTeacher.state_dict(), pathStorage+'/model-teacher.pth')
        torch.save(modelStudent.state_dict(), pathStorage+'/model-student.pth')
        with open(pathStorage+'/log.txt', 'w') as file:
            file.write('-- models were created on '+datetime.datetime.now().strftime('%Y-%m-%d at %H:%M:%S')+f' by {notebookName}\n')
            file.write('   models: [teacher, student]\n')
            file.write('   last successful round: n/a\n')
            file.close()
        lastRound = -1

    # transformers
    transformsTrain = tr.Compose([
        tr.RandomResizedCrop(size=tuple(params.transformer_imageSize),scale=params.transformer_scale),
        tr.ColorJitter(*params.transformer_jitter),
        tr.RandomHorizontalFlip(),
        tr.ToTensor(),
        tr.Normalize(tuple(params.transformer_means), tuple(params.transformer_stds))
    ])
    transformsTest = tr.Compose([
        tr.ToTensor(),
        tr.Normalize(tuple(params.transformer_means), tuple(params.transformer_stds))
    ])

    # clients
    with open(params.rootIdda+'/train.json') as file:
        clientsInfo = json.load(file)
        file.close

    clients = []
    for clientName, fileNames in clientsInfo.items():
        client = Client(
            device = params.device,
            name = clientName,
            datasetTrain = IDDADataset(
                root = params.rootIdda,
                fileNames = fileNames,
                transform = transformsTrain
            ),
            batchSizeTrain = params.batchSizeTrain
        )
        clients.append(client)

    # server
    server = Server(
        device = params.device,
        modelTeacher = modelTeacher,
        modelStudent = modelStudent,
        clients = clients,
        datasetTestIdda = IDDADataset(
            root = params.rootIdda,
            fileNames = getFileNames(params.rootIdda, 'train.txt'),
            transform = transformsTest
        ),
        datasetTestSame = IDDADataset(
            root = params.rootIdda,
            fileNames = getFileNames(params.rootIdda, 'test_same_dom.txt'),
            transform = transformsTest
        ),
        datasetTestDiff = IDDADataset(
            root = params.rootIdda,
            fileNames = getFileNames(params.rootIdda, 'test_diff_dom.txt'),
            transform = transformsTest
        ),
        batchSizeTest = params.batchSizeTest,
        metricClass = StreamSegMetrics,
        num_rounds = params.num_rounds,
        max_num_clients_per_round = params.max_num_clients_per_round,
        max_num_epochs_per_client = params.max_num_epochs_per_client,
        threshold = params.threshold,
        updatePeriod = params.updatePeriod,
        scheduler_dict = {
            'lr_initial': params.scheduler_lr_initial,
            'lr_min':  params.scheduler_lr_min,
            'T_max': params.scheduler_T_max
        },
        optimizer_dict = {
            'momentum': params.optimizer_momentum,
            'weight_decay': params.optimizer_weight_decay
        },
        notebookName = notebookName,                                            # STORAGE
        pathStorage = pathStorage,                                              # STORAGE
        lastRound = lastRound                                                   # STORAGE
    )

    # train federated learning
    server.train()

# Server Run

In [19]:
# params
## the intitial parameters (default config)
## if there's a config value for a parameter, it'll be overwritten
params = Params(
    device = 'cuda:0',
    rootIdda = '/content/data/idda',
    rootPretrained = rootMldl+'/pretrained',
    rootStorage = rootMldl+'/storage/step4/part2',
    batchSizeTrain = 3,
    batchSizeTest = 3,
    transformer_imageSize = [1080, 1920],
    transformer_scale = [0.25, 1],
    transformer_jitter = [0.4, 0.4, 0.5, 0.1],
    transformer_means = [0.320888, 0.292300, 0.288562],
    transformer_stds  = [0.250606, 0.248234, 0.253670],
    num_rounds = 20,                                                            # server
    max_num_clients_per_round = 7,                                              # server
    max_num_epochs_per_client = 5,                                              # client
    threshold = 0.9,                                                            # server (teacher model threshold)
    updatePeriod = 10,                                                          # server (teacher model update period; 0 means no update)
    scheduler_lr_initial = 0.01,                                                # server (scheduler parameter)
    scheduler_lr_min = 0,                                                       # server (scheduler parameter)
    scheduler_T_max = 30,                                                       # server (scheduler parameter)
    optimizer_momentum = 0.65,                                                  # client (optimzer parameter)
    optimizer_weight_decay = 0.0005                                             # client (optimzer parameter)
)

In [None]:
# copies
print('Copying IDDA data ...')
shutil.copytree(rootMldl+'/data/idda', params.rootIdda, dirs_exist_ok=True)

In [None]:
# configs
## make sure each key of `config` is already an attribute of `params`
config = {
    'updatePeriod': 10,
    'max_num_clients_per_round': 8,
    'max_num_epochs_per_client': 1,
}

# storage
folderName = []
for key, value in config.items():
    args = key.split('_')
    folderName.append((''.join([arg[:2].capitalize() for arg in args]))+'='+str(value))
folderName = ','.join(folderName)

# driver
pathStorage = params.rootStorage+'/'+folderName
if os.path.exists(params.rootStorage) and (folderName in os.listdir(params.rootStorage)):
    main(pathStorage)
else:
    main(
        params = params,
        config = config,
        pathStorage = params.rootStorage+'/'+folderName
    )