In [None]:
import os
import copy

import numpy as np
import torch
import torch.nn as nn
import math
from torchvision import transforms
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F

from unet import UNet
from dice_loss import dice_coeff

import matplotlib.pyplot as plt
from IPython.display import clear_output
import re
############################
# Helper func
############################
from helper import *
#################################
TRAIN_RATIO = 0.8
RS = 30448 # random state
N_CHANNELS, N_CLASSES = 1, 1 
bilinear = True
BATCH_SIZE, EPOCHS = 16, 250

img_size = 224
CROP_SIZE = (224, 224)
#########################################
data_path = './data'
PTH = './model/'
CLIENTS = ['TJCH','GDPH', 'CHSUMC', 'Rider','Interobs','Lung1']
CLIENTS_2 = [cl +'_2' for cl in CLIENTS]
TOTAL_CLIENTS = len(CLIENTS)

In [None]:
device = torch.device('cuda:0')
LR, WD, TH = 1.5e-5, 1e-5, 0.9

## Training path - Testing path

In [None]:
lung_dataset = dict()
for client in CLIENTS:
    if client != 'Interobs' and client != 'Lung1':
        lung_dataset[client+'_train']= BasicDataset(data_path,split = client,train=True,transforms = transforms.Compose([RandomGenerator(output_size=CROP_SIZE, train=True)]))
    
        lung_dataset[client+'_test'] = BasicDataset(data_path,split = client,train=False,transforms = transforms.Compose([RandomGenerator(output_size=CROP_SIZE, train=False)]))
    else:
        lung_dataset[client] = BasicDataset(data_path,split = client,train=False,transforms = transforms.Compose([RandomGenerator(output_size=CROP_SIZE, train = False)]))


## Initialize the weights

In [None]:
TOTAL_DATA = []
for client in CLIENTS:
    if client != 'Interobs' and client != 'Lung1':
        print(len(lung_dataset[client + '_train']))
        TOTAL_DATA.append(len(lung_dataset[client + '_train']))


DATA_AMOUNT = sum(TOTAL_DATA)
WEIGHTS = [t/DATA_AMOUNT for t in TOTAL_DATA]
ORI_WEIGHTS = copy.deepcopy(WEIGHTS)

score = [0,0,0,0]


# storage file

In [None]:
training_clients, testing_clients = dict(), dict()
 

acc_train, acc_valid, loss_train, loss_test = dict(), dict(), \
                                            dict(), dict()
loss_test = dict()
alpha_acc = []
    
nets, optimizers = dict(), dict()

In [None]:
nets['global'] = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, bilinear=True).to(device)
ema_net = nets['global']
for param in ema_net.parameters():
    param.detach_()


for client in CLIENTS:
    if client != 'Interobs' and client != 'Lung1':
 
        training_clients[client] = DataLoader(lung_dataset[client+'_train'], batch_size=32, shuffle=True, num_workers=8)

        ###################################################################################
        testing_clients[client] = DataLoader(lung_dataset[client+'_test'], batch_size=1, shuffle=False, num_workers=1)


        nets[client] = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, bilinear=True).to(device)

        optimizers[client]= optim.Adam(nets[client].parameters(), lr=LR, weight_decay=WD)
    else:
        testing_clients[client] = DataLoader(lung_dataset[client], batch_size=1, shuffle=False, num_workers=1)
    
    acc_train[client], acc_valid[client] = [], []
    loss_train[client], loss_test[client] = [], []


        
for client in CLIENTS:
    if client == 'Lung1' or client == 'Interobs':
        print(client)
        print(len(lung_dataset[client]))

## FedSPCA


In [None]:
WEIGHTS_POSTWARMUP = [0.3, 0.65, 0.025, 0.025, 0, 0] #put more weight to client with strong supervision
WARMUP_EPOCH = 150
CLIENTS_SUPERVISION = ['labeled', 'labeled', 'unlabeled','unlabeled','EXTERNAL1','ENTERNAL2']


### First 150 epochs warmup by training locally on labeled clients

In [None]:
best_avg_acc, best_epoch_avg = 0, 0
index = []
iter_nums = 0

USE_UNLABELED_CLIENT = False


for epoch in range(EPOCHS):
    print('epoch {} :'.format(epoch))
    if epoch == WARMUP_EPOCH:
        WEIGHTS = WEIGHTS_POSTWARMUP
        USE_UNLABELED_CLIENT = True
        
    index.append(epoch)
    
    #################### copy fed model ###################
    copy_fed(CLIENTS, nets, fed_name='global')
    
    #### conduct training #####
    for client, supervision_t in zip(CLIENTS, CLIENTS_SUPERVISION):
        if supervision_t == 'unlabeled':
            if not USE_UNLABELED_CLIENT:
                acc_train[client].append(0)
                loss_train[client].append(0)
                continue

        if client != 'Interobs' and client != 'Lung1':
            train_model(epoch,training_clients[client], optimizers[client], device,\
                                nets[client], ema_model= ema_net,\
                                    acc = acc_train[client], \
                                    loss = loss_train[client], \
                                    supervision_type = supervision_t, \
                                    learning_rate=LR,iter_num=iter_nums)

    aggr_fed(CLIENTS, WEIGHTS, nets)
    ################### test ################################
    avg_acc = 0.0
    for order, (client, supervision_t) in enumerate(zip(CLIENTS, CLIENTS_SUPERVISION)):
        test(epoch, testing_clients[client], device, nets['global'],ema_net, acc_valid[client],\
             loss_test[client])
        avg_acc += acc_valid[client][-1]
        if supervision_t == "labeled":
            score[order] = acc_valid[client][-1]
######################################################
    ####### dynamic weighting #########
    ###################################
    print("Score is :",score)
    WEIGHTS_DATA = copy.deepcopy(ORI_WEIGHTS)
    denominator = sum(score)
    score = [s/denominator for s in score]
    for order, _ in enumerate(WEIGHTS_DATA):
        WEIGHTS_DATA[order] = WEIGHTS_DATA[order]*score[order]
        
    ### normalize #####################
    denominator = sum(WEIGHTS_DATA)
    WEIGHTS_DATA = [w/denominator for w in WEIGHTS_DATA]


    if USE_UNLABELED_CLIENT:
        for order, supervision_t in enumerate(CLIENTS_SUPERVISION):
            if supervision_t == "labeled":
                WEIGHTS[order] =  copy.deepcopy(WEIGHTS_DATA[order]*0.95)        
    else:
        WEIGHTS = copy.deepcopy(WEIGHTS_DATA)
    
    print("weight is::::",WEIGHTS)
    w = []
    s = []
    w.append(WEIGHTS)
    s.append(score)
        

    avg_acc = avg_acc / TOTAL_CLIENTS
    ############################################################
    if avg_acc > best_avg_acc:
        best_avg_acc = avg_acc
        best_epoch = epoch
        save_model_4(PTH, epoch, nets, ema_net)
    save_mode_path = "./epoch/"
    torch.save(nets['global'].state_dict(), save_mode_path + 'epoch_' + str(epoch) + '.pth')
    torch.save(ema_net.state_dict(), save_mode_path + 'emaepoch_' + str(epoch) + '.pth')


    ################################
    # plot #########################
    ################################
    np.save(PTH+'/outcome/acc_train',acc_train)
    np.save(PTH+'/outcome/acc_test',acc_valid)
    np.save(PTH+'/outcome/loss_train',loss_train)
    np.save(PTH+'/outcome/weight',w)
    np.save(PTH+'/outcome/score',s)
    clear_output(wait=True)