# MobiML FL demo

Using Flower and MobiML




In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
import pickle
import warnings
import pandas as pd
import geopandas as gpd
import numpy as np
from datetime import datetime, timedelta
from copy import deepcopy
from typing import Dict, List, Tuple
from pathlib import Path
from sklearn.metrics import log_loss
from sklearn.preprocessing import MultiLabelBinarizer
import pymeos

import flwr as fl
from flwr.common import Metrics, Context

import sys
sys.path.append("../src")
from mobiml.datasets import AISDK, MOVER_ID, SHIPTYPE
from mobiml.transforms import TrajectoryCreator, TrajectoryAggregator
from mobiml.preprocessing import StationaryClientExtractor, MobileClientExtractor
from mobiml.models.trajclassifier.ais_trajectory_classifier import SummarizedAISTrajectoryClassifier, AISLoader, get_evaluate_fn, fit_round, weighted_average
from mobiml.utils import convert_wgs_to_utm


## Set FL simulation output verbosity settings


RAY_BACKEND_LOG_LEVEL: to declutter log output in shells, a low-volume log level has been chosen from [the source code](https://github.com/ray-project/ray/blob/master/src/ray/util/logging.cc#L273)

RAY_DEDUP_LOGS: to see logs from all clients instead of just one

In [None]:
os.environ['RAY_DEDUP_LOGS'] = '0'
os.environ['RAY_BACKEND_LOG_LEVEL'] = 'fatal'


In [None]:
path = "./data/aisdk-2018-02.zip" # download AISDK input file (1 month of data, 14G) from here http://web.ais.dk/aisdata/aisdk-2018-02.zip



## Extract stationary client (antenna) data

In [None]:
antennas = ['Point (11.96524 57.70730)', 'Point (11.63979 57.71941)', 'Point (11.78460 57.57255)']
antenna_radius_meters = 25000


In [None]:
epsg_code = convert_wgs_to_utm(11.96524, 57.70730)

ids =  [{'client': i} for i in range(len(antennas))]
df = pd.DataFrame(ids)
df['geometry'] = gpd.GeoSeries.from_wkt(antennas)
gdf = gpd.GeoDataFrame(df, geometry=df.geometry, crs=4326)
gdf = gdf.to_crs(epsg_code)
gdf['geometry'] = gdf.buffer(antenna_radius_meters)

buffered_antennas =  gdf.to_crs(4326)
min_lon, min_lat, max_lon, max_lat = buffered_antennas.geometry.total_bounds

In [None]:
out_dir = "temp"
os.path.dirname(out_dir)
if not os.path.exists(out_dir):
    print(f"{datetime.now()} Creating output directory {out_dir} ...")
    os.makedirs(out_dir)

In [None]:
print(f"{datetime.now()} Loading data from {path}")
aisdk = AISDK(path, min_lon, min_lat, max_lon, max_lat)

In [None]:
print(f"{datetime.now()} Extracting client data ...")
antenna_gdf = StationaryClientExtractor(aisdk).extract(buffered_antennas)

In [None]:
stationary_feather_path = f"{out_dir}/ais-antenna.feather"


In [None]:
print(f"{datetime.now()} Writing output to {stationary_feather_path}")
antenna_gdf.to_feather(stationary_feather_path)

## Extract mobile client (vessel) data

In [None]:
ship_type = 'Towing' 
antenna_radius_meters = 25000  
bbox = [57.273, 11.196, 57.998, 12.223]  
min_lat, min_lon, max_lat, max_lon = bbox

In [None]:
mobile_feather_path = f"{out_dir}/ais-vessels.feather"

In [None]:
print(f"{datetime.now()} Loading data from {path}")
aisdk = AISDK(path, min_lon, min_lat, max_lon, max_lat)
vessels = deepcopy(aisdk)   # AISDK(path, min_lon, min_lat, max_lon, max_lat, vessel_type)
vessels.df = vessels.df[vessels.df.ship_type == ship_type]

In [None]:
print(f"{datetime.now()} Extracting client data ...")
vessel_gdf = MobileClientExtractor(aisdk).extract(vessels, antenna_radius_meters)

In [None]:
print(f"{datetime.now()} Writing output to {mobile_feather_path}")
vessel_gdf.to_feather(mobile_feather_path)

## Prepare stationary and mobile client training data

In [None]:
h3_resolution = 8

stationary_training_file = f"temp/training-data-stationary.pickle"
mobile_training_file = f"temp/training-data-mobile.pickle"

In [None]:
def prepare_training_data(client_feather, outfile):
    print(f"{datetime.now()} Loading data from {client_feather} ...")
    gdf = gpd.read_feather(client_feather)
    vessels = gdf.groupby(MOVER_ID)[["ship_type", "Name"]].agg(pd.Series.mode)

    print(f"{datetime.now()} Extracting trips ...")
    trajs = TrajectoryCreator(gdf).get_trajs( gap_duration=timedelta(minutes=60))  

    print(f"{datetime.now()} Computing trajectory features ...")
    t_df = TrajectoryAggregator(trajs, vessels).aggregate_trajs(h3_resolution)

    with open(outfile.replace("training-data", "vessel"), "wb") as out_file:
        pickle.dump(vessels, out_file)

    with open(outfile, "wb") as out_file:
        pickle.dump(t_df, out_file)
    print(f"{datetime.now()} training data written to {outfile}")


In [None]:
prepare_training_data(stationary_feather_path, stationary_training_file)


In [None]:
prepare_training_data(mobile_feather_path, mobile_training_file)

# Federated Learning



In [None]:
np.random.seed(0)

vessel_types = ['Cargo', 'Passenger', 'Tanker'] 
n_features = 7
traj_features = ['speed_max', 'speed_median', 'x_start', 'y_start', 'x_end', 'y_end', 'length'] 
test_size = 0.33

trajectory_classifier_model = SummarizedAISTrajectoryClassifier(vessel_types, n_features)


In [None]:
from sklearn.metrics import accuracy_score, confusion_matrix
import json
import os
import numpy as np
from pandas import DataFrame
from mobiml.utils import XYList


def display_confusion_matrix(y_test, predictions, labels):
    cm = confusion_matrix(y_test, predictions, labels=labels)
    cm_df = DataFrame(cm, index=labels, columns=labels)
    print(cm_df)


def save_metrics(predictions, y_test, scenario_name):
    if not os.path.exists("output"):
        os.makedirs("output")

    metrics = {"accuracy": accuracy_score(y_test, predictions)}

    out_path = f"output/fl-global-metrics-{scenario_name}.json"
    print(f"Saving metrics to {out_path}")
    with open(out_path, "w") as fd:
        json.dump(metrics, fd)


def partition(X: np.ndarray, y: np.ndarray, num_partitions: int) -> XYList:
    """Split X and y into a number of partitions."""
    zipped = zip(np.array_split(X, num_partitions), np.array_split(y, num_partitions))
    return list(zipped)

## Define FL client

In [None]:
class AISClient(fl.client.NumPyClient):
    """Client for FL model"""

    def __init__(self, cid, model, partitions, data_loader) -> None:
        super().__init__()
        self.cid = int(cid)
        self.model = model
        
        (X_train, y_train), (self.X_test, self.y_test) = data_loader.load(client_id=self.cid)
        # Split train set into partitions and randomly use one for training.
        partition_id = np.random.choice(partitions)
        (self.X_train, self.y_train) = partition(X_train, y_train, partitions)[partition_id]
        print(f"CLIENT {self.cid} started up, will use partition {partition_id} of {partitions} partitions for training")

    def get_parameters(self, config):  # type: ignore
        return self.model.get_model_parameters()

    def fit(self, parameters, config):  # type: ignore
        self.model.set_model_params(parameters)
        # Ignore convergence failure due to low local epochs       
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            self.model.fit(self.X_train, self.y_train)
            accuracy = self.model.score(self.X_train, self.y_train)

        print(f"CLIENT {self.cid} Training finished for round {config['server_round']}")
        return self.model.get_model_parameters(), len(self.X_train), {"accuracy": accuracy}

    def evaluate(self, parameters, config):  # type: ignore
        self.model.set_model_params(parameters)
        vessel_types = self.model.classes
        loss = log_loss(self.y_test, self.model.predict_proba(self.X_test), labels=vessel_types)
        accuracy = self.model.score(self.X_test, self.y_test)
        print(f"CLIENT {self.cid} accuracy {accuracy}")
        return loss, len(self.X_test), {"accuracy": accuracy}

In [None]:
def generate_client_fn(model, lookup, data_partitions, dloader):
    """ Is called in every training round to initialise a new Client object. """

    def client_fn(cid: str):     
        print(f"******* {int(lookup[cid])} **********")
        return AISClient(int(lookup[cid]), trajectory_classifier_model, data_partitions, dloader).to_client()

    return client_fn

## Start FL 

https://flower.ai/docs/framework/how-to-implement-strategies.html

https://flower.ai/docs/framework/ref-api/flwr.simulation.start_simulation.html

### FL with static data


In [None]:
static_data_loader = AISLoader(vessel_types, traj_features, test_size, path=stationary_training_file)
static_scenario_name = Path(stationary_training_file).stem.replace("training-data-", "")


static_strategy = fl.server.strategy.FedAvg(
        min_available_clients=2,
        evaluate_fn=get_evaluate_fn(trajectory_classifier_model, static_data_loader, static_scenario_name),
        on_fit_config_fn=fit_round,
        evaluate_metrics_aggregation_fn=weighted_average,
        fit_metrics_aggregation_fn=weighted_average,
    )

In [None]:
clients_per_round = 3
rounds = 10
client_data_partitions = 2

client_mapping = {  # flwr client id -> MMSI
     '0': 0,
     '1': 1,
     '2': 2,
}


client_fn = generate_client_fn(trajectory_classifier_model, client_mapping, client_data_partitions, static_data_loader)

print(f"{datetime.now()} Starting training")

fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=clients_per_round,
        config=fl.server.ServerConfig(num_rounds=rounds),
        strategy=static_strategy
)
print(f"{datetime.now()} Training done")

### FL with mobile data

In [None]:
mobile_data_loader = AISLoader(vessel_types, traj_features, test_size, path=mobile_training_file)
mobile_scenario_name = Path(mobile_training_file).stem.replace("training-data-", "")


mobile_strategy = fl.server.strategy.FedAvg(
        min_available_clients=2,
        evaluate_fn=get_evaluate_fn(trajectory_classifier_model, mobile_data_loader, mobile_scenario_name),
        on_fit_config_fn=fit_round,
        evaluate_metrics_aggregation_fn=weighted_average,
        fit_metrics_aggregation_fn=weighted_average,
    )

In [None]:
clients_per_round = 4
rounds = 10
client_data_partitions = 3


client_mapping = {  # flwr client id -> MMSI
     '0': 236111925,
     '1': 219012959,
     '2': 235662000,
     '3': 265737220,
}


client_fn = generate_client_fn(trajectory_classifier_model, client_mapping, client_data_partitions, mobile_data_loader)

print(f"{datetime.now()} Starting training")
fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=clients_per_round,
        config=fl.server.ServerConfig(num_rounds=rounds),
        strategy=mobile_strategy
)
print(f"{datetime.now()} Training done")