# MobiML FL demo

Using Flower and MobiML

In [None]:
import os
import sys
import pickle
import pandas as pd
import geopandas as gpd
import numpy as np
from datetime import datetime, timedelta
from copy import deepcopy
from typing import Dict, List, Tuple
from pathlib import Path
from sklearn.metrics import log_loss
from sklearn.preprocessing import MultiLabelBinarizer

import flwr as fl
from flwr.common import Metrics

import sys
sys.path.append("..")
from mobiml.datasets import AISDK, MOVER_ID, SHIPTYPE
from mobiml.transforms import TripExtractor, TrajectoryAggregator
from mobiml.preprocessing import StationaryClientExtractor
from mobiml.models import SummarizedAISTrajectoryClassifier
from mobiml.models.ais_trajectory_classifier import AISLoader, get_evaluate_fn, fit_round, weighted_average
from mobiml.utils import convert_wgs_to_utm

## Extract stationary client (antenna) data

In [None]:
path = "./data/aisdk-2018-02.zip"
antennas = ['Point (11.96524 57.70730)', 'Point (11.63979 57.71941)', 'Point (11.78460 57.57255)']
antenna_radius_meters = 25000

In [None]:
epsg_code = convert_wgs_to_utm(11.96524, 57.70730)

ids =  [{'client': i} for i in range(len(antennas))]
df = pd.DataFrame(ids)
df['geometry'] = gpd.GeoSeries.from_wkt(antennas)
gdf = gpd.GeoDataFrame(df, geometry=df.geometry, crs=4326)
gdf = gdf.to_crs(epsg_code)
gdf['geometry'] = gdf.buffer(antenna_radius_meters)

buffered_antennas =  gdf.to_crs(4326)
min_lon, min_lat, max_lon, max_lat = buffered_antennas.geometry.total_bounds

In [None]:
out_dir = "temp"
if not os.path.exists(out_dir):
    print(f"{datetime.now()} Creating output directory {out_dir} ...")
    os.makedirs(out_dir)

In [None]:
print(f"{datetime.now()} Loading data from {path}")
aisdk = AISDK(path, min_lon, min_lat, max_lon, max_lat)

In [None]:
print(f"{datetime.now()} Extracting client data ...")
client_data = StationaryClientExtractor(aisdk).extract(buffered_antennas)

In [None]:
client_feather_path = "temp/ais-antenna.feather"
print(f"{datetime.now()} Writing output to {client_feather_path}")
client_data.to_feather(client_feather_path)

## Prepare training data

In [None]:
h3_resolution = 8

In [None]:
print(f"{datetime.now()} Loading data from {client_feather_path} ...")
gdf = gpd.read_feather(client_feather_path)
vessels = gdf.groupby(MOVER_ID)[["ship_type", "Name"]].agg(pd.Series.mode)

In [None]:
print(f"{datetime.now()} Extracting trips ...")
trajs = TripExtractor(gdf).get_trips(
    gap_duration=timedelta(minutes=60)
)  

In [None]:
print(f"{datetime.now()} Computing trajectory features ...")
trajs = TrajectoryAggregator(trajs, vessels).aggregate_trajs(h3_resolution)

In [None]:
with open("temp/vessels-stationary.pickle", "wb") as out_file:
    pickle.dump(vessels, out_file)

In [None]:
with open("temp/training-data-stationary.pickle", "wb") as out_file:
    pickle.dump(trajs, out_file)

## Start Flower server for federated learning

https://github.com/adap/flower/blob/main/examples/flower-in-30-minutes/tutorial.ipynb

In [None]:
np.random.seed(0)

data_path = "temp/training-data-stationary.pickle"
scenario_name = Path(data_path).stem.replace("training-data-", "")

vessel_types = ['Cargo', 'Passenger', 'Tanker']
traj_features = ['speed_max', 'speed_median', 'x_start', 'y_start', 'x_end', 'y_end', 'length']  # ['SOG_max', 'SOG_median', 'LON_start', 'LAT_start', 'LON_end', 'LAT_end', 'length']  'H3_seq'
n_features = 7  # 1804  # depends on the number of H3 cells in H3_seq
test_size = 0.33

data_loader = AISLoader(vessel_types, traj_features, test_size, path=data_path)

model = SummarizedAISTrajectoryClassifier(vessel_types, n_features)

strategy = fl.server.strategy.FedAvg(
    min_available_clients=2,
    evaluate_fn=get_evaluate_fn(model, data_loader, scenario_name),
    on_fit_config_fn=fit_round,
    evaluate_metrics_aggregation_fn=weighted_average,
    fit_metrics_aggregation_fn=weighted_average,
)

fl.server.start_server(
    server_address="0.0.0.0:8080",
    strategy=strategy,
    config=fl.server.ServerConfig(num_rounds=10),
)


In [None]:
1+2