# Multi-task Sequential Federated Learning on MedMNIST

In [None]:
! pip install medmnist
! pip3 install pyyaml

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cd /content/drive/MyDrive/FederatedLearning/MICCAI2024/

In [1]:
# Import libraries
import medmnist
from medmnist import INFO

import random
from tqdm import tqdm
import time
import numpy as np
from statistics import mean
import PIL
from itertools import chain
import yaml

import torch
import torch.nn as nn

# Models
from models.CNN import CNN5
from models.VGG11 import VGG11
from models.ResNet18 import ResNet_18, Block
from models.SimCNN import SimCNN

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

# Train and Test/validation
from utils.Training import Train
from utils.Evaluation import Evaluator

# data
from utils.StratifiedDatasets import StratifiedData
from utils.IIDDatasets import IIDData

# import EarlyStopping
from utils.pytorchtools import EarlyStopping

# making code reproducible
from utils.seedeverything import seedevrything

In [2]:
import os
os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = "1"

In [3]:
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

Using device: cuda



In [4]:
# Setting Up Random Seeds In PyTorch
seedevrything.set_seed(42)

Random seed set as 42


### Dataset loading

In [5]:
print(f"MedMNIST v{medmnist.__version__} @ {medmnist.HOMEPAGE}")

MedMNIST v3.0.1 @ https://github.com/MedMNIST/MedMNIST/


In [6]:
# getting all data flags in MedMNIST
medmnist.INFO

{'pathmnist': {'python_class': 'PathMNIST',
  'description': 'The PathMNIST is based on a prior study for predicting survival from colorectal cancer histology slides, providing a dataset (NCT-CRC-HE-100K) of 100,000 non-overlapping image patches from hematoxylin & eosin stained histological images, and a test dataset (CRC-VAL-HE-7K) of 7,180 image patches from a different clinical center. The dataset is comprised of 9 types of tissues, resulting in a multi-class classification task. We resize the source images of 3×224×224 into 3×28×28, and split NCT-CRC-HE-100K into training and validation set with a ratio of 9:1. The CRC-VAL-HE-7K is treated as the test set.',
  'url': 'https://zenodo.org/records/10519652/files/pathmnist.npz?download=1',
  'MD5': 'a8b06965200029087d5bd730944a56c1',
  'url_64': 'https://zenodo.org/records/10519652/files/pathmnist_64.npz?download=1',
  'MD5_64': '55aa9c1e0525abe5a6b9d8343a507616',
  'url_128': 'https://zenodo.org/records/10519652/files/pathmnist_128.np

## Load configs

In [7]:
# load configs data
with open("configs.yaml", 'r') as configs:
    configs = yaml.safe_load(configs)

### Dataset loading

In [8]:
data_flags = configs['dataset']['data_flags']

In [9]:
n_classes = 0
for dataset in data_flags:
  n_classes += len(medmnist.INFO[dataset]['label'])
print(f"The numebr of labels in all datasets is: {n_classes}")

The numebr of labels in all datasets is: 23


In [10]:
# get an instance from Data class
amount = 8000
# client_splitting_method
cmp = None

# SAMPLING = ['iid', 'noniid']
SAMPLING = "noniid"

if SAMPLING == 'noniid':
    # No-IID sampling
    dataset = StratifiedData(data_flags = data_flags, amount = amount, num_clients = 8)
    cmp = "ordered"
elif SAMPLING == 'iid':
    # IID sampling
    dataset = IIDData(data_flags = data_flags, amount = amount, split_method="stratify")
    cmp = 'equal'

# get the server data
train_dl_server, valid_dl_server, test_dl_server = dataset.get_server_data()
# get clients data
x_train_dict, y_train_dict, x_valid_dict, y_valid_dict, x_test_dict, y_test_dict = dataset.get_clients_data(client_splitting_method=cmp)

Using downloaded and verified file: /home/srahmani/.medmnist/octmnist.npz
Using downloaded and verified file: /home/srahmani/.medmnist/organamnist.npz
Using downloaded and verified file: /home/srahmani/.medmnist/tissuemnist.npz


### Classification Model

In [11]:
def get_model(model_name):
  if model_name == 'cnn5':
    return CNN5(num_classes=n_classes)
  if model_name == 'vgg11':
    return VGG11(num_classes=n_classes)
  if model_name == 'simcnn':
    return SimCNN(num_classes=n_classes)
  elif model_name == 'resnet18':
    return ResNet_18(num_classes=n_classes)

In [12]:
trainer = Train(n_classes)
validator = Evaluator(n_classes)

## Functions for Federated

In [13]:
def create_model_optimizer_criterion_dict(number_of_samples):
  '''
  creates three dictionaries of models, optimizers, and criteria for all the clients!

  @param: number_of_samples  --> number of clients

  return: model_dict, optimizer_dict, criterion_dict
  '''
  model_dict = dict()
  optimizer_dict= dict()
  criterion_dict = dict()

  for i in range(number_of_samples):
    model_name="model"+str(i)
    model_info = get_model(model_name=MODEL)
    model_info.to(device)
    model_dict.update({model_name : model_info })

    optimizer_name = "optimizer"+str(i)
    optimizer_info = torch.optim.SGD(model_info.parameters(), lr=learning_rate, momentum=momentum)
    optimizer_dict.update({optimizer_name : optimizer_info })

    criterion_name = "criterion"+str(i)
    criterion_info = nn.CrossEntropyLoss()
    criterion_dict.update({criterion_name : criterion_info})

  return model_dict, optimizer_dict, criterion_dict

In [14]:
def send_main_model_to_nodes_and_update_model_dict(main_model, model_dict, number_of_samples):
  '''
  sends main_model to clients and updates clients model (model_dict).

  @param: main_model --> the server model
  @param: model_dict --> dictionary of all client's models
  @param: number_of_samples --> number of clients

  return: model_dict: dictionary including client's model
  '''


  """
  no_grad() function is a context manager that disables gradient calculation.
  This is useful for inference when you are sure you will not call tensor.
  backward().
  """
  with torch.no_grad():
    for client_id in range(number_of_samples):
      # Clone the weights from the main_model to the client's model (model_dict)
      # model_name is only a list of the names of the models such as model0, model1
      model_dict[name_of_models[client_id]].load_state_dict(main_model.state_dict())
  return model_dict

In [15]:
def start_train_end_node_process_print_some(number_of_samples, print_amount):
  '''
  This function starts training all clients in parallel and prints train&test accuracy, loss dictionary which includes loss for each epoch and slope.

  @param: number_of_samples --> number of clients
  @param: print_amount

  return: loss_dict,loss_lst_per_sample, slopes
          loss_lst_per_sample: list of client's loss per epoch
          loss_dict: a dictionary of the client's mean loss (we get the average from the loss list for epochs)
          slopes: a dictionary of the client's mean slope

  '''
  loss_dict = dict()
  slopes = dict()

  for i in range(number_of_samples):
    train_ds = TensorDataset(x_train_dict[name_of_x_train_sets[i]], y_train_dict[name_of_y_train_sets[i]])
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

    valid_ds = TensorDataset(x_valid_dict[name_of_x_valid_sets[i]], y_valid_dict[name_of_y_valid_sets[i]])
    valid_dl = DataLoader(valid_ds, batch_size=batch_size, shuffle=True)

    test_ds = TensorDataset(x_test_dict[name_of_x_test_sets[i]], y_test_dict[name_of_y_test_sets[i]])
    test_dl = DataLoader(test_ds, batch_size= batch_size)
    # test_dl = DataLoader(test_ds, batch_size= batch_size * 2)

    model = model_dict[name_of_models[i]]
    # print_model_weights(model_dict[name_of_models[i]])
    # print_model_weights(model)
    # print(model.features[0])
    criterion = criterion_dict[name_of_criterions[i]]
    optimizer = optimizer_dict[name_of_optimizers[i]]

    if i < print_amount:
      print("Subset", i)

    loss_lst_per_sample = []
    train_acc = []
    valid_acc = []

    for epoch in range(5):
      train_loss_batch_list, train_accuracy, train_accuracy_skl, train_sens, train_spec, train_f1 = trainer.train_with_batch_loss(model, train_dl, criterion, optimizer)
      valid_loss, valid_accuracy, valid_accuracy_skl, valid_sens, valid_spec, valid_f1 =  validator.validation(model, valid_dl, criterion)

      train_acc.append(train_accuracy)
      valid_acc.append(valid_accuracy)
      loss_lst_per_sample.append(train_loss_batch_list)

      if i < print_amount:
        print("epoch: {:3.0f}".format(epoch+1) + " | train accuracy: {:7.5f}".format(train_accuracy) + " | test accuracy: {:7.5f}".format(valid_accuracy))

    loss_lst_per_sample = list(chain.from_iterable(loss_lst_per_sample)) # flatten the list of batch losses

    loss_dict[i] = mean(loss_lst_per_sample)
    slopes[i] = calculate_slope(losses=loss_lst_per_sample) # mean-slope

    if i < print_amount:
      print("For all local epochs:" + " | train accuracy: {:7.5f}".format(mean(train_acc)) + " | test accuracy: {:7.5f}".format(valid_accuracy))
      print(f"loss_dict:{loss_dict}")
      print(f"loss_lst_per_client:{loss_lst_per_sample}")
      print(f"slope:{slopes}")

  return slopes

### Functions for Sequential Federated Averaging

In [16]:
# check data poisoning and model integrity
def calculate_integrity(model, baseline_model):
    # Calculate the L2 norm (# Calculate L2 norm like l2_norm = np.linalg.norm(v)) of the difference between the model parameters
    integrity = np.linalg.norm(np.array(list(model.parameters())[0].detach().numpy()) -
                          np.array(list(baseline_model.parameters())[0].detach().numpy()))
    return integrity

def verify_model_integrity(model, baseline_model, integrity_threshold=0.01):
    integrity = calculate_integrity(model, baseline_model)

    if integrity < integrity_threshold:
        print(f"Model integrity check failed. Integrity: {integrity}")
        return next_model
    else:
        print(f"Model integrity check passed. Integrity: {integrity}")
        return next_model.load_state_dict(model.state_dict())

In [17]:
# for sequential training
def check_privacy_in_server_and_send_model_dict_to_next_clients(model, next_model):
  '''
  This function gets back the parameters of the previous client/sample and sets it as a parameter for the next client/sample
  '''
  with torch.no_grad():
    # checking model integrity, it needs to be discussed the value of integrity_threshold
    # verify_model_integrity(model, baseline_model, integrity_threshold=0.01)

    # Clone the parameters from the model_dict to the next_model (next client)
    next_model.load_state_dict(model.state_dict())

  return next_model

In [18]:
# model_save_dir = "checkpoints/"
# logger = Logger().get_logger(logger_name='Advertima logs')

In [19]:
def sequential_training(model_dict, number_of_samples, sorted_samples):
  '''
  train local models sequentially
  clients have been reordered based on their slopes (less to more)
  and then we have this procedure: client1 --> server --> client2 --> server --> client3 --> ...

  @param: main_model
  @param: model_dict --> dictionary of all client's models
  @param: number_of_samples --> number of clients
  @param: sorted_samples --> sorted clients based on their slopes

  return: model --> the final model resulted from sequential learning
  '''

  # a dict to save the losses of epoch * batch for reordering
  slopes = dict()

  # initialize a new model
  next_model = base_model
  next_model.to(device)

  # Initialize weights and biases to zero in the new model
  for param in next_model.parameters():
    nn.init.zeros_(param)

  # i: a counter on clients
  # j: real client id
  for i in range(number_of_samples):
    if (i==0):
      j = sorted_samples[i][0]

       # Clone the parameters from the model_dict to the next_model (next client)
      next_model.load_state_dict(model_dict[name_of_models[j]].state_dict())

    else:
      j = sorted_samples[i][0]

      # Clone the parameters from the next client to the model_dict
      model_dict[name_of_models[j]].load_state_dict(next_model.state_dict())

    train_ds = TensorDataset(x_train_dict[name_of_x_train_sets[j]], y_train_dict[name_of_y_train_sets[j]])
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

    valid_ds = TensorDataset(x_valid_dict[name_of_x_valid_sets[j]], y_valid_dict[name_of_y_valid_sets[j]])
    valid_dl = DataLoader(valid_ds, batch_size=batch_size, shuffle=True)

    test_ds = TensorDataset(x_test_dict[name_of_x_test_sets[j]], y_test_dict[name_of_y_test_sets[j]])
    test_dl = DataLoader(test_ds, batch_size= batch_size)

    model = model_dict[name_of_models[j]]
    criterion = criterion_dict[name_of_criterions[j]]
    optimizer = optimizer_dict[name_of_optimizers[j]]

    loss_lst_per_sample = []
    # early stopping patience; how long to wait after last time validation loss improved.

    # initialize the early_stopping object
    early_stopping_local = EarlyStopping(patience=patience_local, verbose=True)

    prev_loss = float('inf')
    clinet_stop_counter = 0

    for epoch in range(numEpoch):
      train_loss_batch_list, train_accuracy, train_accuracy_skl, train_sens, train_spec, train_f1 = trainer.train_with_batch_loss(model, train_dl, criterion, optimizer)
      valid_loss, valid_accuracy, valid_accuracy_skl, valid_sens, valid_spec, valid_f1 =  validator.validation(model, valid_dl, criterion)
      loss_lst_per_sample.append(train_loss_batch_list)

      # record information per ecpoch for each client training
      local_metrics_train.write(f"{iter}\t{j}\t{train_accuracy}\t{np.mean(train_loss_batch_list)}\t{train_sens}\t{train_spec}\t{train_f1}\n")
      local_metrics_validation.write(f"{iter}\t{j}\t{valid_accuracy}\t{valid_loss}\t{valid_sens}\t{valid_spec}\t{valid_f1}\n")

      # testing each hospital on the same server test data for generalizability
      test_loss, test_accuracy, test_accuracy_skl, test_sens, test_spec, test_f1 = validator.validation(model, test_dl_server, criterion)
      global_metrics_test.write(f"{iter}\t{j}\t{test_accuracy}\t{test_loss}\t{test_sens}\t{test_spec}\t{test_f1}\n")

      '''
      Here we are going to stop training if the valdation loss decresses a little bit. The general trend of our validation loss is "Descending".
      So, we compare the exsiting val_loss with the previous loss and if there were not a signigicant improvment, we stop the client training.
            '''
      if valid_loss <= prev_loss:
        if prev_loss - valid_loss < 0.01:
          print("prev_loss - valid_loss= ", prev_loss - valid_loss)
          clinet_stop_counter += 1

        if clinet_stop_counter >= patience_local:
            print(f"Stoping Client {j}")
            break

        prev_loss = valid_loss

      # early_stopping needs the validation loss to check if it has decresed,
      # and if it has, it will make a checkpoint of the current model
      early_stopping_local(valid_loss, model)
      if early_stopping_local.early_stop:
        print("Early Stopping Local")
        break

    loss_lst_per_sample = list(chain.from_iterable(loss_lst_per_sample)) # flatten the list of batch losses
    slopes[j] = calculate_slope(losses=loss_lst_per_sample) # mean-slope

    next_model = check_privacy_in_server_and_send_model_dict_to_next_clients(model, next_model)
    print(f"privacy checked and client {j} transfered its weight to server and server transfered it to next client")

  return model, slopes

In [20]:
# Function to calculate the slope of the loss function
def calculate_slope(losses):
  '''
  calculate the slope of the loss function
  '''
  slopes = np.diff(losses)
  mean_slope = np.mean(slopes)
  return mean_slope

In [21]:
def reorder_samples_wise_task_complexity(slopes):
  '''
  reorder clients wise task complexity based slopes from less to more

  @param: slopes --> Slops is a dictionary of client id to the client slop (i.e., {c1: s1})

  return: sorted_clients
  '''
  # For ordering
  sorted_clients = sorted(slopes.items(), key=lambda kv: (kv[1], kv[0]))

  # No ordering
  # sorted_clients = sorted(slopes.items())

  # Random Ordering
  # sorted_clients = sorted(slopes.items()) # to get the right data structure i.e., [(0, 0.1), (1, 0.2), ...]
  # random.shuffle(sorted_clients)

  # sorted_clients = [(c0, s0), (c1, s1)] i.e., (k, v)
  return sorted_clients

In [22]:
def server_training(main_model, main_optimizer, main_criterion):
  '''
  train server on mixed data.
  '''
  # initialize the early_stopping object
  early_stopping_server = EarlyStopping(patience=patience_local, verbose=True)

  prev_loss = float('inf')
  server_stop_counter = 0

  for epoch in range(numEpoch):
    central_train_loss, central_train_accuracy, central_train_accuracy_skl, central_train_sens, central_train_spec, central_train_f1 = trainer.train(main_model, train_dl_server, main_criterion, main_optimizer)
    central_valid_loss, central_valid_accuracy, central_valid_accuracy_skl, central_valid_sens, central_valid_spec, central_valid_f1 = validator.validation(main_model, valid_dl_server, main_criterion)

    global_metrics_train_server.write(f"{iter}\t{epoch}\t{central_train_accuracy}\t{central_train_loss}\t{central_train_sens}\t{central_train_spec}\t{central_train_f1}\n")
    global_metrics_valid_server.write(f"{iter}\t{epoch}\t{central_valid_accuracy}\t{central_valid_loss}\t{central_valid_sens}\t{central_valid_spec}\t{central_valid_f1}\n")

    '''
      Here we are going to stop training if the valdation loss decresses a little bit. The general trend of our validation loss is "Descending".
      So, we compare the exsiting val_loss with the previous loss and if there were not a signigicant improvment, we stop the client training.
            '''
    if central_valid_loss <= prev_loss:
      if prev_loss - central_valid_loss < 0.01:
        print("prev_loss - valid_loss= ", prev_loss - central_valid_loss)
        server_stop_counter += 1

      if server_stop_counter >= patience_local:
        print(f"Stoping server")
        break

      prev_loss = central_valid_loss

    # early_stopping needs the validation loss to check if it has decresed,
    # and if it has, it will make a checkpoint of the current model
    early_stopping_server(central_valid_loss, main_model)
    if early_stopping_server.early_stop:
      print("Early Stopping server")
      break

  print("epoch: {:3.0f}".format(epoch+1) + " | train accuracy: {:7.4f}".format(central_train_accuracy) + " | test accuracy: {:7.4f}".format(central_valid_accuracy))
  return main_model, main_optimizer, main_criterion

In [23]:
def update_accumulation(main_model, model_seq):
    '''
    Element_wise sum and average

    An example of the result:

    model_seqential parameters
    [Parameter containing:
    layer1.0.weight: tensor([[[[ 0.0295,  0.4453, -0.0934],
              [ 0.1785,  0.2598, -0.3424],
              [ 0.0835, -0.1489, -0.5921]]],


    Iteration 0 : main_model accuracy on server test data after training:  0.7019
    ****** mupdated main_model parameters
    [Parameter containing:
    layer1.0.weight: tensor([[[[-0.0799,  0.2727,  0.0540],
              [ 0.1563,  0.3248,  0.0287],
              [-0.0239,  0.0026, -0.1325]]],

    Iteration 0 : main_model accuracy on server test data after accumulation:  0.1144
    ****** updated main_model parameters after accumulation
    [Parameter containing:
    layer1.0.weight: tensor([[[[-0.0252,  0.3590, -0.0197],
              [ 0.1674,  0.2923, -0.1568],
              [ 0.0298, -0.0731, -0.3623]]],
    '''

    for (main_params, model_seq_params) in zip(main_model.parameters(), model_seq.parameters()):
      main_params.data = (a * main_params.data) + ((1-a) * model_seq_params.data)
    return main_model

In [24]:
def print_model_weights(model):
    print("\n The last layer of main_model")
    print(list(model.parameters())[0][1])

In [25]:
# MODEL
learning_rate = configs['model']['learning_rate']
batch_size = configs['model']['batch_size']
momentum = configs['model']['momentum']

numEpoch = configs['model']['num_epoch']
# numEpoch = 5

NUM_ITERATION = configs['model']['num_iteration']
# NUM_ITERATION = 200
# MODEL = ['simcnn', 'cnn5', 'vgg11', 'resnet18']
MODEL = 'simcnn'

# STOP = ["wes","woes"]
STOP = "woes"
# For "wes" we have:
# early stopping patience; how long to wait after last time validation loss improved.
patience_local = 7 # wes=7  woes=7
patience_global = 200 # wes=10  woes=200

STOP_CLIENTS="slightimprove"

# DATASET
# n_classes = 23
n_channels = configs['dataset']['num_channels']

# 'seq-worder': seq using our ordering method
#
# STRATEGY = ['baseline', 'fedavg', 'seq-noorder', 'seq-worder',]
STRATEGY = "seq-worder"

# SPILIT = ['random', 'stratified']
SPILT = "stratified"


# Status = 'withclass'
# num_client: is the number of all clients
num_client = 3*8 #24
number_of_samples= 3*8 #24  # 3*num_client=24
a = 0.7 #the mixing weight in taking average server parameters and the last client's parameters in sequential training
print_amount=1

In [None]:
def create_file():

    local_metrics_train = open(f"results/{SAMPLING}/{STRATEGY}/{MODEL}/{SPILT}/local_train_iter{NUM_ITERATION}_epoch{numEpoch}_lr{learning_rate}_bs{batch_size}_ds{amount}_m{momentum}_a{a}_{STOP}_{STOP_CLIENTS}_pl{patience_local}_pg{patience_global}.txt", 'w')
    local_metrics_train.write("iteration\tclient_id\ttrain_accuracy\ttrain_loss\ttrain_sensitivity\ttrain_specificity\ttrain_f1\n")

    local_metrics_validation = open(f"results/{SAMPLING}/{STRATEGY}/{MODEL}/{SPILT}/local_valid_iter{NUM_ITERATION}_epoch{numEpoch}_lr{learning_rate}_bs{batch_size}_ds{amount}_m{momentum}_a{a}_{STOP}_{STOP_CLIENTS}_pl{patience_local}_pg{patience_global}.txt", 'w')
    local_metrics_validation.write("iteration\tclient_id\tvalid_accuracy\tvalid_loss\tvalid_sensitivity\tvalid_specificity\tvalid_f1\n")

    global_metrics_test = open(f"results/{SAMPLING}/{STRATEGY}/{MODEL}/{SPILT}/global_test_iter{NUM_ITERATION}_epoch{numEpoch}_lr{learning_rate}_bs{batch_size}_ds{amount}_m{momentum}_a{a}_{STOP}_{STOP_CLIENTS}_pl{patience_local}_pg{patience_global}.txt", 'w')
    global_metrics_test.write("iteration\tclient_id\ttest_accuracy\ttest_loss\ttest_sensitivity\ttest_specificity\ttest_f1\n")

    # server
    global_metrics_train_server = open(f"results/{SAMPLING}/{STRATEGY}/{MODEL}/{SPILT}/global_train_server_iter{NUM_ITERATION}_epoch{numEpoch}_lr{learning_rate}_bs{batch_size}_ds{amount}_m{momentum}_a{a}_{STOP}_{STOP_CLIENTS}_pl{patience_local}_pg{patience_global}.txt", 'w')
    global_metrics_train_server.write("iteration\tepoch\ttrain_accuracy\ttrain_loss\ttrain_sensitivity\ttrain_specificity\ttrain_f1\n")
    # server
    global_metrics_valid_server = open(f"results/{SAMPLING}/{STRATEGY}/{MODEL}/{SPILT}/global_valid_server_iter{NUM_ITERATION}_epoch{numEpoch}_lr{learning_rate}_bs{batch_size}_ds{amount}_m{momentum}_a{a}_{STOP}_{STOP_CLIENTS}_pl{patience_local}_pg{patience_global}.txt", 'w')
    global_metrics_valid_server.write("iteration\tepoch\tvalid_accuracy\tvalid_loss\tvalid_sensitivity\tvalid_specificity\tvalid_f1\n")

    accumulate_metrics_valid = open(f"results/{SAMPLING}/{STRATEGY}/{MODEL}/{SPILT}/accumulate_valid_iter{NUM_ITERATION}_epoch{numEpoch}_lr{learning_rate}_bs{batch_size}_ds{amount}_m{momentum}_a{a}_{STOP}_{STOP_CLIENTS}_pl{patience_local}_pg{patience_global}.txt", 'w')
    accumulate_metrics_valid.write("iteration\tvalid_accuracy\tvalid_loss\tvalid_sensitivity\tvalid_specificity\tvalid_f1\n")

    final_metrics_test= open(f"results/{SAMPLING}/{STRATEGY}/{MODEL}/{SPILT}/final_test_iter{NUM_ITERATION}_epoch{numEpoch}_lr{learning_rate}_bs{batch_size}_ds{amount}_m{momentum}_a{a}_{STOP}_{STOP_CLIENTS}_pl{patience_local}_pg{patience_global}.txt", 'w')
    final_metrics_test.write("iteration\ttest_accuracy\ttest_loss\ttest_sensitivity\ttest_specificity\ttest_f1\n")

    elapsed_time = open(f"results/{SAMPLING}/{STRATEGY}/{MODEL}/{SPILT}/{STRATEGY}_elapsed_time_iter{NUM_ITERATION}_epoch{numEpoch}_lr{learning_rate}_bs{batch_size}_ds{amount}_m{momentum}_a{a}_{STOP}_{STOP_CLIENTS}_pl{patience_local}_pg{patience_global}.txt", 'w')
    elapsed_time.write("elapsed_time\n")

    return local_metrics_train, local_metrics_validation, global_metrics_test, global_metrics_train_server, global_metrics_valid_server,accumulate_metrics_valid, final_metrics_test, elapsed_time

## Ctreat these files for evaluating scalibilty/a parameter across num_clients clients 

In [26]:
def create_file():

    local_metrics_train = open(f"results/a/{SAMPLING}/{STRATEGY}/{MODEL}/{SPILT}/local_train_iter{NUM_ITERATION}_epoch{numEpoch}_lr{learning_rate}_bs{batch_size}_ds{amount}_m{momentum}_a{a}_nc{num_client}_{STOP}_{STOP_CLIENTS}_pl{patience_local}_pg{patience_global}.txt", 'w')
    local_metrics_train.write("iteration\tclient_id\ttrain_accuracy\ttrain_loss\ttrain_sensitivity\ttrain_specificity\ttrain_f1\n")

    local_metrics_validation = open(f"results/a/{SAMPLING}/{STRATEGY}/{MODEL}/{SPILT}/local_valid_iter{NUM_ITERATION}_epoch{numEpoch}_lr{learning_rate}_bs{batch_size}_ds{amount}_m{momentum}_a{a}_nc{num_client}_{STOP}_{STOP_CLIENTS}_pl{patience_local}_pg{patience_global}.txt", 'w')
    local_metrics_validation.write("iteration\tclient_id\tvalid_accuracy\tvalid_loss\tvalid_sensitivity\tvalid_specificity\tvalid_f1\n")

    global_metrics_test = open(f"results/a/{SAMPLING}/{STRATEGY}/{MODEL}/{SPILT}/global_test_iter{NUM_ITERATION}_epoch{numEpoch}_lr{learning_rate}_bs{batch_size}_ds{amount}_m{momentum}_a{a}_nc{num_client}_{STOP}_{STOP_CLIENTS}_pl{patience_local}_pg{patience_global}.txt", 'w')
    global_metrics_test.write("iteration\tclient_id\ttest_accuracy\ttest_loss\ttest_sensitivity\ttest_specificity\ttest_f1\n")

    # server
    global_metrics_train_server = open(f"results/a/{SAMPLING}/{STRATEGY}/{MODEL}/{SPILT}/global_train_server_iter{NUM_ITERATION}_epoch{numEpoch}_lr{learning_rate}_bs{batch_size}_ds{amount}_m{momentum}_a{a}_nc{num_client}_{STOP}_{STOP_CLIENTS}_pl{patience_local}_pg{patience_global}.txt", 'w')
    global_metrics_train_server.write("iteration\tepoch\ttrain_accuracy\ttrain_loss\ttrain_sensitivity\ttrain_specificity\ttrain_f1\n")
    # server
    global_metrics_valid_server = open(f"results/a/{SAMPLING}/{STRATEGY}/{MODEL}/{SPILT}/global_valid_server_iter{NUM_ITERATION}_epoch{numEpoch}_lr{learning_rate}_bs{batch_size}_ds{amount}_m{momentum}_a{a}_nc{num_client}_{STOP}_{STOP_CLIENTS}_pl{patience_local}_pg{patience_global}.txt", 'w')
    global_metrics_valid_server.write("iteration\tepoch\tvalid_accuracy\tvalid_loss\tvalid_sensitivity\tvalid_specificity\tvalid_f1\n")

    accumulate_metrics_valid = open(f"results/a/{SAMPLING}/{STRATEGY}/{MODEL}/{SPILT}/accumulate_valid_iter{NUM_ITERATION}_epoch{numEpoch}_lr{learning_rate}_bs{batch_size}_ds{amount}_m{momentum}_a{a}_nc{num_client}_{STOP}_{STOP_CLIENTS}_pl{patience_local}_pg{patience_global}.txt", 'w')
    accumulate_metrics_valid.write("iteration\tvalid_accuracy\tvalid_loss\tvalid_sensitivity\tvalid_specificity\tvalid_f1\n")

    final_metrics_test= open(f"results/a/{SAMPLING}/{STRATEGY}/{MODEL}/{SPILT}/final_test_iter{NUM_ITERATION}_epoch{numEpoch}_lr{learning_rate}_bs{batch_size}_ds{amount}_m{momentum}_a{a}_nc{num_client}_{STOP}_{STOP_CLIENTS}_pl{patience_local}_pg{patience_global}.txt", 'w')
    final_metrics_test.write("iteration\ttest_accuracy\ttest_loss\ttest_sensitivity\ttest_specificity\ttest_f1\n")

    elapsed_time = open(f"results/a/{SAMPLING}/{STRATEGY}/{MODEL}/{SPILT}/{STRATEGY}_elapsed_time_iter{NUM_ITERATION}_epoch{numEpoch}_lr{learning_rate}_bs{batch_size}_ds{amount}_m{momentum}_a{a}_nc{num_client}_{STOP}_{STOP_CLIENTS}_pl{patience_local}_pg{patience_global}.txt", 'w')
    elapsed_time.write("elapsed_time\n")

    return local_metrics_train, local_metrics_validation, global_metrics_test, global_metrics_train_server, global_metrics_valid_server,accumulate_metrics_valid, final_metrics_test, elapsed_time

In [27]:
local_metrics_train, local_metrics_validation, global_metrics_test, global_metrics_train_server, global_metrics_valid_server, accumulate_metrics_valid, final_metrics_test, elapsed_time = create_file()

### Parameter Initialization: Models, Optimizers, and Loss Functions

In [28]:
base_model = get_model(model_name=MODEL)
base_model.to(device)
# base_model (to print the architecture of the model)

SimCNN(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=1600, out_features=2048, bias=True)
  (output): Linear(in_features=2048, out_features=23, bias=True)
)

In [29]:
model_dict, optimizer_dict, criterion_dict = create_model_optimizer_criterion_dict(number_of_samples)

In [30]:
name_of_x_train_sets = list(x_train_dict.keys())
name_of_y_train_sets = list(y_train_dict.keys())

name_of_x_valid_sets = list(x_valid_dict.keys())
name_of_y_valid_sets = list(y_valid_dict.keys())

name_of_x_test_sets = list(x_test_dict.keys())
name_of_y_test_sets = list(y_test_dict.keys())

# model_dcit = {'model_1': model_info_1, ..}
# names_of_models = ['model_1', 'model_2', ...]
name_of_models = list(model_dict.keys())
name_of_optimizers=list(optimizer_dict.keys())
name_of_criterions=list(criterion_dict.keys())

In [None]:
# Federated learning with sequential learning strategy based on the slope of the loss function
print("Model is ", MODEL)
print("lr is ", learning_rate)
print("Num iteration is ", NUM_ITERATION)
print("Num epochs is ", numEpoch)
print("Stop", STOP)
print("patience_local is", patience_local)
print("patience_global is", patience_global)


strat_training_time = time.time()
checkpoint_path_global = f"checkpoints/{SAMPLING}_{STRATEGY}_{MODEL}_{SPILT}_iter{NUM_ITERATION}_epoch{numEpoch}_lr{learning_rate}_bs{batch_size}_ds{amount}_m{momentum}_{STOP}_{STOP_CLIENTS}_Global.pt"

# 0: creat main model
main_model = get_model(model_name=MODEL)
main_model.to(device)
main_optimizer = torch.optim.SGD(main_model.parameters(), lr=learning_rate, momentum=momentum)
main_criterion = nn.CrossEntropyLoss()

# 1: First training run
print("First training round is started ...")
model_dict = send_main_model_to_nodes_and_update_model_dict(main_model, model_dict, number_of_samples)
slopes_seq = start_train_end_node_process_print_some(number_of_samples, print_amount)
sorted_clients = reorder_samples_wise_task_complexity(slopes_seq)
print(f"samples have been sorted by their slope values: {sorted_clients}")
print("First training round is finished!")

# initialize the early_stopping object
early_stopping_global = EarlyStopping(patience=patience_global, verbose=True, path=checkpoint_path_global)

for iter in range(NUM_ITERATION):
  print("global epoch", iter)

  # 2: Sequential training
  model_seq, slopes = sequential_training(model_dict, number_of_samples, sorted_clients)

  # 3: Server training
  main_model, main_optimizer, main_criterion = server_training(main_model, main_optimizer, main_criterion)

  # 4: Update accumulation
  main_model = update_accumulation(main_model, model_seq)

  valid_loss_accumulation, valid_accuracy_accumulation, valid_accuracy_skl_accumulation, valid_sens_accumulation, valid_spec_accumulation, valid_f1_accumulation = validator.validation(main_model, valid_dl_server, main_criterion)
  accumulate_metrics_valid.write(f"{iter}\t{valid_accuracy_accumulation}\t{valid_loss_accumulation}\t{valid_sens_accumulation}\t{valid_spec_accumulation}\t{valid_f1_accumulation}\n")

  early_stopping_global(valid_loss_accumulation, main_model)
  if early_stopping_global.early_stop:
    print("Early Stopping Global")
    break

  # 5: Reordering (ranking hospitals/clients)
  sorted_clients = reorder_samples_wise_task_complexity(slopes)

test_loss, test_accuracy, test_accuracy_skl, test_sens, test_spec, test_f1= validator.validation(main_model, test_dl_server, main_criterion)
print("Iteration", str(iter), ": final main_model accuracy on server test data after accumulation: {:7.4f}".format(test_accuracy))
final_metrics_test.write(f"{iter}\t{test_accuracy}\t{test_loss}\t{test_sens}\t{test_spec}\t{test_f1}\n")


end_training_time = time.time()
elapsed_training_time = end_training_time - strat_training_time
print("elapsed_training_time:", elapsed_training_time)
elapsed_time.write(f"{elapsed_training_time}\n")

local_metrics_train.close()
local_metrics_validation.close()
global_metrics_test.close()
global_metrics_train_server.close()
global_metrics_valid_server.close()
accumulate_metrics_valid.close()
final_metrics_test.close()
elapsed_time.close()