In [None]:
import shutil
import os
import gc
import time
import copy
from enum import Enum

os.chdir('../')

from femnist import download_femnist
from utils import CLIENT_MODEL_DIR, ThresholdType, CurriculumType

from ditto.base.ditto_server import DittoStrategy as BaseDittoStrategy, fit_config_fn_generator as base_fit_config_fn_generator
from ditto.base.ditto_client import ditto_client_fn as base_ditto_client_fn

from ditto.curriculum_learning.ditto_server import DittoStrategy as CurriculumDittoStrategy, fit_config_fn_generator as curriculum_fit_config_fn_generator
from ditto.curriculum_learning.ditto_client import ditto_client_fn as curriculum_ditto_client_fn

import flwr as fl
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import ray

In [None]:
# Download FEMNIST data (does nothing if already present)
download_femnist()

In [None]:
class DittoType(Enum):
    BASE = 0
    CURRICULUM = 1

def plot_results(results: dict, var_name: str):
    """
    Plot results generated by experiment
    
    :param results: dictionary of results provided by experiment
    :param var_name: name of variable being plotted
    """
    global_accuracies = pd.DataFrame.from_dict({key: [acc for _, acc in results[key].metrics_distributed['avg_global_accuracy']] for key in results.keys()})
    local_accuracies = pd.DataFrame.from_dict({key: [acc for _, acc in results[key].metrics_distributed['avg_local_accuracy']] for key in results.keys()})
    accuracy_differential = global_accuracies - local_accuracies
    
    _, axs = plt.subplots(1,3, figsize=(15,4))

    global_accuracies.plot(ax=axs[0])
    axs[0].set_xlabel('Round')
    axs[0].set_ylabel('Mean global accuracy')
    axs[0].legend(title=var_name)
    axs[0].grid()
    
    local_accuracies.plot(ax=axs[1])
    axs[1].set_xlabel('Round')
    axs[1].set_ylabel('Mean local accuracy')
    axs[1].legend(title=var_name)
    axs[1].grid()
    
    accuracy_differential.plot(ax=axs[2])
    axs[2].set_xlabel('Round')
    axs[2].set_ylabel('Mean global - local accuracies')
    axs[2].legend(title=var_name)
    axs[2].axhline(y=0, color='black', linestyle='--', linewidth=2.5, alpha=0.5)
    axs[2].grid()
    
    plt.tight_layout()
    plt.show()
    
def ditto_experiment(ditto_type: DittoType, var_name: str, values: list, config: dict = None) -> dict:
    """
    Run ditto experiment
    
    :param ditto_type: The type of ditto clients to run (base or curriculum learning)
    :param var_name: The name of the variable in the config dictionary to vary
    :param values: The range of values to try for the variable
    :param config: The base config settings to use for every trial
    """
    base_config = {} if config is None else copy.deepcopy(config)
    
    start_time = time.time()

    match ditto_type:
        case DittoType.BASE:
            ditto_client_fn = base_ditto_client_fn
            fit_config_fn_generator = base_fit_config_fn_generator
            DittoStrategy = BaseDittoStrategy
        case DittoType.CURRICULUM:
            ditto_client_fn = curriculum_ditto_client_fn
            fit_config_fn_generator = curriculum_fit_config_fn_generator
            DittoStrategy = CurriculumDittoStrategy
        case _:
            raise Exception('Invalid DittoType')
       
    
    results = {}
    
    for value in values:
        print(f'{var_name} =', value)
        # Reset client models
        if os.path.exists(CLIENT_MODEL_DIR):
            shutil.rmtree(CLIENT_MODEL_DIR)  # Delete client model directory
        os.makedirs(CLIENT_MODEL_DIR)  # Recreate client model directory
    
        # https://flower.ai/docs/framework/tutorial-series-customize-the-client-pytorch.html
        results[value] = fl.simulation.start_simulation(
            # num_clients=num_clients,
            clients_ids=clients_ids,
            client_fn=ditto_client_fn,
            config=fl.server.ServerConfig(num_rounds=num_rounds),
            strategy=DittoStrategy(
                log_accuracy=True, 
                on_fit_config_fn=fit_config_fn_generator(base_config|{var_name: value})
            ),
            client_resources={
                'num_cpus': max(os.cpu_count()//num_clients, 1)
            }
        )
        
        # Clean up after simulation to prevent memory leakage
        ray.shutdown()
        gc.collect()
        
    elapsed_time = time.time() - start_time
    print('Elapsed time:', time.strftime("%H:%M:%S", time.gmtime(elapsed_time)))
    
    return results

In [None]:
num_clients = 8
num_rounds = 50
epochs_per_round = 25

np.random.seed(42)
clients_ids = list(map(str, np.random.randint(0, len(next(os.walk('./femnist/client_data_mappings/fed_natural'))[1]), size=num_clients)))

# Base $\texttt{Ditto}$
### Base $\texttt{Ditto}$ with $\lambda \in [0,0.05]$

In [None]:
results = ditto_experiment(
    DittoType.BASE, 
    'lambda', 
    [0.0, 0.01, 0.02, 0.03, 0.04, 0.05]
)

In [None]:
plot_results(results, '$\lambda$')

### Base $\texttt{Ditto}$ with $\lambda \in [0,1]$

In [None]:
results = ditto_experiment(
    DittoType.BASE, 
    'lambda', 
    [0.00, 0.25, 0.50, 0.75, 1.00]
)

In [None]:
plot_results(results, '$\lambda$')

# Curriculum Learning


## Self-Paced Learning

### Self-paced using $\lambda\in [0,0.05]$ with fixed quantile cutoff $K=0.95$

In [None]:
results = ditto_experiment(
    DittoType.CURRICULUM,
    'lambda', 
    [0.00, 0.01, 0.02, 0.03, 0.04, 0.05], 
    {
        'local_epochs': epochs_per_round,
        'loss_threshold': 0.95,
        'threshold_type': ThresholdType.QUANTILE,
        'percentile_type': 'linear', 
        'curriculum_type': CurriculumType.SELF_PACED
    }
)

In [None]:
plot_results(results, '$\lambda$')

### Self-paced using quantile cutoff $K\in [0.5,1]$ with fixed $\lambda=0.1$

In [None]:
results = ditto_experiment(
    DittoType.CURRICULUM,
    'loss_threshold', 
    [0.5, 0.6, 0.7, 0.8, 0.9, 1.0], 
    {
        'local_epochs': epochs_per_round,
        'threshold_type': ThresholdType.QUANTILE,
        'percentile_type': 'linear', 
        'curriculum_type': CurriculumType.SELF_PACED,
        'lambda': 0.1
    }
)

In [None]:
plot_results(results, 'Threshold')

## Transfer-Teacher Learning

### Transfer-teacher using $\lambda\in [0,0.05]$ with fixed quantile cutoff $K=0.95$

In [None]:
results = ditto_experiment(
    DittoType.CURRICULUM,
    'lambda', 
    [0.00, 0.01, 0.02, 0.03, 0.04, 0.05], 
    {
        'local_epochs': epochs_per_round,
        'loss_threshold': 0.95,
        'threshold_type': ThresholdType.QUANTILE,
        'percentile_type': 'linear', 
        'curriculum_type': CurriculumType.TRANSFER_TEACHER
    }
)

In [None]:
plot_results(results, '$\lambda$')

### Transfer-teacher using quantile cutoff $K\in [0.5,1]$ with fixed $\lambda=0.1$

In [None]:
results = ditto_experiment(
    DittoType.CURRICULUM,
    'loss_threshold', 
    [0.5, 0.6, 0.7, 0.8, 0.9, 1.0], 
    {
        'local_epochs': epochs_per_round,
        'threshold_type': ThresholdType.QUANTILE,
        'percentile_type': 'linear', 
        'curriculum_type': CurriculumType.TRANSFER_TEACHER,
        'lambda': 0.1
    }
)

In [None]:
plot_results(results, 'Threshold')