# Baseline (Centralised)

In [None]:
! pip install medmnist
! pip3 install pyyaml

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cd /content/drive/MyDrive/FederatedLearning/MICCAI2024/

In [None]:
# Import libraries
import medmnist
from medmnist import INFO

from tqdm import tqdm
import collections
import time
import numpy as np
import random
import os
from statistics import mean
import PIL
from itertools import chain
import yaml

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torchvision.transforms as transforms

# Models
from models.CNN import CNN5
from models.VGG11 import VGG11
from models.ResNet18 import ResNet_18
from models.SimCNN import SimCNN
from models.mobilenetv2 import MyMobileNet_v2

# Train and Test/validation
from utils.Training import Train
from utils.Evaluation import Evaluator

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

# import EarlyStopping
from utils.pytorchtools import EarlyStopping

# making code reproducible
from utils.seedeverything import seedevrything

In [None]:
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

In [None]:
# Setting Up Random Seeds In PyTorch
seedevrything.set_seed(42)

## Preparing Dataset including Server and Clients

In [None]:
print(f"MedMNIST v{medmnist.__version__} @ {medmnist.HOMEPAGE}")

In [None]:
# getting all data flags in MedMNIST
medmnist.INFO

## Load Configs

In [None]:
# load configs data
with open("configs.yaml", 'r') as configs:
    configs = yaml.safe_load(configs)

In [None]:
# List of datasets that we are going to use in this experiment
# data_flags = ['octmnist', 'organamnist', 'tissuemnist']
data_flags = configs['dataset']['data_flags']
# data_flags = ['octmnist']
# data_flags = ['organamnist']
# data_flags = ['tissuemnist']

In [None]:
n_classes = 0
for dataset in data_flags:
  n_classes += len(medmnist.INFO[dataset]['label'])
print(n_classes)

In [None]:
def get_transform():
  '''
  preprocessing
  '''
  data_transform = transforms.Compose([
      transforms.Resize((224, 224), interpolation=PIL.Image.NEAREST),
      transforms.ToTensor(),
      transforms.Normalize(mean=[.5], std=[.5])
  ])

  return data_transform

In [None]:
def get_dataset(dataset_name='pathmnist', split='train'):
  '''
  load the data
  '''
  info = INFO[dataset_name]
  DataClass = getattr(medmnist, info['python_class']) # To load selected MedMNIST dataset

  data_transform = get_transform()
  dataset = DataClass(split=split, transform=data_transform, download=True)

  return dataset, info

In [None]:
dataflag_to_dataset = dict()

for data_flag in data_flags:
  dataset, info = get_dataset(dataset_name = data_flag, split='train')
  dataflag_to_dataset[data_flag] = (dataset, info)

In [None]:
# info = INFO['octmnist']
# n_channels = info['n_channels']
# print("n_channels:", n_channels)

In [None]:
# https://stackoverflow.com/questions/28663856/how-do-i-count-the-occurrence-of-a-certain-item-in-an-ndarray
def print_unique_count(new_list):
    for l in set(new_list):
        print(l, new_list.count(l))

def np_print_unique_count(np_ndarray_list):
    counter = collections.Counter(np_ndarray_list)
    print(counter.keys())
    print(counter)

In [None]:
# Dataset parameter configuration
amount = configs['dataset']['amount']
shift = 0
total_labels = 0
test_ratio = configs['dataset']['test_ratio']
valid_ratio = configs['dataset']['valid_ratio']

In [None]:
X_train_set, Y_train_set, X_valid_set, Y_valid_set, X_test_set, Y_test_set = list(), list(), list(), list(), list(), list()

for data_flag, (data, info) in dataflag_to_dataset.items():
  # print(data_flag)
  # print(len(data.imgs))
  dataset = data.imgs[0: amount]
  # print(len(dataset))
  # select the amount of labels and make it falt
  labels = list(np.ravel(data.labels[0: amount]))
  # print_unique_count(labels)
  # print(set(labels))
  total_labels += len(set(labels))
  labels = [label + shift for label in labels]
  shift += len(dataflag_to_dataset[data_flag][1]['label'].keys())

  X_trainvalid_data, X_test, y_trainvalid, y_test = train_test_split(dataset, labels, test_size=test_ratio, random_state=42, stratify=labels)
  X_train, X_valid, y_train, y_valid = train_test_split(X_trainvalid_data, y_trainvalid, test_size=valid_ratio / (1 - test_ratio), random_state=42, stratify=y_trainvalid)


  X_train_set.append(X_train)
  Y_train_set.append(y_train)
  X_valid_set.append(X_valid)
  Y_valid_set.append(y_valid)
  X_test_set.append(X_test)
  Y_test_set.append(y_test)

X_train_set = np.concatenate(X_train_set, axis=0)
Y_train_set = np.concatenate(Y_train_set, axis=0)
X_valid_set = np.concatenate(X_valid_set, axis=0)
Y_valid_set = np.concatenate(Y_valid_set, axis=0)
X_test_set = np.concatenate(X_test_set, axis=0)
Y_test_set = np.concatenate(Y_test_set, axis=0)

In [None]:
np_print_unique_count(Y_train_set)

In [None]:
X_train_set = X_train_set.astype('float32')
X_test_set = X_test_set.astype('float32')
X_valid_set = X_valid_set.astype('float32')

X_train_set = X_train_set[:,None,:,:]
Y_train_set = Y_train_set[:, None]

X_valid_set = X_valid_set[:,None,:,:]
Y_valid_set = Y_valid_set[:, None]

X_test_set = X_test_set[:,None,:,:]
Y_test_set = Y_test_set[:, None]


X_train_set, Y_train_set, X_valid_set, Y_valid_set, X_test_set, Y_test_set  = map(torch.tensor, (X_train_set, Y_train_set, X_valid_set, Y_valid_set, X_test_set, Y_test_set))

# Transfer data from CPU to GPU
X_train_set = X_train_set.to(device)
Y_train_set = Y_train_set.to(device)
X_valid_set = X_valid_set.to(device)
Y_valid_set = Y_valid_set.to(device)
X_test_set = X_test_set.to(device)
Y_test_set = Y_test_set.to(device)

In [None]:
# Define a custom transformation function to resize the images
def resize_transform(image_tensor, new_height, new_width):
  image_pil = transforms.ToPILImage()(image_tensor)
  resized_image = transforms.Resize((new_height, new_width))(image_pil)
  return transforms.ToTensor()(resized_image)

In [None]:
server_batch_size = configs['model']['barch_size']

# Apply the custom transformation to the entire dataset
resized_images_train = [resize_transform(image, 32, 32) for image in X_train_set]
train_ds_server = TensorDataset(torch.stack(resized_images_train), Y_train_set)

resized_images_valid= [resize_transform(image, 32, 32) for image in X_valid_set]
valid_ds_server = TensorDataset(torch.stack(resized_images_valid), Y_valid_set)

resized_images_test = [resize_transform(image, 32, 32) for image in X_test_set]
test_ds_server = TensorDataset(torch.stack(resized_images_test), Y_test_set)
###

train_dl_server = DataLoader(train_ds_server, batch_size=server_batch_size, shuffle=True)
valid_dl_server = DataLoader(valid_ds_server, batch_size=server_batch_size, shuffle=True)
test_dl_server = DataLoader(test_ds_server, batch_size=server_batch_size)

## Classification Models

In [None]:
bottleneckLayerDetails = [
        # (expansion, out_dimension, number_of_times, stride)
            (1,16,1,1),
            (6,24,2,2),
            (6,32,3,2),
            (6,64,4,2),
            (6,96,3,1),
            (6,160,3,2),
            (6,320,1,1)
        ]

In [None]:
def get_model(model_name):
  if model_name == 'cnn5':
    return CNN5(num_classes=n_classes)
  if model_name == 'vgg11':
    return VGG11(num_classes=n_classes)
  if model_name == 'simcnn':
    return SimCNN(num_classes=n_classes)
  if model_name =='mobilenetv2':
    return MyMobileNet_v2(bottleneckLayerDetails, width_multiplier=1)
  elif model_name == 'resnet18':
    return ResNet_18(num_classes=n_classes)

In [None]:
def create_file():
    baseline_metrics_train= open(f"results/{STRATEGY}/{MODEL}/{SPILT}/{STRATEGY}_train_iter{NUM_ITERATION}_lr{learning_rate}_bs{batch_size}_ds{amount}_m{momentum}_{STOP}.txt", 'w')
    baseline_metrics_train.write("iteration\ttrain_accuracy\ttrain_loss\ttrain_sensitivity\ttrain_specificity\ttrain_f1\n")

    baseline_metrics_validation = open(f"results/{STRATEGY}/{MODEL}/{SPILT}/{STRATEGY}_validation_iter{NUM_ITERATION}_lr{learning_rate}_bs{batch_size}_ds{amount}_m{momentum}_{STOP}.txt", 'w')
    baseline_metrics_validation.write("iteration\tvalid_accuracy\tvalid_loss\tvalid_sensitivity\tvalid_specificity\tvalid_f1\n")

    baseline_metrics_test = open(f"results/{STRATEGY}/{MODEL}/{SPILT}/{STRATEGY}_test_iter{NUM_ITERATION}_lr{learning_rate}_bs{batch_size}_ds{amount}_m{momentum}_{STOP}.txt", 'w')
    baseline_metrics_test.write("iteration\ttest_accuracy\ttest_loss\ttest_sensitivity\ttest_specificity\ttest_f1\n")

    elapsed_time = open(f"results/{STRATEGY}/{MODEL}/{SPILT}/{STRATEGY}_elapsed_time_iter{NUM_ITERATION}_lr{learning_rate}_bs{batch_size}_ds{amount}_m{momentum}_{STOP}.txt", 'w')
    elapsed_time.write("elapsed_time\n")

    return baseline_metrics_train, baseline_metrics_validation, baseline_metrics_test, elapsed_time

### Parameter Initialization: Models, Optimizers, and Loss Functions

In [None]:
# MODEL
learning_rate = configs['model']['learning_rate']
batch_size = configs['model']['barch_size']
momentum = configs['model']['momentum']
NUM_ITERATION = configs['model']['num_iteration']

# MODEL = ['cnn5', 'simcnn' 'vgg11', 'resnet18']
MODEL = 'simcnn'

# STOP = ["wes","woes"]
STOP = "woes"

# DATASET
n_classes = n_classes
n_channels = configs['dataset']['num_channels']

# STRATEGY = ['baseline', 'fedavg', 'seq']
STRATEGY = "baseline"

# SPILIT = ['random', 'stratified']
SPILT = "stratified"

# TASK = ["singletask", "multitask"]
# TASK = "multitask"

In [None]:
base_model = get_model(model_name=MODEL)
base_model.to(device)
# base_model (to print the architecture of the model)

In [None]:
baseline_metrics_train, baseline_metrics_validation, baseline_metrics_test, elapsed_time = create_file()

In [None]:
print(f"Model is {MODEL}")

checkpoint_path = f"checkpoints/{STRATEGY}_{MODEL}_{SPILT}_iter{NUM_ITERATION}_lr{learning_rate}_bs{batch_size}_ds{amount}_m{momentum}_{STOP}.pt"

start_server_training_time = time.time()

# train server on mixed data.
main_model = get_model(model_name = MODEL)
main_model.to(device)
main_optimizer = torch.optim.SGD(main_model.parameters(), lr=learning_rate, momentum=0.9)
main_criterion = nn.CrossEntropyLoss()

# early stopping patience; how long to wait after last time validation loss improved.
patience = 50
# initialize the early_stopping object
early_stopping = EarlyStopping(patience=patience, verbose=True, path=checkpoint_path)

trainer = Train(n_classes)
validator = Evaluator(n_classes)

for iter in range(NUM_ITERATION):
  print("-------Baseline result----------")
  # model training
  central_train_loss, central_train_accuracy, central_train_accuracy_skl, central_train_sens, central_train_spec, central_train_f1 = trainer.train(main_model, train_dl_server, main_criterion, main_optimizer)
  # model validaiton
  central_valid_loss, central_valid_accuracy, central_valid_accuracy_skl, central_valid_sens, central_valid_spec, central_valid_f1 = validator.validation(main_model, valid_dl_server, main_criterion)

  print("epoch: {:3.0f}".format(iter+1) + " | train accuracy: {:7.4f}".format(central_train_accuracy) +  " | valid accuracy: {:7.4f}".format(central_valid_accuracy))
  print("epoch: {:3.0f}".format(iter+1) + " | train loss: {:7.4f}".format(central_train_loss) + " | valid loss: {:7.4f}".format(central_valid_loss))

  # early_stopping needs the validation loss to check if it has decresed,
  # and if it has, it will make a checkpoint of the current model
  early_stopping(central_valid_loss, main_model)

  if early_stopping.early_stop:
    print("Early stopping")
    break

  baseline_metrics_train.write(f"{iter}\t{central_train_accuracy}\t{central_train_loss}\t{central_train_sens}\t{central_train_spec}\t{central_train_f1}\n")
  baseline_metrics_validation.write(f"{iter}\t{central_valid_accuracy}\t{central_valid_loss}\t{central_valid_sens}\t{central_valid_spec}\t{central_valid_f1}\n")

# Testing final model
central_test_loss, central_test_accuracy, central_test_accuracy_skl, central_test_sens, central_test_spec, central_test_f1 = validator.validation(main_model, test_dl_server, main_criterion)
baseline_metrics_test.write(f"{iter}\t{central_test_accuracy}\t{central_test_accuracy_skl}\t{central_test_loss}\t{central_test_sens}\t{central_test_spec}\t{central_test_f1}\n")

print("final test accuracy: {:7.4f}".format(central_test_accuracy))

baseline_metrics_train.close()
baseline_metrics_validation.close()
baseline_metrics_test.close()

end_server_training_time = time.time()
elapsed_server_training_time = end_server_training_time - start_server_training_time
elapsed_time.write(f"{elapsed_server_training_time}\n")
elapsed_time.close()
print("elapsed_server_training_time: ", elapsed_server_training_time)