# MobiML FL demo

Using Flower and MobiML




In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
import pickle
import warnings
import pandas as pd
import geopandas as gpd
import numpy as np
from datetime import datetime, timedelta
from copy import deepcopy
from typing import Dict, List, Tuple
from pathlib import Path
from sklearn.metrics import log_loss
from sklearn.preprocessing import MultiLabelBinarizer
import pymeos

import flwr as fl
from flwr.common import Metrics, Context

from mobiml.datasets import AISDK, MOVER_ID, SHIPTYPE
from mobiml.transforms import StationaryClientExtractor, TripExtractor, TrajectoryAggregator, MobileClientExtractor
from mobiml.models import SummarizedAISTrajectoryClassifier
from mobiml.models.ais_trajectory_classifier import AISLoader, get_evaluate_fn, fit_round, weighted_average
from mobiml.utils import convert_wgs_to_utm


## Set FL simulation output verbosity settings


RAY_BACKEND_LOG_LEVEL: to declutter log output in shells, a low-volume log level has been chosen from [the source code](https://github.com/ray-project/ray/blob/master/src/ray/util/logging.cc#L273)

RAY_DEDUP_LOGS: to see logs from all clients instead of just one

In [None]:
os.environ['RAY_DEDUP_LOGS'] = '0'
os.environ['RAY_BACKEND_LOG_LEVEL'] = 'fatal'


In [None]:
path = "./data/aisdk-2018-02.zip" # For a real-world experience, download AISDK input file (1 month of data, 14G) from here http://web.ais.dk/aisdata/aisdk-2018-02.zip

#path = "./data/aisdk_20180208_sample.zip"

## Extract stationary client (antenna) data

In [None]:
antennas = ['Point (11.96524 57.70730)', 'Point (11.63979 57.71941)', 'Point (11.78460 57.57255)']
antenna_radius_meters = 25000


In [None]:
epsg_code = convert_wgs_to_utm(11.96524, 57.70730)

ids =  [{'client': i} for i in range(len(antennas))]
df = pd.DataFrame(ids)
df['geometry'] = gpd.GeoSeries.from_wkt(antennas)
gdf = gpd.GeoDataFrame(df, geometry=df.geometry, crs=4326)
gdf = gdf.to_crs(epsg_code)
gdf['geometry'] = gdf.buffer(antenna_radius_meters)

buffered_antennas =  gdf.to_crs(4326)
min_lon, min_lat, max_lon, max_lat = buffered_antennas.geometry.total_bounds

In [None]:
out_dir = "temp"
os.path.dirname(out_dir)
if not os.path.exists(out_dir):
    print(f"{datetime.now()} Creating output directory {out_dir} ...")
    os.makedirs(out_dir)

In [None]:
print(f"{datetime.now()} Loading data from {path}")
aisdk = AISDK(path, min_lon, min_lat, max_lon, max_lat)

In [None]:
print(f"{datetime.now()} Extracting client data ...")
antenna_gdf = StationaryClientExtractor(aisdk, buffered_antennas)

In [None]:
stationary_feather_path = f"{out_dir}/ais-antenna.feather"


In [None]:
print(f"{datetime.now()} Writing output to {stationary_feather_path}")
antenna_gdf.to_feather(stationary_feather_path)

## Extract mobile client (vessel) data

In [None]:
ship_type = 'Towing' 
antenna_radius_meters = 25000  
bbox = [57.273, 11.196, 57.998, 12.223]  
min_lat, min_lon, max_lat, max_lon = bbox

In [None]:
mobile_feather_path = f"{out_dir}/ais-vessels.feather"

In [None]:
print(f"{datetime.now()} Loading data from {path}")
aisdk = AISDK(path, min_lon, min_lat, max_lon, max_lat)
vessels = deepcopy(aisdk)   # AISDK(path, min_lon, min_lat, max_lon, max_lat, vessel_type)
vessels.df = vessels.df[vessels.df.ship_type == ship_type]

In [None]:
print(f"{datetime.now()} Extracting client data ...")
vessel_gdf = MobileClientExtractor(aisdk, vessels, antenna_radius_meters)

In [None]:
print(f"{datetime.now()} Writing output to {mobile_feather_path}")
vessel_gdf.to_feather(mobile_feather_path)

## Prepare stationary and mobile client training data

In [None]:
h3_resolution = 8

stationary_training_file = f"temp/training-data-stationary.pickle"
mobile_training_file = f"temp/training-data-mobile.pickle"

In [None]:
def prepare_training_data(client_feather, outfile):
    print(f"{datetime.now()} Loading data from {client_feather} ...")
    gdf = gpd.read_feather(client_feather)
    vessels = gdf.groupby(MOVER_ID)[["ship_type", "Name"]].agg(pd.Series.mode)

    print(f"{datetime.now()} Extracting trips ...")
    trajs = TripExtractor(gdf).get_trips( gap_duration=timedelta(minutes=60))  

    print(f"{datetime.now()} Computing trajectory features ...")
    t_df = TrajectoryAggregator(trajs, vessels).aggregate_trajs(h3_resolution)

    with open(outfile.replace("training-data", "vessel"), "wb") as out_file:
        pickle.dump(vessels, out_file)

    with open(outfile, "wb") as out_file:
        pickle.dump(t_df, out_file)
    print(f"{datetime.now()} training data written to {outfile}")


In [None]:
prepare_training_data(stationary_feather_path, stationary_training_file)


In [None]:
prepare_training_data(mobile_feather_path, mobile_training_file)

# Federated Learning



In [54]:
np.random.seed(0)

vessel_types = ['Cargo', 'Passenger', 'Tanker'] 
n_features = 7
traj_features = ['speed_max', 'speed_median', 'x_start', 'y_start', 'x_end', 'y_end', 'length'] 
test_size = 0.33

trajectory_classifier_model = SummarizedAISTrajectoryClassifier(vessel_types, n_features)


In [55]:
from sklearn.metrics import accuracy_score, confusion_matrix
import json
import os
import numpy as np
from pandas import DataFrame
from mobiml.utils import XYList


def display_confusion_matrix(y_test, predictions, labels):
    cm = confusion_matrix(y_test, predictions, labels=labels)
    cm_df = DataFrame(cm, index=labels, columns=labels)
    print(cm_df)


def save_metrics(predictions, y_test, scenario_name):
    if not os.path.exists("output"):
        os.makedirs("output")

    metrics = {"accuracy": accuracy_score(y_test, predictions)}

    out_path = f"output/fl-global-metrics-{scenario_name}.json"
    print(f"Saving metrics to {out_path}")
    with open(out_path, "w") as fd:
        json.dump(metrics, fd)


def partition(X: np.ndarray, y: np.ndarray, num_partitions: int) -> XYList:
    """Split X and y into a number of partitions."""
    zipped = zip(np.array_split(X, num_partitions), np.array_split(y, num_partitions))
    return list(zipped)

## Define FL client

In [56]:
class AISClient(fl.client.NumPyClient):
    """Client for FL model"""

    def __init__(self, cid, model, partitions, data_loader) -> None:
        super().__init__()
        self.cid = int(cid)
        self.model = model
        
        (X_train, y_train), (self.X_test, self.y_test) = data_loader.load(client_id=self.cid)
        # Split train set into partitions and randomly use one for training.
        partition_id = np.random.choice(partitions)
        (self.X_train, self.y_train) = partition(X_train, y_train, partitions)[partition_id]
        print(f"CLIENT {self.cid} started up, will use partition {partition_id} of {partitions} partitions for training")

    def get_parameters(self, config):  # type: ignore
        return self.model.get_model_parameters()

    def fit(self, parameters, config):  # type: ignore
        self.model.set_model_params(parameters)
        # Ignore convergence failure due to low local epochs       
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            self.model.fit(self.X_train, self.y_train)
            accuracy = self.model.score(self.X_train, self.y_train)

        print(f"CLIENT {self.cid} Training finished for round {config['server_round']}")
        return self.model.get_model_parameters(), len(self.X_train), {"accuracy": accuracy}

    def evaluate(self, parameters, config):  # type: ignore
        self.model.set_model_params(parameters)
        vessel_types = self.model.classes
        loss = log_loss(self.y_test, self.model.predict_proba(self.X_test), labels=vessel_types)
        accuracy = self.model.score(self.X_test, self.y_test)
        print(f"CLIENT {self.cid} accuracy {accuracy}")
        return loss, len(self.X_test), {"accuracy": accuracy}

In [57]:
def generate_client_fn(model, lookup, data_partitions, dloader):
    """ Is called in every training round to initialise a new Client object. """

    def client_fn(cid: str):     
        print(f"******* {int(lookup[cid])} **********")
        return AISClient(int(lookup[cid]), trajectory_classifier_model, data_partitions, dloader).to_client()

    return client_fn

## Start FL 

https://flower.ai/docs/framework/how-to-implement-strategies.html

https://flower.ai/docs/framework/ref-api/flwr.simulation.start_simulation.html

### FL with static data


In [58]:
static_data_loader = AISLoader(vessel_types, traj_features, test_size, path=stationary_training_file)
static_scenario_name = Path(stationary_training_file).stem.replace("training-data-", "")


static_strategy = fl.server.strategy.FedAvg(
        min_available_clients=2,
        evaluate_fn=get_evaluate_fn(trajectory_classifier_model, static_data_loader, static_scenario_name),
        on_fit_config_fn=fit_round,
        evaluate_metrics_aggregation_fn=weighted_average,
        fit_metrics_aggregation_fn=weighted_average,
    )

Vessel types: ['Cargo', 'Passenger', 'Tanker']
Trajectory features: ['speed_max', 'speed_median', 'x_start', 'y_start', 'x_end', 'y_end', 'length']
Test size: 0.33
Filtering ship_type to ['Cargo', 'Passenger', 'Tanker'] ...
... 3960 found.
Available trajectory columns: Index(['traj_id', 'start_t', 'end_t', 'geometry', 'length', 'direction',
       'client', 'mover_id', 'speed_max', 'speed_median', 'H3_seq',
       'speed_start', 'direction_start', 'x_start', 'y_start', 'speed_end',
       'direction_end', 'x_end', 'y_end', 'ship_type'],
      dtype='object')
2024-09-02 10:31:52.126697 Splitting dataset ...
Using 807 movers for training and 398 for testing ...
(2647 trajectories for training and 1313 for testing)


In [61]:
clients_per_round = 3
rounds = 10
client_data_partitions = 2

client_mapping = {  # flwr client id -> MMSI
     '0': 0,
     '1': 1,
     '2': 2,
}


client_fn = generate_client_fn(trajectory_classifier_model, client_mapping, client_data_partitions, static_data_loader)

print(f"{datetime.now()} Starting training")

fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=clients_per_round,
        config=fl.server.ServerConfig(num_rounds=rounds),
        strategy=static_strategy
)
print(f"{datetime.now()} Training done")

[92mINFO [0m:      Starting Flower simulation, config: num_rounds=10, no round_timeout


2024-09-02 10:33:01.494999 Starting training


2024-09-02 10:33:06,797	INFO worker.py:1621 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'CPU': 12.0, 'node:__internal_head__': 1.0, 'node:10.103.41.38': 1.0, 'object_store_memory': 3302055936.0, 'memory': 6604111872.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      No `client_resources` specified. Using minimal resources for clients.
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 1, 'num_gpus': 0.0}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 12 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[2m[33m(raylet)[0m [2024-09-02 10:33:08,815 I 55937 55937] logging.cc:230: Set ray log level from environment variable RAY_BACKEND_LOG_LEVEL to 3
[2m[33m(raylet)[0m [2024-09-02 10:33:08,815 I 55943 55943] logging.cc:

[2m[36m(ClientAppActor pid=55945)[0m ******* 0 **********
[2m[36m(ClientAppActor pid=55945)[0m Vessel types: ['Cargo', 'Passenger', 'Tanker']
[2m[36m(ClientAppActor pid=55945)[0m Trajectory features: ['speed_max', 'speed_median', 'x_start', 'y_start', 'x_end', 'y_end', 'length']
[2m[36m(ClientAppActor pid=55945)[0m Test size: 0.33
[2m[36m(ClientAppActor pid=55945)[0m Filtering ship_type to ['Cargo', 'Passenger', 'Tanker'] ...
[2m[36m(ClientAppActor pid=55945)[0m ... 3960 found.


[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Evaluating initial global parameters
[92mINFO [0m:      initial parameters (loss, other metrics): 0.9587302956179196, {'accuracy': 0.5102817974105103}
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[2m[36m(ClientAppActor pid=55945)[0m Available trajectory columns: Index(['traj_id', 'start_t', 'end_t', 'geometry', 'length', 'direction',
[2m[36m(ClientAppActor pid=55945)[0m        'client', 'mover_id', 'speed_max', 'speed_median', 'H3_seq',
[2m[36m(ClientAppActor pid=55945)[0m        'speed_start', 'direction_start', 'x_start', 'y_start', 'speed_end',
[2m[36m(ClientAppActor pid=55945)[0m        'direction_end', 'x_end', 'y_end', 'ship_type'],
[2m[36m(ClientAppActor pid=55945)[0m       dtype='object')
Accuracy 0.5102817974105103
[2m[36m(ClientAppActor pid=55945)[0m 2024-09-02 10:33:11.823832 Splitting dataset ...
[2m[36m(ClientAppActor pid=55945)[0m Using 807 movers for training and 398 for testing ...
[2m[36m(ClientAppActor pid=55945)[0m (2647 trajectories for training and 1313 for testing)
[2m[36m(ClientAppActor pid=55945)[0m CLIENT 0 started up, will use partition 1 of 2 partitions for training
[2m[36m(ClientAppActor pid=55945)[0m ******* 0 *********

[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (1, 0.9852104183330533, {'accuracy': 0.4607768469154608}, 2.4913640079994366)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


[2m[36m(ClientAppActor pid=55943)[0m Client id: 2
[2m[36m(ClientAppActor pid=55943)[0m Filtering client to 2 ...
Accuracy 0.4607768469154608


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[2m[36m(ClientAppActor pid=55944)[0m CLIENT 1 accuracy 0.0


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (2, 0.992255203604182, {'accuracy': 0.4607768469154608}, 3.0244784539991088)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


Accuracy 0.4607768469154608


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (3, 0.9982404688754756, {'accuracy': 0.456968773800457}, 3.5387142170002335)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


Accuracy 0.456968773800457


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (4, 0.9958393173488802, {'accuracy': 0.45773038842345776}, 4.125544176999028)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


Accuracy 0.45773038842345776


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (5, 0.9896437601082371, {'accuracy': 0.46534653465346537}, 4.628641829000117)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


Accuracy 0.46534653465346537
[2m[36m(ClientAppActor pid=55945)[0m ******* 1 **********[32m [repeated 29x across cluster][0m
[2m[36m(ClientAppActor pid=55945)[0m Vessel types: ['Cargo', 'Passenger', 'Tanker'][32m [repeated 29x across cluster][0m
[2m[36m(ClientAppActor pid=55945)[0m Trajectory features: ['speed_max', 'speed_median', 'x_start', 'y_start', 'x_end', 'y_end', 'length'][32m [repeated 29x across cluster][0m
[2m[36m(ClientAppActor pid=55945)[0m Test size: 0.33[32m [repeated 29x across cluster][0m


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 6]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[2m[36m(ClientAppActor pid=55945)[0m Filtering ship_type to ['Cargo', 'Passenger', 'Tanker'] ...[32m [repeated 29x across cluster][0m
[2m[36m(ClientAppActor pid=55945)[0m ... 707 found.[32m [repeated 49x across cluster][0m
[2m[36m(ClientAppActor pid=55945)[0m Available trajectory columns: Index(['traj_id', 'start_t', 'end_t', 'geometry', 'length', 'direction',[32m [repeated 29x across cluster][0m
[2m[36m(ClientAppActor pid=55945)[0m        'client', 'mover_id', 'speed_max', 'speed_median', 'H3_seq',[32m [repeated 29x across cluster][0m
[2m[36m(ClientAppActor pid=55945)[0m        'speed_start', 'direction_start', 'x_start', 'y_start', 'speed_end',[32m [repeated 29x across cluster][0m
[2m[36m(ClientAppActor pid=55945)[0m        'direction_end', 'x_end', 'y_end', 'ship_type'],[32m [repeated 29x across cluster][0m
[2m[36m(ClientAppActor pid=55945)[0m       dtype='object')[32m [repeated 29x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (6, 0.9936430281784375, {'accuracy': 0.46001523229246}, 5.156842242000494)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


[2m[36m(ClientAppActor pid=55943)[0m 2024-09-02 10:33:16.907672 Splitting dataset ...[32m [repeated 30x across cluster][0m
[2m[36m(ClientAppActor pid=55943)[0m Using 205 movers for training and 102 for testing ...[32m [repeated 30x across cluster][0m
[2m[36m(ClientAppActor pid=55943)[0m (520 trajectories for training and 187 for testing)[32m [repeated 30x across cluster][0m
[2m[36m(ClientAppActor pid=55943)[0m CLIENT 1 started up, will use partition 1 of 2 partitions for training[32m [repeated 30x across cluster][0m
Accuracy 0.46001523229246


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 7]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[2m[36m(ClientAppActor pid=55945)[0m CLIENT 2 Training finished for round 6[32m [repeated 17x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (7, 0.9900339826970359, {'accuracy': 0.46153846153846156}, 5.667923245000566)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


Accuracy 0.46153846153846156


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 8]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (8, 0.9919529683778454, {'accuracy': 0.46153846153846156}, 6.177679098000226)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


Accuracy 0.46153846153846156


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 9]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (9, 0.9846198428348729, {'accuracy': 0.4607768469154608}, 6.833323708000535)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


Accuracy 0.4607768469154608


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 10]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (10, 0.9848967683119845, {'accuracy': 0.4668697638994669}, 7.346820446999118)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


Accuracy 0.4668697638994669
[2m[36m(ClientAppActor pid=55945)[0m Client id: 2[32m [repeated 37x across cluster][0m
[2m[36m(ClientAppActor pid=55944)[0m Filtering client to 1 ...[32m [repeated 38x across cluster][0m


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 10 round(s) in 7.60s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.9429875792908189
[92mINFO [0m:      		round 2: 0.9440935158012854
[92mINFO [0m:      		round 3: 0.9453010630892211
[92mINFO [0m:      		round 4: 0.9434104017673705
[92mINFO [0m:      		round 5: 0.9398809106754858
[92mINFO [0m:      		round 6: 0.9398957438494122
[92mINFO [0m:      		round 7: 0.9374208287189824
[92mINFO [0m:      		round 8: 0.936680760427797
[92mINFO [0m:      		round 9: 0.9329039073167604
[92mINFO [0m:      		round 10: 0.9317513933846465
[92mINFO [0m:      	History (loss, centralized):
[92mINFO [0m:      		round 0: 0.9587302956179196
[92mINFO [0m:      		round 1: 0.9852104183330533
[92mINFO [0m:      		round 2: 0.992255203604182
[92mINFO [0m:      		round 3: 0.9982404688754756
[

2024-09-02 10:33:19.487282 Training done
[2m[36m(ClientAppActor pid=55943)[0m 2024-09-02 10:33:19.429542 Splitting dataset ...
[2m[36m(ClientAppActor pid=55943)[0m Using 807 movers for training and 398 for testing ...
[2m[36m(ClientAppActor pid=55943)[0m (2647 trajectories for training and 1313 for testing)
[2m[36m(ClientAppActor pid=55943)[0m CLIENT 0 started up, will use partition 1 of 2 partitions for training
[2m[36m(ClientAppActor pid=55943)[0m CLIENT 0 accuracy 0.0


In [62]:
tdf = pd.read_pickle(mobile_training_file)

In [65]:
tdf.client.unique()

array([219012959, 236111925, 235662000, 265737220])

### FL with mobile data

In [67]:
mobile_data_loader = AISLoader(vessel_types, traj_features, test_size, path=mobile_training_file)
mobile_scenario_name = Path(mobile_training_file).stem.replace("training-data-", "")


mobile_strategy = fl.server.strategy.FedAvg(
        min_available_clients=2,
        evaluate_fn=get_evaluate_fn(trajectory_classifier_model, mobile_data_loader, mobile_scenario_name),
        on_fit_config_fn=fit_round,
        evaluate_metrics_aggregation_fn=weighted_average,
        fit_metrics_aggregation_fn=weighted_average,
    )

Vessel types: ['Cargo', 'Passenger', 'Tanker']
Trajectory features: ['speed_max', 'speed_median', 'x_start', 'y_start', 'x_end', 'y_end', 'length']
Test size: 0.33
Filtering ship_type to ['Cargo', 'Passenger', 'Tanker'] ...
... 3263 found.
Available trajectory columns: Index(['traj_id', 'start_t', 'end_t', 'geometry', 'length', 'direction',
       'client', 'mover_id', 'speed_max', 'speed_median', 'H3_seq',
       'speed_start', 'direction_start', 'x_start', 'y_start', 'speed_end',
       'direction_end', 'x_end', 'y_end', 'ship_type'],
      dtype='object')
2024-09-02 10:36:28.861506 Splitting dataset ...
Using 776 movers for training and 383 for testing ...
(2242 trajectories for training and 1021 for testing)


In [None]:
clients_per_round = 4
rounds = 10
client_data_partitions = 3


client_mapping = {  # flwr client id -> MMSI
     '0': 236111925,
     '1': 219012959,
     '2': 235662000,
     '3': 265737220,
}


client_fn = generate_client_fn(trajectory_classifier_model, client_mapping, client_data_partitions, mobile_data_loader)

print(f"{datetime.now()} Starting training")
fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=clients_per_round,
        config=fl.server.ServerConfig(num_rounds=rounds),
        strategy=mobile_strategy
)
print(f"{datetime.now()} Training done")

[92mINFO [0m:      Starting Flower simulation, config: num_rounds=10, no round_timeout


2024-09-02 10:36:55.382806 Starting training


2024-09-02 10:37:00,599	INFO worker.py:1621 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'node:10.103.41.38': 1.0, 'object_store_memory': 3244908134.0, 'memory': 6489816270.0, 'CPU': 12.0, 'node:__internal_head__': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      No `client_resources` specified. Using minimal resources for clients.
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 1, 'num_gpus': 0.0}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 12 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[2m[33m(raylet)[0m [2024-09-02 10:37:01,779 I 58546 58546] logging.cc:230: Set ray log level from environment variable RAY_BACKEND_LOG_LEVEL to 3
[2m[33m(raylet)[0m [2024-09-02 10:37:02,669 I 58809 58809] logging.cc:

[2m[36m(ClientAppActor pid=58817)[0m ******* 235662000 **********
[2m[36m(ClientAppActor pid=58817)[0m Client id: 235662000
[2m[36m(ClientAppActor pid=58817)[0m Vessel types: ['Cargo', 'Passenger', 'Tanker']
[2m[36m(ClientAppActor pid=58817)[0m Trajectory features: ['speed_max', 'speed_median', 'x_start', 'y_start', 'x_end', 'y_end', 'length']
[2m[36m(ClientAppActor pid=58817)[0m Test size: 0.33
[2m[36m(ClientAppActor pid=58817)[0m Filtering ship_type to ['Cargo', 'Passenger', 'Tanker'] ...
[2m[36m(ClientAppActor pid=58817)[0m ... 3263 found.
[2m[36m(ClientAppActor pid=58817)[0m Filtering client to 235662000 ...
[2m[36m(ClientAppActor pid=58817)[0m ... 30 found.
Accuracy 0.6150832517140059
[2m[36m(ClientAppActor pid=58817)[0m Available trajectory columns: Index(['traj_id', 'start_t', 'end_t', 'geometry', 'length', 'direction',
[2m[36m(ClientAppActor pid=58817)[0m        'client', 'mover_id', 'speed_max', 'speed_median', 'H3_seq',
[2m[36m(ClientAppActo