In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
import pytorch_lightning as pl
import seaborn as sns

# Description of the target variable (number of pickup)

In [2]:
df = pd.read_csv('dataset/locker_nyc_engineered.csv')
# Convert Date Hour to datetime
df['Date Hour'] = pd.to_datetime(df['Date Hour'])
df['Locker Name'] = "Total Locker"
df.head()

Unnamed: 0,Locker Name,Date Hour,IsIndoor,size_L_delivery,size_M_delivery,size_S_delivery,size_XL_delivery,size_L_withdraw,size_M_withdraw,size_S_withdraw,size_XL_withdraw,IsHoliday,DBSCAN Cluster,KMeans Cluster
0,Total Locker,2024-04-10 18:00:00,1,0.0,1.0,6.0,0.0,0.0,0.0,0.0,0.0,0,0,4
1,Total Locker,2024-04-10 19:00:00,1,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0,0,4
2,Total Locker,2024-04-11 12:00:00,1,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0,0,4
3,Total Locker,2024-04-11 14:00:00,1,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0,0,4
4,Total Locker,2024-04-11 15:00:00,1,0.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0,0,4


In [3]:
# Cluster Dataframe: Split the records by "KMeans Cluster" feature
clustered_dfs = {cluster: df[df['KMeans Cluster'] == cluster] for cluster in df['KMeans Cluster'].unique()}
clustered_dfs.keys()

dict_keys([np.int64(4), np.int64(3), np.int64(1), np.int64(0), np.int64(2)])

In [4]:
def _resample(df):
    # Group by Locker Name, resample to 3 hours, and take the sum of numerical columns
    df = df.set_index('Date Hour')
    df = df.groupby('Locker Name').resample('3h').sum().drop(columns=["Locker Name"])
    df = df.reset_index()
    df = df.fillna(0)
    return df

def _rolling_window_sum(df):
    # Apply rolling window
    for window in [1, 2]:
        df[f'withdraw_{window}'] = df.groupby('Locker Name')['size_S_withdraw'].transform(lambda x: x.shift(1).rolling(window=window, min_periods=1).sum())
        # Fill NaN values with 0
        df[f'withdraw_{window}'] = df[f'withdraw_{window}'].fillna(0)

    # Apply rolling window
    for window in [1, 2, 8, 16]:
        df[f'delivery_{window}'] = df.groupby('Locker Name')['size_S_delivery'].transform(lambda x: x.rolling(window=window, min_periods=1).sum())

    return df

def _transform_target(df):
    # Proxy inventory by applying rolling window sum of withdraw - delivery
    df[f'inventory'] = df[f'delivery_1'] - df[f'withdraw_1']
    # Apply cumulative sum to get a proxy inventory level
    df[f'inventory'] = df.groupby('Locker Name')[f'inventory'].cumsum()

    # Target variable: proportion_withdraw: size_S_withdraw / inventory
    df['proportion_withdraw'] = df.apply(lambda row: row['size_S_withdraw'] / row['inventory'] if row['inventory'] > 0 else 0, axis=1)
    # df.drop(columns=['size_S_withdraw'], inplace=True)
    return df

def _cyclical_features(df):
    # Create cyclical features for Hour, Day of Week, Month
    # MONTHS_IN_YEAR = 12
    HOURS_IN_DAY = 24
    # DAYS_IN_WEEK = 7
    # QUARTERS_IN_YEAR = 4

    df['hour_sin'] = np.sin(2 * np.pi * df['Hour'] / HOURS_IN_DAY)
    df['hour_cos'] = np.cos(2 * np.pi * df['Hour'] / HOURS_IN_DAY)
    # df['month_sin'] = np.sin(2 * np.pi * df['Month'] / MONTHS_IN_YEAR)
    # df['month_cos'] = np.cos(2 * np.pi * df['Month'] / MONTHS_IN_YEAR)
    # df['day_of_week_sin'] = np.sin(2 * np.pi * df['Day of Week'] / DAYS_IN_WEEK)
    # df['day_of_week_cos'] = np.cos(2 * np.pi * df['Day of Week'] / DAYS_IN_WEEK)
    return df

def preprocess(df):
    # Extract dict Locker Name -> IsIndoor
    locker_indoor_dict = df.set_index('Locker Name')['IsIndoor'].to_dict()
    # Cache IsHoliday per date
    holiday_dict = df.groupby(df['Date Hour'].dt.date)['IsHoliday'].first().to_dict()
    df = df.drop(columns=['IsIndoor'])

    df = _resample(df)

    # Map IsHoliday back to the resampled dataframe
    df['IsHoliday'] = df['Date Hour'].dt.date.map(holiday_dict)

    df = _rolling_window_sum(df)
    df = _transform_target(df)

    # Ignore all size_*_delivery and size_*_withdraw columns except size_S_withdraw
    cols_to_drop = [col for col in df.columns if ('size_' in col and col not in ['size_S_withdraw'])]
    df = df.drop(columns=cols_to_drop)

    # Apply the IsIndoor feature back to the resampled dataframe
    df['IsIndoor'] = df['Locker Name'].map(locker_indoor_dict)
    df['Hour'] = df['Date Hour'].dt.hour
    # df['Day of Week'] = df['Date Hour'].dt.dayofweek
    df['Month'] = df['Date Hour'].dt.month
    df['IsPeakHour'] = df['Date Hour'].apply(lambda x: True if 17 <= x.hour <= 20 else False)
    df['IsWeekend'] = df['Date Hour'].dt.dayofweek.apply(lambda x: True if x >= 5 else False)

    # Create a time index as incremental integer
    df['time_idx'] = df.groupby('Locker Name')['Date Hour'].transform(lambda x: (x - x.min()).dt.total_seconds() // 3600 / 3).astype(int)

    df = _cyclical_features(df)

    df = df.drop(columns=['Hour', 'Month', 'IsHoliday'])

    # Remove the cluster columns
    df = df.drop(columns=['DBSCAN Cluster', 'KMeans Cluster'])
    return df

# Apply preprocessing to each clustered dataframe and concatenate them back
clustered_dfs = {cluster: preprocess(clustered_dfs[cluster]) for cluster in clustered_dfs}

In [5]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from sklearn.preprocessing import StandardScaler
from torchmetrics import MeanSquaredError
import math

# Custom Dataset for time series data from DataFrame
class TimeSeriesDataset(Dataset):
    def __init__(self, df, encoder_length, decoder_length, target_col, feature_cols, from_train_scaler=None):
        self.encoder_length = encoder_length
        self.decoder_length = decoder_length
        self.target_col = target_col
        self.feature_cols = feature_cols
        
        # Scale features
        if from_train_scaler:
            self.scaler = from_train_scaler
            self.features = self.scaler.transform(df[feature_cols].values)
        else:
             self.scaler = StandardScaler()
             self.features = self.scaler.fit_transform(df[feature_cols].values)
        self.indices = df['time_idx'].values
        self.dates = df['Date Hour'].values
        self.targets = df[target_col].values
        
    def __len__(self):
        return len(self.features) - self.encoder_length - self.decoder_length

    def __getitem__(self, idx):
        # Get sequence of features and target
        x = self.features[idx:idx+self.encoder_length]
        y = self.targets[idx+self.encoder_length:idx+self.encoder_length+self.decoder_length]
        index = self.indices[idx+self.encoder_length:idx+self.encoder_length+self.decoder_length]
        
        return (torch.tensor(x, dtype=torch.float32), 
                torch.tensor(y, dtype=torch.float32),
                torch.tensor(index, dtype=torch.int64)
        )

class DecayMSE(nn.Module):
    def __init__(self, lambda_decay=0.1):
        super(DecayMSE, self).__init__()
        self.lambda_decay = lambda_decay
        self.mse = nn.MSELoss(reduction='none')

    def forward(self, y_pred, y_true):
        # Calculate the squared errors
        squared_errors = self.mse(y_pred, y_true)
        H = squared_errors.size(1)
        # Create decay weights: w_h = exp(-λ * h) for h = 1 to H
        decay_factors = torch.tensor([math.exp(-self.lambda_decay * i) for i in range(H)], device=squared_errors.device)
        # Apply the decay factors to the squared errors
        weighted_errors = squared_errors * decay_factors
        # Return the mean of the weighted errors
        return weighted_errors.mean()

# Plain PyTorch LSTM Model
class LSTMModel(nn.Module):
    def __init__(self, input_size, output_horizon, hidden_size, num_layers):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True
        )
        self.fc = nn.Linear(hidden_size, output_horizon)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        out = self.fc(lstm_out[:, -1, :])
        out = self.sigmoid(out)
        return out

# Function to create DataLoader from DataFrame
def create_dataloader(df, encoder_length, decoder_length, target_col, feature_cols, batch_size=32, train_val_test_split=(0.7, 0.2, 0.1), num_workers=0):
    train_split, val_split, test_split = train_val_test_split

    # Split data into train and validation
    train_size = int(train_split * len(df))
    train_df = df[:train_size]
    val_df = df[train_size:train_size + int(val_split * len(df))]
    test_df = df[train_size + int(val_split * len(df)):]
    print(f"Train size: {len(train_df)}, Validation size: {len(val_df)}, Test size: {len(test_df)}")
    # Create datasets
    train_dataset = TimeSeriesDataset(train_df, encoder_length, decoder_length, target_col, feature_cols)
    val_dataset = TimeSeriesDataset(val_df, encoder_length, decoder_length, target_col, feature_cols, from_train_scaler=train_dataset.scaler)
    test_dataset = TimeSeriesDataset(test_df, encoder_length, decoder_length, target_col, feature_cols, from_train_scaler=train_dataset.scaler)

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, val_loader, test_loader, train_df, val_df, test_df

In [6]:
df = preprocess(df)

In [7]:
ENCODER_LENGTH = 16
DECODER_LENGTH = 4
TARGET_COL = 'proportion_withdraw'
FEATURE_COLS = [col for col in df.columns if col not in ['Date Hour', 'Locker Name', 'proportion_withdraw', 'time_idx', 'size_S_withdraw']]
BATCH_SIZE = 128
INPUT_SIZE = len(FEATURE_COLS)
HIDDEN_SIZE = 256
NUM_LAYERS = 8
WORKERS = 0

train_loader, val_loader, test_loader, train_df, val_df, test_df = create_dataloader(
    df,
    encoder_length=ENCODER_LENGTH,
    decoder_length=DECODER_LENGTH,
    target_col=TARGET_COL,
    feature_cols=FEATURE_COLS,
    batch_size=BATCH_SIZE,
    train_val_test_split=(0.7, 0.2, 0.1),
    num_workers=WORKERS
)

Train size: 2516, Validation size: 719, Test size: 360


In [8]:
df.head()

Unnamed: 0,Locker Name,Date Hour,size_S_withdraw,withdraw_1,withdraw_2,delivery_1,delivery_2,delivery_8,delivery_16,inventory,proportion_withdraw,IsIndoor,IsPeakHour,IsWeekend,time_idx,hour_sin,hour_cos
0,Total Locker,2024-04-10 12:00:00,2.0,0.0,0.0,2.0,2.0,2.0,2.0,2.0,1.0,0,False,False,0,1.224647e-16,-1.0
1,Total Locker,2024-04-10 15:00:00,0.0,2.0,2.0,0.0,2.0,2.0,2.0,0.0,0.0,0,False,False,1,-0.7071068,-0.7071068
2,Total Locker,2024-04-10 18:00:00,1.0,0.0,2.0,10.0,10.0,12.0,12.0,10.0,0.1,0,True,False,2,-1.0,-1.83697e-16
3,Total Locker,2024-04-10 21:00:00,0.0,1.0,1.0,0.0,10.0,12.0,12.0,9.0,0.0,0,False,False,3,-0.7071068,0.7071068
4,Total Locker,2024-04-11 00:00:00,0.0,0.0,1.0,0.0,0.0,12.0,12.0,9.0,0.0,0,False,False,4,0.0,1.0


In [9]:
model = LSTMModel(
    input_size=INPUT_SIZE,
    output_horizon=DECODER_LENGTH,
    hidden_size=HIDDEN_SIZE,
    num_layers=NUM_LAYERS,
)
print(f"Model has {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters")

Model has 3961860 trainable parameters


In [10]:
# Load the first batch to check for correct matching of input and output shapes
for batch in train_loader:
    x, y, idx = batch
    print(f"Input shape: {x.shape}, Target shape: {y.shape}")
    break

Input shape: torch.Size([128, 16, 12]), Target shape: torch.Size([128, 4])


In [11]:
# from pytorch_lightning.callbacks import EarlyStopping

# trainer = pl.Trainer(
#     accelerator="gpu",
#     devices=1,
#     min_epochs=10,
#     max_epochs=500,
#     callbacks=[EarlyStopping(monitor="val_loss", patience=5, mode="min")],
# )

In [12]:
# # Print the number of batches in each loader
# print(f"Number of training batches: {len(train_loader)}")
# print(f"Number of validation batches: {len(val_loader)}")

In [13]:
# trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

In [14]:
# trainer.test(model, test_loader)

In [15]:
# # Enhanced Residual Analysis: Regular and Studentized Residuals

# import scipy.stats as stats
# import warnings
# warnings.filterwarnings('ignore')

# # Calculate residuals
# residuals = evaluation_df['true_values'] - evaluation_df['predictions']

# # Calculate studentized residuals
# # Studentized residuals = residuals / sqrt(MSE * (1 - h_ii))
# # where h_ii is the leverage (diagonal of hat matrix)

# # For simplicity, we'll use a rolling standard deviation approach
# # This approximates studentized residuals for time series data
# residual_std = residuals.rolling(window=20, center=True, min_periods=5).std()
# studentized_residuals = residuals / (residual_std + 1e-8)  # Add small epsilon to avoid division by zero

# # Alternative: Use overall standard deviation for studentized residuals
# overall_std = residuals.std()
# studentized_residuals_simple = residuals / overall_std

# # Create comprehensive residual analysis plots
# fig, axes = plt.subplots(3, 2, figsize=(20, 15))
# fig.suptitle('Comprehensive Residual Analysis', fontsize=16, fontweight='bold')

# # 1. Regular Residuals Time Series
# axes[0, 0].plot(evaluation_df['time_idx'], residuals, 'b-', alpha=0.7, linewidth=1)
# axes[0, 0].axhline(y=0, color='red', linestyle='--', alpha=0.8)
# axes[0, 0].set_title('Regular Residuals Over Time', fontweight='bold')
# axes[0, 0].set_xlabel('Time Index')
# axes[0, 0].set_ylabel('Residuals')
# axes[0, 0].grid(True, alpha=0.3)

# # 2. Regular Residuals Scatter Plot
# axes[0, 1].scatter(evaluation_df['time_idx'], residuals, alpha=0.6, s=20, color='blue')
# axes[0, 1].axhline(y=0, color='red', linestyle='--', alpha=0.8)
# axes[0, 1].axhline(y=2*np.std(residuals), color='orange', linestyle='--', alpha=0.6, label=f'+2σ ({2*np.std(residuals):.2f})')
# axes[0, 1].axhline(y=-2*np.std(residuals), color='orange', linestyle='--', alpha=0.6, label=f'-2σ ({-2*np.std(residuals):.2f})')
# axes[0, 1].set_title('Regular Residuals Scatter Plot', fontweight='bold')
# axes[0, 1].set_xlabel('Time Index')
# axes[0, 1].set_ylabel('Residuals')
# axes[0, 1].legend()
# axes[0, 1].grid(True, alpha=0.3)

# # 3. Studentized Residuals (Rolling Std) Time Series
# axes[1, 0].plot(evaluation_df['time_idx'], studentized_residuals, 'g-', alpha=0.7, linewidth=1)
# axes[1, 0].axhline(y=0, color='red', linestyle='--', alpha=0.8)
# axes[1, 0].axhline(y=2, color='orange', linestyle='--', alpha=0.6, label='+2')
# axes[1, 0].axhline(y=-2, color='orange', linestyle='--', alpha=0.6, label='-2')
# axes[1, 0].axhline(y=3, color='red', linestyle='--', alpha=0.4, label='+3')
# axes[1, 0].axhline(y=-3, color='red', linestyle='--', alpha=0.4, label='-3')
# axes[1, 0].set_title('Studentized Residuals (Rolling Std) Over Time', fontweight='bold')
# axes[1, 0].set_xlabel('Time Index')
# axes[1, 0].set_ylabel('Studentized Residuals')
# axes[1, 0].legend()
# axes[1, 0].grid(True, alpha=0.3)

# # 4. Studentized Residuals (Rolling Std) Scatter Plot
# axes[1, 1].scatter(evaluation_df['time_idx'], studentized_residuals, alpha=0.6, s=20, color='green')
# axes[1, 1].axhline(y=0, color='red', linestyle='--', alpha=0.8)
# axes[1, 1].axhline(y=2, color='orange', linestyle='--', alpha=0.6, label='±2 (Outlier threshold)')
# axes[1, 1].axhline(y=-2, color='orange', linestyle='--', alpha=0.6)
# axes[1, 1].axhline(y=3, color='red', linestyle='--', alpha=0.4, label='±3 (Extreme outlier)')
# axes[1, 1].axhline(y=-3, color='red', linestyle='--', alpha=0.4)
# axes[1, 1].set_title('Studentized Residuals (Rolling Std) Scatter Plot', fontweight='bold')
# axes[1, 1].set_xlabel('Time Index')
# axes[1, 1].set_ylabel('Studentized Residuals')
# axes[1, 1].legend()
# axes[1, 1].grid(True, alpha=0.3)

# # 5. Residuals Distribution Histogram
# axes[2, 0].hist(residuals, bins=50, alpha=0.7, color='blue', edgecolor='black')
# axes[2, 0].axvline(x=0, color='red', linestyle='--', alpha=0.8)
# axes[2, 0].axvline(x=np.mean(residuals), color='green', linestyle='-', alpha=0.8, label=f'Mean: {np.mean(residuals):.3f}')
# axes[2, 0].set_title('Residuals Distribution', fontweight='bold')
# axes[2, 0].set_xlabel('Residual Value')
# axes[2, 0].set_ylabel('Frequency')
# axes[2, 0].legend()
# axes[2, 0].grid(True, alpha=0.3)

# # 6. Q-Q Plot for Normality Check
# stats.probplot(residuals, dist="norm", plot=axes[2, 1])
# axes[2, 1].set_title('Q-Q Plot (Normality Check)', fontweight='bold')
# axes[2, 1].grid(True, alpha=0.3)

# plt.tight_layout()
# plt.show()

# # Statistical Summary
# print("=" * 60)
# print("RESIDUAL ANALYSIS SUMMARY")
# print("=" * 60)
# print(f"Number of observations: {len(residuals)}")
# print(f"Mean of residuals: {np.mean(residuals):.4f}")
# print(f"Standard deviation of residuals: {np.std(residuals):.4f}")
# print(f"Min residual: {np.min(residuals):.4f}")
# print(f"Max residual: {np.max(residuals):.4f}")
# print(f"Skewness: {stats.skew(residuals):.4f}")
# print(f"Kurtosis: {stats.kurtosis(residuals):.4f}")

# # Outlier detection using studentized residuals
# outliers_2sigma = np.abs(studentized_residuals) > 2
# outliers_3sigma = np.abs(studentized_residuals) > 3

# print(f"\nOutlier Analysis (Studentized Residuals):")
# print(f"Points beyond ±2σ: {np.sum(outliers_2sigma)} ({100*np.mean(outliers_2sigma):.1f}%)")
# print(f"Points beyond ±3σ: {np.sum(outliers_3sigma)} ({100*np.mean(outliers_3sigma):.1f}%)")

# # Identify potential outlier time indices
# if np.sum(outliers_3sigma) > 0:
#     outlier_indices = evaluation_df['time_idx'][outliers_3sigma].values
#     print(f"\nExtreme outlier time indices: {outlier_indices[:10]}{'...' if len(outlier_indices) > 10 else ''}")

# # Autocorrelation of residuals (for time series)
# from statsmodels.graphics.tsaplots import plot_acf
# fig, ax = plt.subplots(figsize=(12, 6))
# plot_acf(residuals, lags=50, ax=ax, alpha=0.05)
# ax.set_title('Autocorrelation of Residuals', fontweight='bold')
# ax.grid(True, alpha=0.3)
# plt.show()

# print("\nAutocorrelation analysis helps identify if residuals have temporal patterns.")
# print("Significant autocorrelation indicates the model may not capture all temporal dependencies.")

# Federated Learning Simulation Setup
We'll simulate cross-silo FL with each KMeans cluster as a separate client. We'll use Flower for orchestration and a simple FedAvg strategy. If the LightningModule is incompatible, we'll reuse the same model architecture in plain PyTorch for the training loop.

In [16]:
# Install and import Flower (if not already installed)
from flwr.client import NumPyClient, Client, ClientApp
from flwr.common import Metrics, Context
from flwr.server.strategy import FedAvg
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.simulation import run_simulation

import torch
from torch.utils.data import DataLoader
import numpy as np
from typing import Dict, List, Tuple

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# DEVICE = "cpu" # Force CPU for testing
print(f"Using device: {DEVICE}")
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if DEVICE == "cuda":
    torch.cuda.manual_seed_all(SEED)

# Reuse model and dataset utilities from earlier cells
# LSTMModel, DecayMSE, TimeSeriesDataset, create_dataloader are assumed to be defined


Using device: cuda


In [17]:
# Build per-client dataloaders from clustered_dfs
ENCODER_LENGTH = 16
DECODER_LENGTH = 4
TARGET_COL = 'proportion_withdraw'
FEATURE_COLS = [col for col in df.columns if col not in ['Date Hour', 'Locker Name', 'proportion_withdraw', 'time_idx', 'size_S_withdraw']]
BATCH_SIZE = 128
WORKERS = 0  # set to 0 for portability in notebooks

client_splits = {}
for cluster, df_client in clustered_dfs.items():
    # Temporal split: 70/20/10 per client
    n = len(df_client)
    train_size = int(0.7 * n)
    val_size = int(0.2 * n)
    train_df = df_client.iloc[:train_size]
    val_df = df_client.iloc[train_size: train_size + val_size]
    test_df = df_client.iloc[train_size + val_size:]

    # Use one scaler per client, fitted on train only; share to val/test
    train_dataset = TimeSeriesDataset(train_df, ENCODER_LENGTH, DECODER_LENGTH, TARGET_COL, FEATURE_COLS)
    val_dataset = TimeSeriesDataset(val_df, ENCODER_LENGTH, DECODER_LENGTH, TARGET_COL, FEATURE_COLS, from_train_scaler=train_dataset.scaler)
    test_dataset = TimeSeriesDataset(test_df, ENCODER_LENGTH, DECODER_LENGTH, TARGET_COL, FEATURE_COLS, from_train_scaler=train_dataset.scaler)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=WORKERS)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=WORKERS)

    client_splits[cluster] = {
        'train_loader': train_loader,
        'val_loader': val_loader,
        'test_loader': test_loader,
        'train_len': len(train_loader),
        'val_len': len(val_loader),
        'test_len': len(test_loader),
    }

print(f"Prepared {len(client_splits)} clients from clusters: {list(client_splits.keys())}")

Prepared 5 clients from clusters: [np.int64(4), np.int64(3), np.int64(1), np.int64(0), np.int64(2)]


In [18]:
# Plain PyTorch training/evaluation loops
import torch.nn as nn

def train(
        model: nn.Module, 
        loader: DataLoader, 
        epochs: int, 
        device) -> float:
    
    criterion = DecayMSE(lambda_decay=0.95)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion.to(device)
    model.to(device)
    model.train()
    for epoch in range(epochs):
        total, epoch_loss = 0.0, 0.0
        for x, y, _ in loader:
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            y_hat = model(x)
            assert y_hat.shape == y.shape, f"y_hat {y_hat.shape} vs y {y.shape}"
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
            # Metric
            epoch_loss += loss.item()
            total += x.size(0)
        epoch_loss /= len(loader.dataset)
        # print(f"Epoch: {epoch+1}, Loss: {epoch_loss:.6f}, Samples: {total}")

def evaluate(model: nn.Module, loader: DataLoader, device) -> Tuple[float, float]:
    total_loss = 0.0
    total_mse = 0.0
    n_batches = 0
    criterion = DecayMSE(lambda_decay=0.95)
    criterion.to(device)
    model.to(device)
    model.eval()
    mse = nn.MSELoss()
    with torch.inference_mode():
        for x, y, _ in loader:
            x = x.to(device)
            y = y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)
            total_loss += loss.item()
            total_mse += mse(y_hat, y).item()
            n_batches += 1
    avg_loss = total_loss / max(1, n_batches)
    avg_mse = total_mse / max(1, n_batches)
    return avg_loss, avg_mse

trainloader, valloader, testloader = client_splits[0]['train_loader'], client_splits[0]['val_loader'], client_splits[0]['test_loader']
net = LSTMModel(
    input_size=INPUT_SIZE,
    output_horizon=DECODER_LENGTH,
    hidden_size=HIDDEN_SIZE,
    num_layers=NUM_LAYERS,
)

for epoch in range(5):
    train(net, trainloader, 3, DEVICE)
    loss, mse = evaluate(net, valloader, DEVICE)
    print(f"Epoch {epoch+1}: validation loss {loss}, mse {mse}")

loss, mse = evaluate(net, testloader, DEVICE)
print(f"Final test set performance:\n\tloss {loss}\n\tmse {mse}")

Epoch 1: validation loss 0.03439931385219097, mse 0.08461722855766614
Epoch 2: validation loss 0.03231489968796571, mse 0.08159168809652328
Epoch 3: validation loss 0.031362141172091164, mse 0.08036963765819867
Epoch 4: validation loss 0.03066612035036087, mse 0.07993355020880699
Epoch 5: validation loss 0.03094364112863938, mse 0.08020769680539767
Final test set performance:
	loss 0.02644224651157856
	mse 0.06955851862827937


In [19]:
# Flower NumPyClient wrapping local training
from collections import OrderedDict

class FLClient(NumPyClient):
    def __init__(self, net: nn.Module, loaders: Dict[str, DataLoader], epochs: int, lr: float = 1e-4):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.epochs = epochs
        self.model = net
        self.loaders = loaders
        self.model

    def get_parameters(self, config: Dict[str, str] = None) -> List[np.ndarray]:
        params = [val.cpu().numpy() for _, val in self.model.state_dict().items()]
        return params

    def set_parameters(self, parameters: List[np.ndarray]):
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        train(self.model, self.loaders['train_loader'], self.epochs, self.device)
        return self.get_parameters({}), self.loaders['train_len'], {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        loss, mse = evaluate(self.model, self.loaders['test_loader'], self.device)
        num_examples = self.loaders['test_len']
        return float(loss), num_examples, {"mse": float(mse)}

In [None]:
# Define model constructor compatible with FL clients
INPUT_SIZE = len(FEATURE_COLS)
HIDDEN_SIZE = 256
NUM_LAYERS = 8
NUM_ROUNDS = 7
cluster_ids = list(sorted(client_splits.keys()))
NUM_CLIENTS = len(cluster_ids)


def make_model():
    m = LSTMModel(
        input_size=INPUT_SIZE,
        output_horizon=DECODER_LENGTH,
        hidden_size=HIDDEN_SIZE,
        num_layers=NUM_LAYERS
    )
    return m

# Map cluster index to client id list for deterministic ordering

def client_fn(context: Context) -> Client:
    partition_id = context.node_config["partition-id"]
    net = make_model()
    net.to(DEVICE)
    cluster = cluster_ids[partition_id]
    loaders = client_splits[cluster]
    return FLClient(net, loaders, epochs=3).to_client()

client_app = ClientApp(client_fn=client_fn)

In [21]:

strategy = FedAvg(
    fraction_fit=1.0,  # use all clients each round (few clusters)
    fraction_evaluate=1.0,
    min_fit_clients=2,
    min_evaluate_clients=5,
    min_available_clients=len(cluster_ids),
)
# Server config
def server_fn(context: Context) -> ServerAppComponents:
    # Strategy and simulation config

    config = ServerConfig(num_rounds=NUM_ROUNDS)
    return ServerAppComponents(strategy=strategy, config=config)

server = ServerApp(server_fn=server_fn)

In [22]:

backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0}}

# When running on GPU, assign an entire GPU for each client
if DEVICE == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1.0}}
    print("Using GPU for clients")


# Start simulation
run_simulation( 
    server_app=server,
    client_app=client_app,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=5, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client


Using GPU for clients


[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)
[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)
[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)
[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)
[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strate

In [23]:
# Aggregate global parameters from the last round and evaluate on concatenated test set
# Note: Flower simulation returns metrics in `history`; for simplicity, we'll re-run a global model
# initialized and averaged by FedAvg inside the server by querying one finished client.

# Build a fresh global model and set to parameters from client 0 after final round
# (In real setups, pull from strategy weights; Flower 1.x simplifies via client evaluation path.)

# For demo, request parameters from one client by re-instantiating and evaluating with averaged server params is non-trivial.
# Instead, we can simulate global evaluation by averaging client weights explicitly if needed.

# Optional: Evaluate a locally trained model (non-FL) to compare baselines
baseline_model = make_model().to(DEVICE)
optimizer = torch.optim.Adam(baseline_model.parameters(), lr=1e-4)
criterion = DecayMSE(lambda_decay=0.95)

# Quick single-epoch baseline on concatenated train loaders (not required, but useful)
# Concatenate all client train loaders into one iterable
concat_batches = []
for cid in cluster_ids:
    for batch in client_splits[cid]['train_loader']:
        concat_batches.append(batch)

baseline_model.train()
for x, y, _ in concat_batches:
    x = x.to(DEVICE)
    y = y.to(DEVICE)
    optimizer.zero_grad()
    y_hat = baseline_model(x)
    loss = criterion(y_hat, y)
    loss.backward()
    optimizer.step()

# Evaluate on concatenated test loaders
test_loader_all = []
for cid in cluster_ids:
    test_loader_all.append(client_splits[cid]['test_loader'])

# Simple evaluation loop over all test loaders
@torch.no_grad()
def eval_all_loaders(model, loaders: List[DataLoader]):
    model.eval()
    mse = nn.MSELoss()
    total_loss, total_mse, n_batches = 0.0, 0.0, 0
    for loader in loaders:
        for x, y, _ in loader:
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            y_hat = model(x)
            total_loss += criterion(y_hat, y).item()
            total_mse += mse(y_hat, y).item()
            n_batches += 1
    return total_loss / max(1, n_batches), total_mse / max(1, n_batches)

b_loss, b_mse = eval_all_loaders(baseline_model, test_loader_all)
print(f"Baseline single-epoch model — loss: {b_loss:.6f}, mse: {b_mse:.6f}")

Baseline single-epoch model — loss: 0.026127, mse: 0.065338
