In [1]:
import os
import numpy as np
import random
import pickle
import copy
from matplotlib import pyplot as plt
from datetime import datetime


np.random.seed(0)
random.seed(0)

#from sklearn.model_selection import KFold

> All the copy and pasted imports since Jupyter doesnt treat the environment as a package...

In [8]:
class ModelBase:
    def __init__(self, ID, w, opt_method, smoothbatch_lr=1, alphaF=0.0, alphaE=1e-6, alphaD=1e-4, verbose=False, starting_update=9, PCA_comps=64, current_round=0, num_clients=14, log_init=0):
        # Not input
        self.num_updates = 19
        self.starting_update=starting_update
        self.update_ix = [0,  1200,  2402,  3604,  4806,  6008,  7210,  8412,  9614, 10816, 12018, 13220, 14422, 15624, 16826, 18028, 19230, 20432, 20769]
        self.id2color = {0:'lightcoral', 1:'maroon', 2:'chocolate', 3:'darkorange', 4:'gold', 5:'olive', 6:'olivedrab', 
                7:'lawngreen', 8:'aquamarine', 9:'deepskyblue', 10:'steelblue', 11:'violet', 12:'darkorchid', 13:'deeppink'}
        
        self.type = 'BaseClass'
        self.ID = ID
        self.PCA_comps = PCA_comps
        self.pca_channel_default = 64  # When PCA_comps equals this, DONT DO PCA
        if w.shape!=(2, self.PCA_comps):
            #print(f"Class BaseModel: Overwrote the provided init decoder: {w.shape} --> {(2, self.PCA_comps)}")
            self.w = np.random.rand(2, self.PCA_comps)
        else:
            self.w = w
        self.w_prev = copy.deepcopy(self.w)
        self.dec_log = [copy.deepcopy(self.w)]
        self.w_prev = copy.deepcopy(self.w)
        self.num_clients = num_clients
        self.log_init = log_init

        self.alphaF = alphaF
        self.alphaE = alphaE
        self.alphaD = alphaD

        self.local_train_error_log = []
        self.global_train_error_log = []
        self.local_test_error_log = []
        self.global_test_error_log = []
        
        self.opt_method = opt_method.upper()
        self.current_round = current_round
        self.verbose = verbose
        self.smoothbatch_lr = smoothbatch_lr

    def __repr__(self): 
        return f"{self.type}{self.ID}"
    
    def display_info(self): 
        return f"{self.type} model: {self.ID}\nCurrent Round: {self.current_round}\nOptimization Method: {self.opt_method}"
    

In [9]:
class Server(ModelBase):
    def __init__(self, ID, D0, opt_method, global_method, all_clients, smoothbatch_lr=0.75, C=0.35, test_split_type='kfoldcv', 
                 num_kfolds=5, test_split_frac=0.3, current_round=0, PCA_comps=64, verbose=False, validate_memory_IDs=True, save_client_loss_logs=True, 
                 sequential=False, current_datatime="Overwritten"):
        super().__init__(ID, D0, opt_method, smoothbatch_lr=smoothbatch_lr, current_round=current_round, PCA_comps=PCA_comps, 
                         verbose=verbose, num_clients=14, log_init=0)
        
        self.type = 'Server'
        self.save_client_loss_logs = save_client_loss_logs
        self.sequential = sequential

        # CLIENT LISTS!
        self.num_avail_clients = 0
        self.available_clients_lst = [0]*len(all_clients)  # Why is this not just an empty list...
        # ^ This should really be: [cli for cli in all_clients if cli.availability==1]
        self.num_chosen_clients = 0
        self.chosen_clients_lst = [0]*len(all_clients)  # Why is this not just an empty list...
        self.all_clients = all_clients  # ALL TRAIN AND VAL AND STATIC/UNAVAILABLE/QUEUED CLIENTS ARE PASSED IN HERE!!!
        self.num_total_clients = len(self.all_clients)
        # THE BELOW SHOULD NOT CHANGE DURING THE RUN!
        ## Maybe train_clients can/should change, to reflect availability in the seq case? ...
        self.train_clients = [cli for cli in self.all_clients if cli.val_set is False]
        self.val_clients = [cli for cli in self.all_clients if cli.val_set is True]
        self.num_train_clients = len(self.train_clients)
        self.num_test_clients = len(self.val_clients)
        # SET AVAILABLE CLIENT LIST
        self.set_available_clients_list()
        # Init all train clients with simulate_streaming so they have self.s, self.F, and self.p_reference
        for client in self.train_clients:
            client.simulate_data_stream()

        self.global_method = global_method.upper()
        print(f"Running the {self.global_method} algorithm as the global method!")
        self.C = C  # Fraction of clients to use each round

        # TESTING
        self.test_split_type = test_split_type.upper()
        self.num_kfolds = num_kfolds
        self.test_split_frac = test_split_frac

        # SAVE FILE RELATED
        # Get the directory of the current script
        self.script_directory = os.path.dirname(os.path.abspath(__file__))  # This returns the path to serverbase... so don't index the end of the path
        # Relative path to results dir
        self.result_path = "\\results\\"
        if current_datatime is None:
            self.set_save_filename()

        # EVALUATE INIT MODEL AND SAVE THAT LOSS
        ## See if this fixes each algo having different init losses...
        for client_idx, my_client in enumerate(self.train_clients): # Eg all train-able clients (no witheld val clients from kfoldcv)
            # test_metrics for all clients
            local_test_loss, _ = my_client.test_metrics(my_client.w, 'local')
            local_train_loss, _ = my_client.train_metrics(my_client.w, 'local')
            if self.global_method!='NOFL':
                global_test_loss, _ = my_client.test_metrics(self.w, 'global')
                global_train_loss, _ = my_client.train_metrics(self.w, 'global')
            else:
                global_test_loss = 0
                global_train_loss = 0

            if client_idx!=0:
                running_global_test_loss += global_test_loss
                running_local_test_loss += local_test_loss
                running_global_train_loss += global_train_loss
                running_local_train_loss += local_train_loss
            else:
                running_global_test_loss = global_test_loss
                running_local_test_loss = local_test_loss
                running_global_train_loss = global_train_loss
                running_local_train_loss = local_train_loss
        self.local_test_error_log.append(running_local_test_loss / len(self.train_clients))
        self.local_train_error_log.append(running_local_train_loss / len(self.train_clients))
        if self.global_method!='NOFL':
            self.global_test_error_log.append(running_global_test_loss / len(self.train_clients))
            self.global_train_error_log.append(running_global_train_loss / len(self.train_clients))
        

    def set_save_filename(self, current_datetime=None):
        if current_datetime is None:
            # get current date and time
            current_datetime = datetime.now().strftime("%m-%d_%H-%M")

        # convert datetime obj to string
        self.str_current_datetime = str(current_datetime)
        # Specify the relative path from the script's directory
        self.relative_path = self.result_path + self.str_current_datetime + "_" + self.global_method
        # Combine the script's directory and the relative path to get the full path
        self.trial_result_path = self.script_directory + self.relative_path
        self.h5_file_path = os.path.join(self.trial_result_path, f"{self.opt_method}_{self.global_method}")
        self.paramtxt_file_path = os.path.join(self.trial_result_path, "param_log.txt")
        if not os.path.exists(self.trial_result_path):
            os.makedirs(self.trial_result_path)

                
    # Main Loop
    def execute_FL_loop(self):
        # Update global round number
        self.current_round += 1
        
        if self.global_method=='FEDAVG' or 'PFA' in self.global_method:
            # Choose fraction C of available clients
            self.set_available_clients_list()
            self.choose_clients()
            for my_client in self.all_clients:  
                if my_client.val_set==True:
                    # This is a val client, so don't log anything
                    # Could be included in train_metrics but doesnt need to be
                    continue
                elif my_client.availability==False:
                    raise ValueError("Sequential case not implemented yet")
                elif my_client not in self.chosen_clients_lst:
                    # Just not getting trained this round
                    continue
                # THESE ARE THE CLIENTS WHICH ACTUALLY GET TRAINED!
                my_client.latest_global_round = self.current_round          
                # Send those clients the current global model
                my_client.global_w = copy.deepcopy(self.w)
                my_client.execute_training_loop()
            # AGGREGATION
            self.agg_local_weights()  # This func sets self.w, eg the new decoder
            # GLOBAL SmoothBatch
            self.w = self.smoothbatch_lr*self.w_prev + ((1 - self.smoothbatch_lr)*self.w)
        elif self.global_method=='NOFL':
            # TODO: Is NoFL just supposed to be the Local CPHS sims... if so this is fine I think 
            for my_client in self.all_clients:  
                if my_client.val_set==True:
                    # This is a val client, so don't log anything
                    # Could be included in train_metrics but doesnt need to be
                    continue
                elif my_client.availability==False:
                    raise ValueError("Sequential case not implemented yet")
                # THESE ARE THE CLIENTS WHICH ACTUALLY GET TRAINED!
                my_client.latest_global_round = self.current_round          
                my_client.execute_training_loop()
        else:
            raise('Method not currently supported, please reset method to FedAvg')
        
        # Save the new decoder to the log
        self.dec_log.append(copy.deepcopy(self.w))
        # Run train_metrics and test_metrics to log performance on training/testing data
        for client_idx, my_client in enumerate(self.train_clients): # Eg all train-able clients (no witheld val clients from kfoldcv)
            # Reset all clients so no one is chosen for the next round
            my_client.chosen_status = 0
            
            # test_metrics for all clients
            local_test_loss, _ = my_client.test_metrics(my_client.w, 'local')
            local_train_loss, _ = my_client.train_metrics(my_client.w, 'local')
            if self.global_method=='FEDAVG' or 'PFA' in self.global_method:
                global_test_loss, _ = my_client.test_metrics(self.w, 'global')
                global_train_loss, _ = my_client.train_metrics(self.w, 'global')
            elif self.global_method=='NOFL':
                global_test_loss = 0
                global_train_loss = 0
            
            if client_idx!=0:
                running_global_test_loss += global_test_loss
                running_local_test_loss += local_test_loss

                running_global_train_loss += global_train_loss
                running_local_train_loss += local_train_loss
            else:
                running_global_test_loss = global_test_loss
                running_local_test_loss = local_test_loss

                running_global_train_loss = global_train_loss
                running_local_train_loss = local_train_loss

        # SERVER AVERAGE PERFORMANCE
        # Divide by the number of clients to get average loss per client
        ## For local, having the individual client logs would be better but they would probably get averaged anyways when plotting so
        self.local_test_error_log.append(running_local_test_loss / len(self.train_clients))
        self.local_train_error_log.append(running_local_train_loss / len(self.train_clients))
        if self.global_method!='NOFL':
            self.global_test_error_log.append(running_global_test_loss / len(self.train_clients))
            self.global_train_error_log.append(running_global_train_loss / len(self.train_clients))
            
            
    def set_available_clients_list(self):
        # TODO come back to this depending on how available_clients_full_idx_lst shakes out...
        self.available_clients_full_idx_lst = [0]*len(self.train_clients)
        for idx, my_client in enumerate(self.train_clients):
            if my_client.availability:
                self.available_clients_full_idx_lst[idx] = my_client
        # cli can be 0 if that client is not available this round... --> I think self.train_clients would need to be self.all_clients above tho...
        self.available_clients_lst = [cli for cli in self.available_clients_full_idx_lst if cli != 0]
        self.num_avail_clients = len(self.available_clients_lst)
    

    def choose_clients(self):
        # Check what client are available this round
        self.set_available_clients_list()
        # Now choose frac C clients from the resulting available clients
        if self.num_avail_clients > 0:
            self.num_chosen_clients = int(np.ceil(self.num_avail_clients*self.C))
            if self.num_chosen_clients<1:
                raise ValueError(f"ERROR: Chose {self.num_chosen_clients} clients for some reason, must choose more than 1")
            # Right now it chooses 2 at random: 14*.1=1.4 --> 2
            self.chosen_clients_lst = random.sample(self.available_clients_lst, len(self.available_clients_lst))[:self.num_chosen_clients]
            for my_client in self.chosen_clients_lst:
                my_client.chosen_status = 1
        else:
            raise(f"ERROR: Number of available clients must be greater than 0: {self.num_avail_clients}")


In [10]:
path = r'C:\Users\kdmen\Desktop\Research\Data\CPHS_EMG'
model_saving_dir = r"C:\Users\kdmen\Desktop\Research\personalization-privacy-risk\PythonVersion\PythonSimsRevamp\models"
cond0_filename = r'\cond0_dict_list.p'
all_decs_init_filename = r'\all_decs_init.p'
nofl_decs_filename = r'\nofl_decs.p'
id2color = {0:'lightcoral', 1:'maroon', 2:'chocolate', 3:'darkorange', 4:'gold', 5:'olive', 6:'olivedrab', 
            7:'lawngreen', 8:'aquamarine', 9:'deepskyblue', 10:'steelblue', 11:'violet', 12:'darkorchid', 13:'deeppink'}
implemented_client_training_methods = ['GD', 'FullScipyMin', 'MaxIterScipyMin']
NUM_USERS = 14
# For exclusion when plotting later on
#bad_nodes = [] #[1,3,13]
D_0 = np.random.rand(2,64)
num_updates = 18
step_indices = list(range(num_updates))

MAX_ITER=None  # For MAXITERSCIPYMIN. Use FULLSCIPYMIN for complete minimization, otherwise stay with 1

COLORS_LST = ['red', 'blue', 'magenta', 'orange', 'darkviolet', 'lime']
ALPHA = 0.7

# get current date and time
CURRENT_DATETIME = str(datetime.now().strftime("%m-%d_%H-%M"))

STARTING_UPDATE=10
DATA_STREAM='streaming'
NUM_KFOLDS=5
USE_HITBOUNDS = False
PLOT_EACH_FOLD = False
USE_KFOLDCV = True
TEST_SPLIT_TYPE = 'KFOLDCV'

In [11]:

def gradient_cost_l2(F, D, V, alphaD=1e-4, alphaE=1e-6, 
                     Nd=2, Ne=64, flatten=True):
    D = np.reshape(D,(Nd, Ne))
    Vplus = V[:,1:]
    if flatten:
        return (2*(D@F - Vplus)@F.T*(alphaE) + 2*alphaD*D ).flatten()
    else:
        return 2*(D@F - Vplus)@F.T*(alphaE) + 2*alphaD*D 

def cost_l2(F, D, V, alphaD=1e-4, alphaE=1e-6, Nd=2, Ne=64, return_cost_func_comps=False):
    D = np.reshape(D,(Nd,Ne))
    Vplus = V[:,1:]
    # Performance
    term1 = alphaE*(np.linalg.norm((D@F - Vplus))**2)
    # D Norm (Decoder Effort)
    term2 = alphaD*(np.linalg.norm(D)**2)
    # F Norm (User Effort)
    #term3 = alphaF*(np.linalg.norm(F)**2)
    if return_cost_func_comps:
        return (term1 + term2), term1, term2
    else:
        return (term1 + term2)


In [None]:
GLOBAL_METHOD = "NOFL"  #FedAvg #PFAFO_GDLS #NOFL
OPT_METHOD = 'FULLSCIPYMIN' if GLOBAL_METHOD=="NOFL" else 'GDLS'
GLOBAL_ROUNDS = 12 if GLOBAL_METHOD=="NOFL" else 250
LOCAL_ROUND_THRESHOLD = 1 if GLOBAL_METHOD=="NOFL" else 20
NUM_STEPS = 3  # This is basically just local_epochs. Num_grad_steps
BETA=0.01  # Not used with GDLS? Only pertains to PFA regardless
LR=1  # Not used with GDLS?

## CROSS-SUBJECT

In [None]:
SCENARIO = "CROSS"

# THIS K FOLD SCHEME IS ONLY FOR CROSS-SUBJECT ANALYSIS!!!
# Define number of folds
kf = KFold(n_splits=NUM_KFOLDS)
# Assuming cond0_training_and_labels_lst is a list of labels for 14 clients
user_ids = list(range(14))
folds = list(kf.split(user_ids))

for fold_idx, (train_ids, test_ids) in enumerate(folds):
    print(f"Fold {fold_idx+1}/{NUM_KFOLDS}")
    print(f"{len(train_ids)} Train_IDs: {train_ids}")
    print(f"{len(test_ids)} Test_IDs: {test_ids}")
    
    # Initialize clients for training
    train_clients = [Client(i, copy.deepcopy(D_0), OPT_METHOD, cond0_training_and_labels_lst[i], DATA_STREAM,
                            beta=BETA, scenario=SCENARIO, local_round_threshold=LOCAL_ROUND_THRESHOLD, lr=LR, current_fold=fold_idx, num_kfolds=NUM_KFOLDS, global_method=GLOBAL_METHOD, max_iter=MAX_ITER, 
                            num_steps=NUM_STEPS, use_zvel=USE_HITBOUNDS, test_split_type=TEST_SPLIT_TYPE) for i in train_ids]
    # Initialize clients for testing
    test_clients = [Client(i, copy.deepcopy(D_0), OPT_METHOD, cond0_training_and_labels_lst[i], DATA_STREAM,
                           beta=BETA, scenario=SCENARIO, local_round_threshold=LOCAL_ROUND_THRESHOLD, lr=LR, current_fold=fold_idx, availability=False, val_set=True, num_kfolds=NUM_KFOLDS, global_method=GLOBAL_METHOD, max_iter=MAX_ITER, 
                           num_steps=NUM_STEPS, use_zvel=USE_HITBOUNDS, test_split_type=TEST_SPLIT_TYPE) for i in test_ids]

    testing_datasets_lst = []
    for test_cli in test_clients:
        testing_datasets_lst.append(test_cli.get_testing_dataset())
    for train_cli in train_clients:
        train_cli.set_testset(testing_datasets_lst)

    full_client_lst = train_clients+test_clients

    for cli in train_clients:
        for model in my_testing_models:
            test_loss, vel_error, dec_error = cli.test_metrics(model, which="global", return_cost_func_comps=True)


## INTRA-SUBJECT

In [None]:
SCENARIO = "INTRA"
for fold_idx in range(NUM_KFOLDS):
    print(f"Fold {fold_idx+1}/{NUM_KFOLDS}")
    # Initialize clients for training
    full_client_lst = [Client(i, copy.deepcopy(D_0), OPT_METHOD, cond0_training_and_labels_lst[i], DATA_STREAM, 
                            scenario=SCENARIO, local_round_threshold=LOCAL_ROUND_THRESHOLD, current_fold=fold_idx, global_method=GLOBAL_METHOD, max_iter=MAX_ITER, 
                            num_steps=NUM_STEPS, use_zvel=USE_HITBOUNDS, test_split_type=TEST_SPLIT_TYPE) for i in range(NUM_USERS)]