## NOTE : For running this notebook you should download the dataset from :
###    https://www.kaggle.com/datasets/nomihsa965/traffic-signs-dataset-mapillary-and-dfg
##    and then put it in the main directory (Argos/) by "data" naming

In [1]:
from Argos.settings import CLASSES_JSON_FILE , DATASET_PATH , NUMBER_OF_CLIENTS , CLIENT_BATCH_SIZE , CLIENT_LEARNING_RATE

In [2]:
from Argos.Dataset_utils import extract_label_mapping , MTSDDataset , partition_dataset, get_dataset_for_client

In [3]:
from torch.utils.data import Subset

label_mapping = extract_label_mapping(CLASSES_JSON_FILE)
number_of_classes = len(label_mapping)
dataset = MTSDDataset(root_dir=DATASET_PATH)
dataset = Subset(dataset, list(range(10)))
partitioned_dataset_indices = partition_dataset(dataset=dataset, num_clients=NUMBER_OF_CLIENTS)


In [4]:
partitioned_dataset_indices

defaultdict(list,
            {0: [7, 7, 8],
             1: [1, 4, 3, 8],
             2: [6, 0, 6],
             3: [2, 0],
             4: [5, 9]})

## Client App

In [5]:
from Argos.Model import get_model
from Argos.Client import Client
from flwr.common import Context
from Argos.settings import DEVICE
from flwr.client import ClientApp


def new_client(context : Context) -> Client:
    """Create a Flower client representing a single organization."""

    neural_network = get_model(
        num_classes=number_of_classes
    ).to(DEVICE)

    partition_id = context.node_config["partition-id"]
    train_dataset , val_dataset ,test_dataset = get_dataset_for_client(
        partition_id=partition_id,
        full_dataset=dataset,
        partitioned_dataset_indices=partitioned_dataset_indices,
    )


    return Client(
        model=neural_network,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        learning_rate=CLIENT_LEARNING_RATE,
        batch_size=CLIENT_BATCH_SIZE,
    ).to_client()



client = ClientApp(client_fn=new_client)

## Server App

In [6]:

from flwr.server.strategy import FedAvg

strategy = FedAvg(
    fraction_fit=1.0,  # Sample 100% of available clients for training
    fraction_evaluate=0.5,  # Sample 50% of available clients for evaluation
)

In [7]:
from flwr.common import Context
from flwr.server import ServerAppComponents, ServerConfig, ServerApp


def server_fn(context: Context) -> ServerAppComponents:
    """Construct components that set the ServerApp behaviour.

    You can use the settings in `context.run_config` to parameterize the
    construction of all elements (e.g the strategy or the number of rounds)
    wrapped in the returned ServerAppComponents object.
    """

    # Configure the server for 5 rounds of training
    config = ServerConfig(num_rounds=5,round_timeout=30.0)

    return ServerAppComponents(strategy=strategy, config=config)


# Create the ServerApp
server = ServerApp(server_fn=server_fn)

In [8]:
# Specify the resources each of your clients need
# By default, each client will be allocated 2x CPU and 0x GPUs

backend_config = {
    "client_resources": {
        "num_cpus": 2, "num_gpus": 0.0
    }
}



In [9]:
from flwr.simulation import run_simulation

run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUMBER_OF_CLIENTS,
    backend_config=backend_config,
)

DEBUG:flwr:Asyncio event loop already running.
[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=5, round_timeout=30.0s
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)
Training:   0%|          | 0/1 [00:00<?, ?it/s]
[92mINFO [0m:      aggregate_fit: received 0 results and 5 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 2 clients (out of 5)
[91mERROR [0m:     An exception was raised when processing a message by RayBackend
[91mERROR [0m:     [36mray::ClientAppActor.run()[39m (pid=1634, ip=127.0.0.1, actor_id=de87f141da9b3b9cb95d47ec01000000, re