# Storage

In [None]:
## the path to the `mldl2023` folder in your drive
rootMldl = '<your_drive>/mldl2023'

# Initialization

In [None]:
# packages 1
import shutil
import os
import torch

if not torch.cuda.is_available():
    raise RuntimeError('The model cannot operate without CUDA!')

In [None]:
# getting the notebook's name
from requests import get
from socket import gethostname, gethostbyname
notebookName = get(f'http://{gethostbyname(gethostname())}:9000/api/sessions').json()[0]['name']

In [None]:
# mounting the drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# changing the root
os.chdir(rootMldl)

In [None]:
# packages 2
import random
import string
from typing import Any
from typing import List
import numpy as np
from PIL import Image
from torch import from_numpy
from torchvision.datasets import VisionDataset
import datasets.ss_transforms as tr
from utils.stream_metrics import StreamSegMetrics
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from models import deeplabv3, mobilenetv2
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch, os, copy
import tqdm.notebook as tqdm
import gc
from utils.utils import HardNegativeMining, MeanReduction
import math
import json
import csv
from pprint import pprint
from torch import from_numpy
import datetime
import time

# Memory Management

In [None]:
# packages
!pip install gputil
import prettytable
import psutil
import GPUtil

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting gputil
  Downloading GPUtil-1.4.0.tar.gz (5.5 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: gputil
  Building wheel for gputil (setup.py) ... [?25l[?25hdone
  Created wheel for gputil: filename=GPUtil-1.4.0-py3-none-any.whl size=7393 sha256=cf40a7455ae00be0d2cbfc9bcce3c9f8712fcc821e59d64879f40b5fbcbc8b13
  Stored in directory: /root/.cache/pip/wheels/a9/8a/bd/81082387151853ab8b6b3ef33426e98f5cbfebc3c397a9d4d0
Successfully built gputil
Installing collected packages: gputil
Successfully installed gputil-1.4.0


In [None]:
# garbage collector
def clearCache():
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
# memory scanner
def printMemoryUsage(title='Memory status'):
    diskStatus = shutil.disk_usage('/content/')
    table = prettytable.PrettyTable(['type', 'available (MB)', 'used (MB)', 'free (MB)'])
    for field in table.field_names:
        table.align[field] = 'l'
    table.add_row([
        'disk',
        round((diskStatus[1]+diskStatus[2])/(1024*1024), 1),
        round((diskStatus[1])/(1024*1024), 1),
        round((diskStatus[2])/(1024*1024), 1)
    ])
    table.add_row([
        'ram',
        round((psutil.virtual_memory().available+psutil.virtual_memory().used)/(1024*1024), 1),
        round(psutil.virtual_memory().used/(1024*1024), 1),
        round(psutil.virtual_memory().available/(1024*1024), 1)
    ])
    for i, gpu in enumerate(GPUtil.getGPUs()):
        table.add_row([
            f'gpu-{i} ram',
            round(gpu.memoryUsed+gpu.memoryFree, 1),
            round(gpu.memoryUsed, 1),
            round(gpu.memoryFree, 1)
        ])
    print(title+':')
    print(table)
    print('')

# Dataset Class

In [None]:
class IDDADataset(VisionDataset):

    @staticmethod
    def get_mapping():
        classes = [255, 2, 4, 255, 11, 5, 0, 0, 1, 8, 13, 3, 7, 6, 255, 255, 15, 14, 12, 9, 10]
        mapping = 255*np.ones(256, dtype=np.int64)
        mapping[range(len(classes))] = classes
        return lambda x: from_numpy(mapping[x])

    def __init__(self, root, fileNames, transform=None):
        super().__init__(root=root, transform=transform, target_transform=IDDADataset.get_mapping())
        self.fileNames = fileNames

    def __getitem__(self, index):
        image = Image.open(self.root+'/images/'+self.fileNames[index]+'.jpg').convert('RGB')
        label = Image.open(self.root+'/labels/'+self.fileNames[index]+'.png').convert('L')
        if self.transform is not None:
            image, label = self.transform(image, label)
        label = self.target_transform(label)
        return image, label

    def __len__(self):
        return len(self.fileNames)

# Client class

In [None]:
class Client:

    def __init__(self,
                 device, name,
                 datasetTrain,
                 batchSizeTrain):

        self.device = device
        self.name = name

        self.dataLoaderTrain = DataLoader(datasetTrain, batch_size=batchSizeTrain, shuffle=True, drop_last=True)
        self.importance = len(datasetTrain)

    def train(self, model, num_epochs, pbar, optimizer, criterion, metric):
        clearCache()
        model.train()

        lossCum = 0.0
        for epoch in range(num_epochs):
            lossEpoch = 0.0
            for batch, (images, labels) in enumerate(self.dataLoaderTrain):
                pbar.set_postfix({
                    'cluster': pbar.cluster,
                    'epoch': f'{epoch+1}/{num_epochs}',
                    'batch': f'{batch+1}/{len(self.dataLoaderTrain)}'
                })

                images = images.to(self.device)
                labels = labels.to(self.device)

                optimizer.zero_grad()
                outputs = model(images)['out']
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                lossEpoch = (batch*lossEpoch + loss.item())/(batch + 1)
                _, prediction = outputs.max(dim=1)
                metric.update(labels.cpu().numpy(), prediction.cpu().numpy())

                pbar.update(1)

            lossCum = (epoch*lossCum + lossEpoch)/(epoch + 1)
        miouCum = metric.get_results()['Mean IoU']
        scoreNew = miouCum                                                      # score must be between zero and one

        return lossCum, miouCum, scoreNew

# Server class

In [None]:
class Server:

    def __init__(self,
                 device, model, clients,
                 datasetTestSame, datasetTestDiff,
                 batchSizeTest,
                 metricClass, num_rounds, max_num_clients_per_round, max_num_clusters_per_round, max_num_cycles_per_cluster, max_num_epochs_per_client,
                 cond_allocator_neutral_score, cond_allocator_cycle_is_active, cond_allocator_epoch_is_active, cond_allocator_method,
                 scheduler_dict, optimizer_dict,
                 notebookName, pathStorage, lastRound, scores, bestRoundSame, bestMiouSame, bestRoundDiff, bestMiouDiff):

        self.device = device
        self.model = model
        self.clients = clients

        self.dataLoaderTestSame = DataLoader(datasetTestSame, batch_size=batchSizeTest, shuffle=True, drop_last=True)
        self.dataLoaderTestDiff = DataLoader(datasetTestDiff, batch_size=batchSizeTest, shuffle=True, drop_last=True)

        self.metricClass = metricClass
        self.num_rounds = num_rounds
        self.max_num_clients_per_round = max_num_clients_per_round
        self.max_num_clusters_per_round = max_num_clusters_per_round
        self.max_num_cycles_per_cluster = max_num_cycles_per_cluster
        self.max_num_epochs_per_client = max_num_epochs_per_client

        self.cond_allocator_neutral_score = cond_allocator_neutral_score
        self.cond_allocator_cycle_is_active = cond_allocator_cycle_is_active
        self.cond_allocator_epoch_is_active = cond_allocator_epoch_is_active
        self.cond_allocator_method = cond_allocator_method

        # manually programmed cosine annealing scheduler
        self.scheduler = lambda round: scheduler_dict['lr_min'] + 0.5*(scheduler_dict['lr_initial']-scheduler_dict['lr_min'])*(1+math.cos(math.pi*round/scheduler_dict['T_max']))
        self.optimizer_dict = optimizer_dict

        # STORAGE
        self.notebookName = notebookName                                        # STORAGE
        self.pathStorage = pathStorage                                          # STORAGE
        self.initialRound = lastRound + 1                                       # STORAGE
        self.scores = scores                                                    # STORAGE
        self.bestRoundSame = bestRoundSame                                      # STORAGE
        self.bestMiouSame = bestMiouSame                                        # STORAGE
        self.bestRoundDiff = bestRoundDiff                                      # STORAGE
        self.bestMiouDiff = bestMiouDiff                                        # STORAGE
        # STORAGE

        self.lr = self.scheduler(self.initialRound)

    def step(self):
        self.round += 1
        self.lr = self.scheduler(self.round)

    def normalize_scores(self, scores):
        if sum([value is not None for value in scores.values()]) < 2:
            for key in scores:
                scores[key] = self.cond_allocator_neutral_score
            return
        minimum = float(np.nanmin(np.array(list(scores.values()), dtype=np.float64)))
        maximum = float(np.nanmax(np.array(list(scores.values()), dtype=np.float64)))
        if self.cond_allocator_method == 'higher-more':
            for key in scores:
                if scores[key] is None:
                    scores[key] = self.cond_allocator_neutral_score
                else:
                    scores[key] = (scores[key] - minimum)/(maximum - minimum)
        elif self.cond_allocator_method == 'lower-more':
            for key in scores:
                if scores[key] is None:
                    scores[key] = self.cond_allocator_neutral_score
                else:
                    scores[key] = (maximum - scores[key])/(maximum - minimum)
        else:
            raise RuntimeError(f'The normaliztion method `{self.cond_allocator_method}` does not exist!')

    def select_clients(self):
        num_clients = min(self.max_num_clients_per_round, len(self.clients))
        return random.sample(self.clients, num_clients)

    def assign_clients_to_clusters(self, active_clients):

        # random clustering
        clustering = {}
        for i in range(self.max_num_clusters_per_round):
            cluster_name = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(5))
            cluster_clients =  active_clients[i::self.max_num_clusters_per_round]
            clustering[cluster_name] = cluster_clients

        # normalizing the client scores
        scores = {}
        for _, cluster_clients in clustering.items():
            for client in cluster_clients:
                scores[client.name] = self.scores[client.name] if client.name in self.scores else None
        self.normalize_scores(scores)

        # allocating cycles to clusters
        num_cycles_per_cluster = {}
        for cluster_name, cluster_clients in clustering.items():
            if self.cond_allocator_cycle_is_active:
                cluster_scores = []
                for client in cluster_clients:
                    cluster_scores.append(scores[client.name])
                num_cycles_per_cluster[cluster_name] = math.floor((self.max_num_cycles_per_cluster-1)*sum(cluster_scores)/len(cluster_scores)+1)
            else:
                num_cycles_per_cluster[cluster_name] = math.floor(self.cond_allocator_neutral_score*(self.max_num_cycles_per_cluster-1)+1)

        # allocating epochs to eclients
        num_epochs_per_client = {}
        for _, cluster_clients in clustering.items():
            for client in cluster_clients:
                if self.cond_allocator_epoch_is_active:
                    num_epochs_per_client[client.name] = math.floor((self.max_num_epochs_per_client-1)*scores[client.name]+1)
                else:
                    num_epochs_per_client[client.name] = math.floor(self.cond_allocator_neutral_score*(self.max_num_epochs_per_client-1)+1)

        # printing some info
        info = []
        for cluster_name, cluster_clients in clustering.items():
            cluster_info = []
            for client in cluster_clients:
                cluster_info.append(client.name)
            info.append(cluster_name+'->['+', '.join(cluster_info)+']')
        print('-- clustering: '+', '.join(info))
        info = []
        for cluster_name, num_cycles in num_cycles_per_cluster.items():
            info.append(f'{cluster_name}->{num_cycles}')
        print('-- num of cycles per cluster: '+', '.join(info))
        info = []
        for client_name, num_epochs in num_epochs_per_client.items():
            info.append(f'{client_name}->{num_epochs}')
        print('-- num of epochs per client: '+', '.join(info))

        return clustering, num_cycles_per_cluster, num_epochs_per_client

    def train(self):
        self.round = self.initialRound
        for round in range(self.initialRound, self.num_rounds):
            print(f'--------------------- round {round:03d} out of {self.num_rounds:03d} ---------------------')

            start = time.perf_counter()
            lossTrain, miouTrain = self.train_round()
            durationTrain = int(time.perf_counter() - start)
            print(f'-- loss (training): {lossTrain:.5f} -- miou (training): {100*miouTrain:.3f}% -- duration (training): {durationTrain}s')

            start = time.perf_counter()
            miouTestSame = self.test(self.dataLoaderTestSame)
            durationTestSame = int(time.perf_counter() - start)
            print(f'-- miou (test - same): {100*miouTestSame:.3f}% -- duration (test - same): {durationTestSame}s')

            start = time.perf_counter()
            miouTestDiff = self.test(self.dataLoaderTestDiff)
            durationTestDiff = int(time.perf_counter() - start)
            print(f'-- miou (test - diff): {100*miouTestDiff:.3f}% -- duration (test - diff): {durationTestDiff}s')

            ## STORAGE
            print('-- storing data')
            with open(self.pathStorage+'/'+'metrics.csv', 'a') as file:
                csv_writer = csv.writer(file)
                csv_writer.writerow([round, lossTrain, miouTrain, miouTestSame, miouTestDiff, durationTrain, durationTestSame, durationTestDiff])
                file.close()

            with open(pathStorage+'/scores.json', 'w') as file:
                json.dump(self.scores, file, indent=4)
                file.close()

            torch.save(self.model.state_dict(), self.pathStorage+f'/model-main.pth')
            models = ['main']
            if miouTestSame > self.bestMiouSame:
                self.bestRoundSame = round
                self.bestMiouSame = miouTestSame
                torch.save(self.model.state_dict(), self.pathStorage+f'/model-bestSame.pth')
                models.append('bestSame')
            if miouTestDiff > self.bestMiouDiff:
                self.bestRoundDiff = round
                self.bestMiouDiff = miouTestDiff
                torch.save(self.model.state_dict(), self.pathStorage+f'/model-bestDiff.pth')
                models.append('bestDiff')

            with open(self.pathStorage+'/'+'log.txt', 'a') as file:
                file.write(f'-- models were overwritten on '+datetime.datetime.now().strftime('%Y-%m-%d at %H:%M:%S')+f' by {self.notebookName}\n')
                file.write(f'   models: [{", ".join(models)}]\n')
                file.write(f'   last successful round: {round}\n')
                file.write(f'   best record of same: [{self.bestRoundSame}, {self.bestMiouSame}]\n')
                file.write(f'   best record of diff: [{self.bestRoundDiff}, {self.bestMiouDiff}]\n')
                file.close()
            ## STORAGE

    def train_round(self):

        # clients and clusters
        active_clients = self.select_clients()
        clustering, num_cycles_per_cluster, num_epochs_per_client = self.assign_clients_to_clusters(active_clients)

        # making a backup of the model
        backup = copy.deepcopy(self.model.state_dict())

        # declaring initial values
        self.cluster_importances = []
        self.cluster_states = []

        # round (loop over clusters)
        lossCum = 0.0
        miouCum = 0.0
        pbar = tqdm.tqdm(
            total = sum([
                num_cycles_per_cluster[cluster_name]*sum([
                    num_epochs_per_client[client.name]*len(client.dataLoaderTrain) for client in cluster_clients
                ]) for (cluster_name, cluster_clients) in clustering.items()
            ]),
            desc = 'training'
        )
        for cluster_index, (cluster_name, cluster_clients) in enumerate(clustering.items()):

            # restoring the model
            clearCache()
            self.model.load_state_dict(copy.deepcopy(backup))

            # cluster (loop over cycles)
            lossCluster = 0.0
            miouCluster = 0.0
            for cycle_index in range(num_cycles_per_cluster[cluster_name]):

                # cycle (loop over clients)
                lossCycle = 0.0
                miouCycle = 0.0
                for client_index, client in enumerate(cluster_clients):

                    # sending the model to the client
                    pbar.cluster = f'{cluster_index+1}/{len(clustering)}, cycle={cycle_index+1}/{num_cycles_per_cluster[cluster_name]}, client={client_index+1}/{len(cluster_clients)}'
                    loss, miou, self.scores[client.name] = client.train(
                        model = self.model,
                        num_epochs = num_epochs_per_client[client.name],
                        pbar = pbar,
                        optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, **self.optimizer_dict),
                        criterion = nn.CrossEntropyLoss(ignore_index=255),
                        metric = self.metricClass(n_classes=21, name='miou')
                    )

                    # cycle outputs
                    lossCycle = (client_index*lossCycle + loss)/(client_index + 1)
                    miouCycle = (client_index*miouCycle + miou)/(client_index + 1)

                # cluster outputs
                lossCluster = (cycle_index*lossCluster + lossCycle)/(cycle_index + 1)
                miouCluster = (cycle_index*miouCluster + miouCycle)/(cycle_index + 1)

            # storing the proposed importance of the cluster
            importance = 0.0
            for client in cluster_clients:
                importance += client.importance
            self.cluster_importances.append(importance)

            # storing the state of the cluster
            state = {}
            for key, tensor in self.model.state_dict().items():
                state[key] = copy.deepcopy(tensor.to('cpu'))
            self.cluster_states.append(state)

            # round outputs
            lossCum = (cluster_index*lossCum + lossCluster)/(cluster_index + 1)
            miouCum = (cluster_index*miouCum + miouCluster)/(cluster_index + 1)

        # closing the progress bar
        pbar.close()

        # stepping the scheduler
        self.step()

        # calling the aggregator
        self.aggregate()

        # returning the results
        return lossCum, miouCum

    def aggregate(self):

        # normalizing the cluster importances
        importances = torch.tensor(self.cluster_importances, dtype=torch.float, device=self.device)
        importances /= importances.sum()

        # declaring an aggregate
        clearCache()
        agg = {}
        for key, tensor in self.model.state_dict().items():
            agg[key] = torch.zeros(tensor.shape, device=self.device)

        # aggregating
        for i, state in enumerate(self.cluster_states):
            for key, tensor in state.items():
                agg[key] += importances[i].item()*(copy.deepcopy(tensor).to(self.device))

        # updating the model
        self.model.load_state_dict(copy.deepcopy(agg))

    def test(self, dataLoader):
        clearCache()
        metric = self.metricClass(n_classes=21, name='miou')
        with torch.no_grad():
            self.model.eval()
            pbar = tqdm.tqdm(total=len(dataLoader), desc='testing')
            for batch, (images, labels) in enumerate(dataLoader):
                pbar.set_postfix({
                    'batch': f'{batch+1}/{len(dataLoader)}'
                })
                images = images.to(self.device)
                labels = labels.to(self.device)
                outputs = self.model(images)['out']
                _, prediction = outputs.max(dim=1)
                metric.update(labels.cpu().numpy(), prediction.cpu().numpy())
                pbar.update(1)
        pbar.close()
        miouCum = metric.get_results()['Mean IoU']
        return miouCum

# Different Things

In [None]:
# params class
class Params:
    def __init__(self, **args):
        for key, value in args.items():
            setattr(self, key, value)

In [None]:
# names reader
def getFileNames(root, containerName):
    fileNames = []
    with open(os.path.join(root, containerName), 'r') as file:
        for line in file.read().splitlines():
            fileNames.append(line)
    return fileNames

In [None]:
# main function
def main(pathStorage, params=None, config=None):
    print('path: '+pathStorage)

    # storage
    if os.path.exists(pathStorage):
        with open(pathStorage+'/log.txt') as file:
            lines = file.readlines()
            lastRoundLine = lines[-3]
            recordSameLine = lines[-2]
            recordDiffLine = lines[-1]
            lastRound = lastRoundLine[lastRoundLine.find(':')+2:-1]
            bestRoundSame = int(recordSameLine[recordSameLine.find(':')+3:recordSameLine.find(',')])
            bestMiouSame = float(recordSameLine[recordSameLine.find(',')+2:-2])
            bestRoundDiff = int(recordDiffLine[recordDiffLine.find(':')+3:recordDiffLine.find(',')])
            bestMiouDiff = float(recordDiffLine[recordDiffLine.find(',')+2:-2])
            file.close()

        with open(pathStorage+'/params.json') as file:
            params = Params(**json.load(file))
            file.close()

        with open(pathStorage+'/scores.json') as file:
            scores = json.load(file)
            file.close()

        clearCache()
        model = deeplabv3.deeplabv3_mobilenetv2().to(params.device)
        model.load_state_dict(torch.load(pathStorage+'/'+'model-main.pth'))

        with open(pathStorage+'/log.txt', 'a') as file:
            file.write(f'-- models were loaded on '+datetime.datetime.now().strftime('%Y-%m-%d at %H:%M:%S')+f' by {notebookName}\n')
            file.write(f'   models: [main]\n')
            file.write(f'   last successful round: {lastRound}\n')
            file.write(recordSameLine)
            file.write(recordDiffLine)
            file.close()
        lastRound = -1 if lastRound == 'n/a' else int(lastRound)
        if (lastRound == params.num_rounds - 1):
            print('-- this config is finished!')
            return

    else:
        for key, value in config.items():
            if hasattr(params, key):
                setattr(params, key, value)
            else:
                raise RuntimeError(f'Make sure each key of `config` is already an attribute of `params`! The key `{key}` does not exist!')

        os.makedirs(pathStorage, exist_ok=True)

        with open(pathStorage+'/params.json', 'w') as file:
            json.dump(params.__dict__, file, indent=4)
            file.close()

        with open(pathStorage+'/config.json', 'w') as file:
            json.dump(dict(config), file, indent=4)
            file.close()

        with open(pathStorage+'/metrics.csv', 'w', newline='') as file:
            csv_writer = csv.writer(file)
            csv_writer.writerow(['round', 'lossTrain', 'miouTrain', 'miouTestSame', 'miouTestDiff', 'durationTrain', 'durationTestSame', 'durationTestDiff'])
            file.close()

        scores = {}
        with open(pathStorage+'/scores.json', 'w') as file:
            json.dump(scores, file, indent=4)
            file.close()

        clearCache()
        model = deeplabv3.deeplabv3_mobilenetv2().to(params.device)

        torch.save(model.state_dict(), pathStorage+'/model-main.pth')
        torch.save(model.state_dict(), pathStorage+'/model-bestSame.pth')
        torch.save(model.state_dict(), pathStorage+'/model-bestDiff.pth')
        with open(pathStorage+'/log.txt', 'w') as file:
            file.write('-- models were created on '+datetime.datetime.now().strftime('%Y-%m-%d at %H:%M:%S')+f' by {notebookName}\n')
            file.write('   models: [main, bestSame, bestDiff]\n')
            file.write('   last successful round: n/a\n')
            file.write('   best record of same: [0, 0.0]\n')
            file.write('   best record of diff: [0, 0.0]\n')
            file.close()
        lastRound = -1
        bestRoundSame = 0
        bestMiouSame = 0.0
        bestRoundDiff = 0
        bestMiouDiff = 0.0

    # transformers
    transformsTrain = tr.Compose([
        tr.RandomResizedCrop(size=tuple(params.transformer_imageSize),scale=params.transformer_scale),
        tr.ColorJitter(*params.transformer_jitter),
        tr.RandomHorizontalFlip(),
        tr.ToTensor(),
        tr.Normalize(tuple(params.transformer_means), tuple(params.transformer_stds))
    ])
    transformsTest = tr.Compose([
        tr.ToTensor(),
        tr.Normalize(tuple(params.transformer_means), tuple(params.transformer_stds))
    ])

    # clients
    with open(params.rootIdda+'/train.json') as file:
        clientsInfo = json.load(file)
        file.close

    clients = []
    for clientName, fileNames in clientsInfo.items():
        client = Client(
            device = params.device,
            name = clientName,
            datasetTrain = IDDADataset(
                root = params.rootIdda,
                fileNames = fileNames,
                transform = transformsTrain
            ),
            batchSizeTrain = params.batchSizeTrain
        )
        clients.append(client)

    # server
    server = Server(
        device = params.device,
        model = model,
        clients = clients,
        datasetTestSame = IDDADataset(
            root = params.rootIdda,
            fileNames = getFileNames(params.rootIdda, 'test_same_dom.txt'),
            transform = transformsTest
        ),
        datasetTestDiff = IDDADataset(
            root = params.rootIdda,
            fileNames = getFileNames(params.rootIdda, 'test_diff_dom.txt'),
            transform = transformsTest
        ),
        batchSizeTest = params.batchSizeTest,
        metricClass = StreamSegMetrics,
        num_rounds = params.num_rounds,
        max_num_clients_per_round = params.max_num_clients_per_round,
        max_num_clusters_per_round = params.max_num_clusters_per_round,
        max_num_cycles_per_cluster = params.max_num_cycles_per_cluster,
        max_num_epochs_per_client = params.max_num_epochs_per_client,
        cond_allocator_neutral_score = params.cond_allocator_neutral_score,
        cond_allocator_cycle_is_active = params.cond_allocator_cycle_is_active,
        cond_allocator_epoch_is_active = params.cond_allocator_epoch_is_active,
        cond_allocator_method = params.cond_allocator_method,
        scheduler_dict = {
            'lr_initial': params.scheduler_lr_initial,
            'lr_min':  params.scheduler_lr_min,
            'T_max': params.scheduler_T_max
        },
        optimizer_dict = {
            'momentum': params.optimizer_momentum,
            'weight_decay': params.optimizer_weight_decay
        },
        notebookName = notebookName,                                            # STORAGE
        pathStorage = pathStorage,                                              # STORAGE
        lastRound = lastRound,                                                  # STORAGE
        scores = scores,                                                        # STORAGE
        bestRoundSame = bestRoundSame,                                          # STORAGE
        bestMiouSame = bestMiouSame,                                            # STORAGE
        bestRoundDiff = bestRoundDiff,                                          # STORAGE
        bestMiouDiff = bestMiouDiff                                             # STORAGE
    )

    # train federated learning
    server.train()

# Server Driver

In [None]:
# params
## the intitial parameters (default config)
## if there's a config value for a parameter, it'll be overwritten
params = Params(
    device = 'cuda:0',
    rootIdda = '/content/data/idda',
    rootStorage = rootMldl+'/storage/step5',
    batchSizeTrain = 3,
    batchSizeTest = 3,
    transformer_imageSize = [1080, 1920],
    transformer_scale = [0.25, 1],
    transformer_jitter = [0.4, 0.4, 0.5, 0.1],
    transformer_means = [0.320888, 0.292300, 0.288562],
    transformer_stds  = [0.250606, 0.248234, 0.253670],
    num_rounds = 200,                                                           # server
    max_num_clients_per_round = 5,                                              # server
    max_num_clusters_per_round = 3,                                             # server
    max_num_cycles_per_cluster = 1,                                             # cluster
    max_num_epochs_per_client = 1,                                              # client
    cond_allocator_neutral_score = 0.5,                                         # server (conditional allocator parameter)
    cond_allocator_cycle_is_active = True,                                      # server (conditional allocator parameter)
    cond_allocator_epoch_is_active = True,                                      # server (conditional allocator parameter)
    cond_allocator_method = 'higher-more',                                      # server (conditional allocator parameter)
    scheduler_lr_initial = 0.1,                                                 # server (scheduler parameter)
    scheduler_lr_min = 0,                                                       # server (scheduler parameter)
    scheduler_T_max = 300,                                                      # server (scheduler parameter)
    optimizer_momentum = 0.65,                                                  # client (optimzer parameter)
    optimizer_weight_decay = 0.0005                                             # client (optimzer parameter)
)

In [None]:
# copies
print('Copying IDDA data ...')
shutil.copytree(rootMldl+'/data/idda', params.rootIdda, dirs_exist_ok=True)

In [None]:
# configs
## make sure each key of `config` is already an attribute of `params`
config = {
    'max_num_clients_per_round': 5,
    'max_num_clusters_per_round': 5,
    'max_num_cycles_per_cluster': 1,
    'max_num_epochs_per_client': 5,
    'cond_allocator_cycle_is_active': False,
    'cond_allocator_epoch_is_active': False,
    'cond_allocator_method': 'higher-more'
}

# storage
folderName = []
for key, value in config.items():
    args = key.split('_')
    folderName.append((''.join([arg[:2].capitalize() for arg in args]))+'='+str(value))
folderName = ','.join(folderName)

# driver
pathStorage = params.rootStorage+'/'+folderName
if os.path.exists(params.rootStorage) and (folderName in os.listdir(params.rootStorage)):
    main(pathStorage)
else:
    main(
        params = params,
        config = config,
        pathStorage = params.rootStorage+'/'+folderName
    )