In [None]:
from data_loader import load_datasets
from data_partitioner import partition_train_data, PartitionStrategy
from fleet_aggregation import BaseStrategy
from models import Net
from utilities import train, get_parameters, set_parameters
import matplotlib as plt
import torch
import random
from zod import ZodFrames
from zod import constants
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# GPU
device = torch.device(
    "cuda" if torch.cuda.is_available() else "cpu"
)
print(f"{device} is aviable")

# SETTINGS 

In [None]:
NO_CLIENTS = 40 #40
CLIENTS_PER_ROUND = 10 #10
PERCENT_DATA = 0.01
GLOBAL_ROUNDS = 40 #40

FILE_NAME = f"spc_{NO_CLIENTS}_{CLIENTS_PER_ROUND}_{int(PERCENT_DATA * 100)}"
PLOT_NAME = f"plt_rndm_{NO_CLIENTS}_{CLIENTS_PER_ROUND}_{int(PERCENT_DATA * 100)}"


def select_client(clients):
    return random.sample(clients,CLIENTS_PER_ROUND)


#Använd för att köra clients i ordning
def order_client(clients):
    if(len(clients) >= 10):
        result = clients[:9]
        del clients[:9]
        return result
    else:
        print("ERROR: List of Clients is empty")

        # Avkommentera för att kunna loopa igenom clients i ordning flera gånger
        clients = [str(i) for i in range(NO_CLIENTS)]
        result = clients[:9]
        del clients[:9]
        return result


LOAD PARAMETERS AND MODEL

In [None]:
def setup():
    global testloader, clients, partitions, device, strategy, round_test_losses, zod_frames, round_train_losses, net
    # GPU
    device = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu"
    )
    print(f"{device} is aviable")

    # Pictures
    zod_frames = ZodFrames("/mnt/ZOD", version="full")
    strategy = BaseStrategy()
    net = Net().to(device)  
    clients = [str(i) for i in range(NO_CLIENTS)]

    
    partitions = partition_train_data(
        PartitionStrategy.RANDOM,
        NO_CLIENTS,
        zod_frames,
        PERCENT_DATA)
    

    print("-------------------------------------------------")
    print("-------------------------------------------------")
    print(partitions)
    print("-------------------------------------------------")
    print("-------------------------------------------------")

# testloader is a class that can be used to iterate over a dataset'
    _, testloader = load_datasets(zod_frames.get_split(constants.VAL), zod_frames)
    print("Testloader:",len(testloader))
    round_test_losses = []
    round_train_losses = []
    return testloader, clients, partitions, device, strategy, round_test_losses, zod_frames, round_train_losses, net

setup()

# TRAIN | MAIN CODE

In [None]:
def main() -> None:  
    early_stopping = False

    for round in range(1, GLOBAL_ROUNDS+1):
        if early_stopping:
            break
        
        print(" ")
        print("ROUND", round)
        selected = select_client(clients)
        nets = []
        
        for client_idx in selected:
            
            net_copy = net.to(device) 

            net_copy.load_state_dict(net.state_dict()) 
            net_copy.train()
    
            trainloader, valloader = load_datasets(partitions[str(client_idx)], zod_frames)

            epoch_train_losses, epoch_val_losses = train(net_copy, trainloader, valloader, epochs=5)
            print(f"Client: {client_idx:>2} Train losses: {epoch_train_losses}, Val losses: {epoch_val_losses}")
            
            nets.append((get_parameters(net_copy), 1))


        agg_weights = strategy.aggregate_fit_fedavg(nets)
        
        set_parameters(net, agg_weights[0])

        net.eval()
        batch_test_losses = []
        batch_train_losses = []

        for data, target in testloader:
            data, target = data.to(device), target.to(device)
            pred = net(data)
            batch_test_losses.append(net.loss_fn(pred, target).item())

        batch_train_losses.append(sum(epoch_train_losses)/len(epoch_train_losses))


        round_test_losses.append(sum(batch_test_losses)/len(batch_test_losses))
        round_train_losses.append(sum(batch_train_losses)/len(batch_train_losses))

        # EARLY STOPPING
        train_count = 0
        if round > 1:
            if float(abs(round_train_losses[-1] - round_train_losses[-2])) <= 0.05:
                train_count += 1
                if train_count == 5:
                    early_stopping = True
                else: 
                    train_count = 0

        test_count = 0
        if round > 1:  
            if float(round_test_losses[-1] - round_test_losses[-2]) > 0.15:  
                test_count += 1
                if test_count == 5:
                    early_stopping = True  
                else:
                    test_count = 0

        
        print(f"Test loss: {round_test_losses[-1]}")
        print(f"Training loss: {round_train_losses[-1]}")


    print('==========================================================')
    print(round_test_losses,'TEST')
    print(round_train_losses,'TRAIN') 
    print('==========================================================')

    length_of_round = len(round_test_losses)

    model_params = get_parameters(net)
    model_params = np.array(model_params, dtype=object)
    filename = f"{FILE_NAME}_{length_of_round}.npz"
    np.savez(filename, model_params)

if __name__ == "__main__":

    main()

PLOT

In [None]:
length_of_round = len(round_test_losses)

plt.figure(figsize=(10, 5))  
plt.plot(range(1, length_of_round + 1), round_test_losses, marker='o', linestyle='-', label='Test Loss')
plt.plot(range(1, length_of_round + 1), round_train_losses, marker='o', linestyle='-', label='Train Loss')
plt.title('Test Loss Over Global Rounds')
plt.xlabel('Global Rounds')
plt.ylabel('Test Loss')
plt.grid(True)
plt.legend() 
plt.savefig(f"{PLOT_NAME}_{length_of_round}.png")



In [None]:
print(round_test_losses)

In [None]:
# round_test_losses.clear()
# round_train_losses.clear()