<h1>Benchmarking Catastrophic Forgetting Mitigation Methods in Federated Time Series Forecasting </h1> 

Dataset link: https://archive.ics.uci.edu/dataset/501/beijing+multi+site+air+quality+data

How to run:

pip install -r requirements.txt

Run each of the steps in the following order: 

1. Imports & Global Config (modify the hyperparameters)

2. Data Loading, Basic Cleaning, Feature Engineering

3. Splits, Global Robust Normalization, Lagged Samples

4. Dataset & DataLoaders

5. Model & Utilities (Base training or loading)

6. Train Offline Base Model

7. Initialize Replay Buffers & Fisher

8. Continual Online learning function

9. Run Ablation & Evaluate all the proposed methods (Naive, Replay, KD, EWC, O-EWC, SI)

   

<h2> 1. Imports & Global Config </h2> 

In [26]:
# ------------------------------------------------------------
# Benchmarking Catastrophic Forgetting in Federated TS Forecasting
# CONFIG + PRACTICAL TIPS FOR REPRO USERS
# ------------------------------------------------------------

import os
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.cluster import KMeans

import matplotlib.pyplot as plt
plt.rcParams.update({"figure.figsize": (12, 6), "axes.grid": True})

# ---- Device & deterministic seeds ----
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def set_all_seeds(seed: int = 42) -> None:
    import random
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_all_seeds(42)

# ---- Paths ----
# TIP: We intentionally do NOT commit raw data. Put CSVs under this folder.
# If you change repo layout, update this path accordingly.
DATASET_DIR = Path("../dataset/AirQuality")

# TIP: Expected 12 station files (UCI dataset). If you subset to run faster,
# you can comment out a few lines below‚Äîthe pipeline adapts automatically.
DATASET_FILES = [
    "Aotizhongxin.csv","Changping.csv","Dingling.csv","Dongsi.csv","Guanyuan.csv","Gucheng.csv",
    "Huairou.csv","Nongzhanguan.csv","Shunyi.csv","Tiantan.csv","Wanliu.csv","Wanshouxigong.csv"
]

# ------------------- Forecasting Task & Shapes -----------------
# TARGET_COL: what we predict. Options commonly tested: "WSPM", "PM2.5", "TEMP".
# TIP: If you switch TARGET_COL, you do NOT need to change features below;
# the pipeline includes target in the model input by default.
TARGET_COL = "WSPM"   # e.g., set to "PM2.5" or "TEMP" to reproduce other tasks

# N_LAGS: input window length (hours). PRED_LEN: how many steps ahead (hours) to predict.
# ‚Ä¢ Larger N_LAGS gives model more context but increases memory/compute.
# ‚Ä¢ Larger PRED_LEN makes the task harder (multi-step horizon).
N_LAGS   = 12
PRED_LEN = 6

# SPLIT_RATIO within each continual task: 80% train / 20% test by default.
# TIP: Keep this ratio fixed when comparing methods; changing it shifts metrics.
SPLIT_RATIO = 0.8      

# Batch size: set lower on CPU or memory-constrained GPUs.
BATCH_SIZE  = 32

# ------------------- Base Federated Pretraining ----------------
# Base FL pretraining stabilizes the initial model before continual learning.
# ‚Ä¢ NUM_ROUNDS_BASE: communication rounds over base data (increase for better base).
# ‚Ä¢ LR_BASE, LOCAL_EPOCHS_BASE: local optimization settings per client.
# TIP: If you are doing a quick smoke run, reduce NUM_ROUNDS_BASE (e.g., 10‚Äì20).
NUM_ROUNDS_BASE  = 200     # paper-grade: can be 200‚Äì500+
LR_BASE          = 1e-5
LOCAL_EPOCHS_BASE = 1

# --------------------- Continual Learning Loop -----------------
# NUM_ROUNDS_CL: rounds per task during continual learning.
# LR: learning rate used during CL (can differ from base LR).
# TIP: For fast runs, set NUM_ROUNDS_CL=5‚Äì10; for paper fidelity, use 20‚Äì30+.
NUM_ROUNDS_CL = 25
LOCAL_EPOCHS  = 1
LR            = 1e-5

# ----------------------- Method Coefficients -------------------
# These are regularization/auxiliary-loss weights. Defaults reflect paper tuning.
# ‚Ä¢ SI_COEFF: Synaptic Intelligence Œª
# ‚Ä¢ KD_COEFF: Distillation strength (if KD is enabled)
# ‚Ä¢ EWC_COEFF: Classic EWC penalty weight
# ‚Ä¢ ONLINE_EWC_COEFF: Online EWC (moving Fisher) penalty weight
# ‚Ä¢ REPLAY_COEFF: Weight on replay loss when buffers are used
# ‚Ä¢ REPLAY_RATIO: Fraction of task samples selected into buffer via KMeans
# ‚Ä¢ BUFFER_CAPACITY: Max stored (x,y) sequences per client
#
# Tuning guidance:
#  - If training diverges, lower LR and/or the largest active coefficient.
#  - If AF (forgetting) is high, try increasing EWC/SI or enabling replay.
#  - If AP (plasticity) is too low (model underfits new tasks), reduce regularization.
SI_COEFF          = 14
KD_COEFF          = 120
EWC_COEFF         = 1e8
ONLINE_EWC_COEFF  = 1.3e6
REPLAY_COEFF      = 0.8
REPLAY_RATIO      = 0.15
BUFFER_CAPACITY   = 100000


# -------------------------- Practical Notes --------------------
# ‚Ä¢ GPU vs CPU: All code runs on CPU but will be slow. If GPU present, it‚Äôs auto-used.
# ‚Ä¢ Memory: KMeans for replay runs on flattened sequences; REPLAY_RATIO scales time/memory.
#   Set REPLAY_RATIO=0 to disable replay entirely (or choose the "Naive"/"kd"/"ewc" modes).
# ‚Ä¢ Repro table: Our evaluation matches the paper‚Äôs ‚Äúlegacy‚Äù protocol (P_{N,1..N-1} for AvgPerf,
#   AF via P_{N,j}-P_{j,j}, AP via diag mean). For ‚Äústrict‚Äù CL metrics with best-past baseline,
#   see the alternative evaluator in comments.

<h1>2. Data Loading, Basic Cleaning, Feature Engineering</h1> 

In [8]:
# 2) Load raw CSVs ‚Üí dict of per-station DataFrames
# -------------------------------------------------
# Tips for reusers:
# ‚Ä¢ We don't ship the data. Ensure DATASET_DIR points to the folder holding the 12 CSVs.
# ‚Ä¢ If you subset stations for a quick run, just remove their filenames from DATASET_FILES.
# ‚Ä¢ All timestamps are interpreted as naive local time; we sort by index immediately.

def load_clients(dataset_dir: Path, files: List[str]) -> Dict[str, pd.DataFrame]:
    clients = {}
    for f in files:
        df = pd.read_csv(dataset_dir / f)
        df["time"] = pd.to_datetime(df[["year","month","day","hour"]])
        df = df.set_index("time").sort_index()
        clients[f.replace(".csv","")] = df
    return clients

client_dfs = load_clients(DATASET_DIR, DATASET_FILES)

# -------------------------------------------------
# Missing-value handling
# -------------------------------------------------
# Strategy:
# ‚Ä¢ Numeric columns: time interpolation (linear in time) with both-sided fill at edges.
# ‚Ä¢ Categorical 'wd' (wind direction): forward/backward fill.
# ‚Ä¢ Finally, drop any row still containing NaNs (rare after the two steps).
# Why this matters:
# ‚Ä¢ Interpolated gaps keep temporal coherence and avoid data leakage from future tasks.
# ‚Ä¢ Strict drop at the end prevents NaNs from breaking metrics later.

def smart_imputation(df: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()
    cont_cols = out.select_dtypes(include=[np.number]).columns
    out[cont_cols] = out[cont_cols].interpolate(method="time", limit_direction="both")
    if "wd" in out.columns:
        out["wd"] = out["wd"].ffill().bfill()
    out = out.dropna()
    return out

client_dfs = {k: smart_imputation(v) for k, v in client_dfs.items()}

# -------------------------------------------------
# Wind direction encoding (circular features)
# -------------------------------------------------
# We map 16-point compass directions to angles (deg), then to sin/cos.
# Unknown tokens (if any) map to angle 0 (calm) to avoid NaNs.

WD2ANG = {'N':0,'NNE':22.5,'NE':45,'ENE':67.5,'E':90,'ESE':112.5,'SE':135,'SSE':157.5,
          'S':180,'SSW':202.5,'SW':225,'WSW':247.5,'W':270,'WNW':292.5,'NW':315,'NNW':337.5}

def encode_wind_direction(df: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()
    if "wd" in out.columns:
        ang = out["wd"].map(WD2ANG).fillna(0.0)
        rad = np.deg2rad(ang)
        out["wd_sin"] = np.sin(rad); out["wd_cos"] = np.cos(rad)
    return out

client_dfs = {k: encode_wind_direction(v) for k, v in client_dfs.items()}

# -------------------------------------------------
# Cyclical time features (hour/month/day)
# -------------------------------------------------
# Tip:
# ‚Ä¢ We retain the original integer columns elsewhere for potential ablations.
# ‚Ä¢ If your downstream only uses the encodings, that‚Äôs fine ‚Äî keeping both is harmless.

def add_time_cycles(df: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()
    if all(c in out.columns for c in ["hour","month","day"]):
        out["hour_sin"]  = np.sin(2*np.pi*out["hour"]/24)
        out["hour_cos"]  = np.cos(2*np.pi*out["hour"]/24)
        out["month_sin"] = np.sin(2*np.pi*out["month"]/12)
        out["month_cos"] = np.cos(2*np.pi*out["month"]/12)
        out["day_sin"]   = np.sin(2*np.pi*out["day"]/31)
        out["day_cos"]   = np.cos(2*np.pi*out["day"]/31)
    return out

client_dfs = {k: add_time_cycles(v) for k, v in client_dfs.items()}


# -------------------------------------------------
# Final feature set
# -------------------------------------------------
# Include 'year' to allow robust global [1,99] percentile normalization per feature.
# Tip for custom tasks:
# ‚Ä¢ You can change TARGET_COL elsewhere (e.g., "PM2.5") without modifying this list;
#   the modeling code handles inclusion/exclusion of target correctly.

FEATURES = [
    "year",
    "PM2.5","PM10","SO2","NO2","CO","O3",
    "TEMP","PRES","DEWP","RAIN","WSPM",
    "wd_sin","wd_cos",
    "hour_sin","hour_cos","month_sin","month_cos","day_sin","day_cos",
]
client_dfs = {k: v[[c for c in FEATURES if c in v.columns]].copy() for k, v in client_dfs.items()}


 <h1>3. Splits, Global Robust Normalization, Lagged Samples</h1> 

In [21]:
# 3) Chronological splits, global robust normalization, lagged samples
# -------------------------------------------------------------------
# Why this matters:
# ‚Ä¢ Splits follow the paper‚Äôs protocol: a base window, a tiny base_test day, then 11 tasks.
# ‚Ä¢ GLOBAL_BOUNDS is computed once across all clients to avoid data leakage across tasks.
# ‚Ä¢ create_lagged_samples builds (X,y) windows with N_LAGS history and PRED_LEN horizon.

# ---- Task schedule (paper protocol) ----
TASK_RANGES = [
    ("2014-05-05","2014-08-06"),("2014-08-07","2014-11-06"),("2014-11-07","2015-02-02"),
    ("2015-02-03","2015-05-04"),("2015-05-05","2015-08-06"),("2015-08-07","2015-11-06"),
    ("2015-11-07","2016-02-02"),("2016-02-03","2016-05-04"),("2016-05-05","2016-08-06"),
    ("2016-08-07","2016-11-06"),("2016-11-07","2017-02-02"),
]

def split_ranges(df: pd.DataFrame) -> Tuple[pd.DataFrame,pd.DataFrame,List[pd.DataFrame]]:
        """
    Returns base, base_test, and a list of task DataFrames (chronological).
    Guardrails:
    ‚Ä¢ Drops empty task windows (rare per station if raw dumps differ).
    ‚Ä¢ Keeps exact time bounds to match the paper.
    """
    base = df.loc["2013-05-01":"2014-05-03 23:00:00"]
    base_test = df.loc["2014-05-04":"2014-05-04 23:00:00"]
    tasks = [df.loc[s:e].copy() for (s,e) in TASK_RANGES if not df.loc[s:e].empty]
    return base, base_test, tasks

# Global robust [1,99] percentile bounds across all clients
def compute_global_bounds(clients: Dict[str, pd.DataFrame], cols: List[str]) -> Dict[str, Tuple[float,float]]:
    percs = {c: [] for c in cols}
    for _, df in clients.items():
        for c in cols:
            if c in df and pd.api.types.is_numeric_dtype(df[c]):
                arr = df[c].dropna().values
                if arr.size:
                    percs[c].append((np.percentile(arr,1), np.percentile(arr,99)))
    bounds = {c: (min(x for x,_ in percs[c]), max(y for _,y in percs[c])) for c in cols if percs[c]}
    return bounds

GLOBAL_BOUNDS = compute_global_bounds(client_dfs, FEATURES)

def normalize_df_globally(df: pd.DataFrame, bounds: Dict[str,Tuple[float,float]], ordered_cols: List[str]) -> pd.DataFrame:
        """
    Apply robust min‚Äìmax using precomputed global bounds.
    ‚Ä¢ Column order is enforced by `ordered_cols` to keep model input consistent.
    ‚Ä¢ If a column has no bounds (e.g., missing in some clients), it's left unchanged.
    ‚Ä¢ Constant features (hi==lo) become 0.0.
    """
    # keep only available columns in the defined order
    cols = [c for c in ordered_cols if c in df.columns]
    out = df[ordered_cols].copy()
    for c in out.columns:
        if c in bounds and pd.api.types.is_numeric_dtype(out[c]):
            lo, hi = bounds[c]
            out[c] = 0.0 if hi == lo else (out[c]-lo)/(hi-lo)
            out[c] = out[c].clip(0,1)
    return out

def create_lagged_samples(df: pd.DataFrame, n_lags: int, pred_len: int, target_col: str = "WSPM",
                          include_target_in_input: bool = True) -> Tuple[np.ndarray,np.ndarray]:

        """
    Build overlapping windows:
      X: [n_samples, n_lags, n_features], Y: [n_samples, pred_len]
    ‚Ä¢ include_target_in_input=True matches our paper‚Äôs setup (target also in X).
    ‚Ä¢ NaN windows are skipped defensively (should be rare post-cleaning).
    """
    
    cols = df.columns.tolist() if include_target_in_input else [c for c in df.columns if c != target_col]
    Xv = df[cols].astype(np.float64).values
    Yv = df[[target_col]].astype(np.float64).values
    X, Y = [], []
    for i in range(n_lags, len(df)-pred_len+1):
        xw = Xv[i-n_lags:i]
        yw = Yv[i:i+pred_len].flatten()
        if np.isnan(xw).any() or np.isnan(yw).any(): 
            continue
        X.append(xw); Y.append(yw)
    return np.asarray(X), np.asarray(Y)

# Build all lagged splits per client
base_lagged, task_lagged = {}, {"base_test": {}}

for client, df in client_dfs.items():
    base, base_test, tasks = split_ranges(df)
    base_n   = normalize_df_globally(base, GLOBAL_BOUNDS, FEATURES)
    baseT_n  = normalize_df_globally(base_test, GLOBAL_BOUNDS, FEATURES)
    Xb, yb   = create_lagged_samples(base_n, N_LAGS, PRED_LEN, TARGET_COL)
    Xbt, ybt = create_lagged_samples(baseT_n, N_LAGS, PRED_LEN, TARGET_COL)

    base_lagged[client] = {"X": Xb, "y": yb}
    task_lagged["base_test"][client] = {"X": Xbt, "y": ybt}

    for t_idx, tdf in enumerate(tasks, 1):
        tn = normalize_df_globally(tdf, GLOBAL_BOUNDS, FEATURES)
        Xt, yt = create_lagged_samples(tn, N_LAGS, PRED_LEN, TARGET_COL)
        task_lagged.setdefault(f"task_{t_idx}", {})[client] = {"X": Xt, "y": yt}


NUM_TASKS = len([k for k in task_lagged if k.startswith("task_")])
assert NUM_TASKS > 0, "No tasks found; check TASK_RANGES."


 <h1>4. Dataset & DataLoaders</h1> 

In [None]:
# 4) Dataset & DataLoaders
# ------------------------
# Tips for readers:
# ‚Ä¢ Determinism: we use a fixed torch.Generator for shuffling.
# ‚Ä¢ Empty windows: if a (client, task) yields 0 samples after lagging, we still
#   register an *empty* DataLoader so downstream code won‚Äôt KeyError; it will
#   simply have 0 batches.
# ‚Ä¢ You can increase num_workers once things run locally; we keep 0 in notebooks
#   to avoid multiprocessing issues on some platforms.

class TimeSeriesDataset(Dataset):
    def __init__(self, X: np.ndarray, y: np.ndarray):
        self.X = torch.as_tensor(X, dtype=torch.float32)
        self.y = torch.as_tensor(y, dtype=torch.float32)
    def __len__(self): return len(self.X)
    def __getitem__(self, i): return self.X[i], self.y[i]

client_dls = {"base_train": {}, "base_test": {}}
for i in range(NUM_TASKS):
    client_dls[f"task_{i+1}_train"] = {}
    client_dls[f"task_{i+1}_test"]  = {}

for c in base_lagged:
    Xb, yb = base_lagged[c]["X"], base_lagged[c]["y"]
    client_dls["base_train"][c] = DataLoader(TimeSeriesDataset(Xb,yb), batch_size=BATCH_SIZE, shuffle=True)

    Xbt, ybt = task_lagged["base_test"][c]["X"], task_lagged["base_test"][c]["y"]
    client_dls["base_test"][c] = DataLoader(TimeSeriesDataset(Xbt,ybt), batch_size=BATCH_SIZE, shuffle=False)

    for i in range(NUM_TASKS):
        key = f"task_{i+1}"
        Xt, yt = task_lagged[key][c]["X"], task_lagged[key][c]["y"]
        print(f"üìé {c} - {key}: total = {len(Xt)}")
        
        split = int(len(Xt)*SPLIT_RATIO)



 <h1>5. Model & Utilities (Base training or loading)</h1> 

In [11]:
# 5) Model & Utilities (base model + Fisher + replay buffer)
# ----------------------------------------------------------
# Tips for readers:
# ‚Ä¢ LSTM head: we read the last hidden state only (common for seq2one / seq2many).
# ‚Ä¢ Keep INPUT_DIM/HIDDEN_DIM/OUTPUT_DIM small for quick runs; scale up for paper runs.
# ‚Ä¢ Fisher is estimated on a few batches (num_batches) for speed ‚Äî increase for stability.
# ‚Ä¢ Replay buffer uses K-Means coresets over flattened windows; deterministic via random_state.

class LSTMPredictor(nn.Module):
        """
    Minimal LSTM forecaster:
      input:  [B, T, D]
      output: [B, PRED_LEN]
    We project only the last timestep‚Äôs hidden state.
    """
    
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 1):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, batch_first=True)
        self.fc   = nn.Linear(hidden_dim, output_dim)
    def forward(self, x):  # x: [B, T, D]
        out,_ = self.lstm(x)
        return self.fc(out[:, -1, :])
        
# Auto-detect model dims from prepared windows
INPUT_DIM = base_lagged[next(iter(base_lagged))]["X"].shape[2]
HIDDEN_DIM = 64
OUTPUT_DIM = PRED_LEN

class ReplayBuffer:
        """
    Simple (x,y) FIFO replay buffer.
    ‚Ä¢ Stores CPU tensors to keep GPU memory free.
    ‚Ä¢ add(): ignores additions if capacity==0 (i.e., replay disabled).
    ‚Ä¢ sample(): returns a batch; if buffer smaller than asked, returns all.
    """
    def __init__(self, capacity: int):
        self.capacity = capacity
        self.buffer: List[Tuple[torch.Tensor,torch.Tensor]] = []
    def add(self, x, y):
        if self.capacity == 0: return
        if len(self.buffer) >= self.capacity: self.buffer.pop(0)
        self.buffer.append((x.detach().cpu(), y.detach().cpu()))
    def sample(self, batch_size: int):
        if not self.buffer: return (None,None)
        idx = np.random.choice(len(self.buffer), min(batch_size, len(self.buffer)), replace=False)
        X, y = zip(*[self.buffer[i] for i in idx])
        return torch.stack(X), torch.stack(y)
    def __len__(self): return len(self.buffer)

def compute_ewc_fisher(model: nn.Module, loader: DataLoader, device=DEVICE, num_batches=10):
        """
    Diagonal Fisher approximation via gradient^2 of MSE loss.
    Notes:
    ‚Ä¢ num_batches controls speed/variance trade-off; increase for more stable EWC.
    ‚Ä¢ Model stays in eval() since we only need gradients of the loss w.r.t. parameters.
    """
    
    model.eval().to(device)
    fis = {n: torch.zeros_like(p, device=device) for n,p in model.named_parameters()}
    count = 0
    for X,y in loader:
        X, y = X.to(device), y.to(device)
        model.zero_grad()
        loss = nn.functional.mse_loss(model(X), y)
        loss.backward()
        for n,p in model.named_parameters():
            if p.grad is not None:
                fis[n] += (p.grad.detach()**2)
        count += 1
        if count >= num_batches: break
    for n in fis: fis[n] /= max(count,1)
    return fis

def init_base_buffer(predictor: nn.Module, loader: DataLoader, capacity=BUFFER_CAPACITY, sample_fraction=0.05):
        """
    Build an initial replay coreset from base data via K-Means on flattened inputs.
    Steps:
      1) Collect all (X,y) from loader on CPU.
      2) Flatten X: [N, T*D] for clustering (fast & simple).
      3) Select closest sample to each centroid ‚áí diverse set.
    """
    predictor.eval()
    Xs, Ys = [], []
    with torch.no_grad():
        for X,y in loader:
            Xs.append(X.cpu()); Ys.append(y.cpu())
    Xall = torch.cat(Xs, 0); Yall = torch.cat(Ys, 0)
    embed = Xall.view(Xall.size(0), -1).numpy()
    n_sel = max(1, int(sample_fraction * len(Xall)))
    if n_sel >= len(embed):
        sel_idx = np.arange(len(embed))
    else:
        km = KMeans(n_clusters=n_sel, random_state=42).fit(embed)
        sel_idx = []
        for k in range(n_sel):
            ids = np.where(km.labels_ == k)[0]
            if ids.size == 0: continue
            D = np.linalg.norm(embed[ids] - km.cluster_centers_[k], axis=1)
            sel_idx.append(ids[np.argmin(D)])
    buf = ReplayBuffer(capacity)
    for i in sel_idx: buf.add(Xall[i], Yall[i])
    return buf


 <h1>6. Base Model: Train (optional) or Load (default for speed) </h1> 

In [13]:
# 6) Base Model: Train (optional) or Load
# ---------------------------------------
# Tips for readers:
# ‚Ä¢ NUM_ROUNDS_BASE / LR_BASE / LOCAL_EPOCHS_BASE control the offline FL pretraining budget.
# ‚Ä¢ We accumulate Synaptic Intelligence (SI) contributions per mini-batch:
#       W += (theta_{t+1} - theta_t) * (-grad_t)
#   and compute per-parameter omega:
#       omega = W / ( (theta_T - theta_star)^2 + xi )
#   with theta_star = params BEFORE local training on this client.
# ‚Ä¢ FedAvg averages client weights each round.

def train_federated_base(client_dls, input_dim, output_dim, hidden_dim=HIDDEN_DIM,
                         num_rounds=NUM_ROUNDS_BASE, local_epochs=LOCAL_EPOCHS_BASE, lr=LR_BASE, device=DEVICE):
    global_model = LSTMPredictor(input_dim, hidden_dim, output_dim).to(device)
    gW = global_model.state_dict()

    client_models = {}
    si_omegas, si_prev, si_W = {}, {}, {}

    for r in range(num_rounds):
        local_states = []
        for c, loader in client_dls["base_train"].items():
            m = LSTMPredictor(input_dim, hidden_dim, output_dim).to(device)
            m.load_state_dict(gW); m.train()
            opt = torch.optim.Adam(m.parameters(), lr=lr)
            prev = {n: p.clone().detach() for n,p in m.named_parameters()}
            
            # SI accumulator W initialized to zeros (same shapes)
            W    = {n: torch.zeros_like(p) for n,p in m.named_parameters()}
            for _ in range(local_epochs):
                for X,y in loader:
                    X,y = X.to(device), y.to(device)
                    opt.zero_grad(); loss = nn.functional.mse_loss(m(X), y)
                    loss.backward(); opt.step()
                    for n,p in m.named_parameters():
                        if p.grad is not None:
                            delta = p.detach() - prev[n]
                            W[n] += delta * (-p.grad.detach())
            local_states.append({k: v.detach().clone() for k,v in m.state_dict().items()})
            client_models[c] = m
            # SI omega
            omega = {}
            for n,p in m.named_parameters():
                delta = p.detach() - prev[n]
                omega[n] = W[n] / (delta.pow(2) + 1e-3)
            si_omegas[c] = omega; si_prev[c] = {n: p.clone().detach() for n,p in m.named_parameters()}
            si_W[c] = W

        # FedAvg
        newW = {k: sum(ls[k] for ls in local_states)/len(local_states) for k in local_states[0].keys()}
        gW = newW

    global_model.load_state_dict(gW)
    return global_model, client_models, si_omegas, si_prev, si_W

# --- Run base FL pretraining  ---
base_model, base_clients, si_omegas, si_prev, si_W = train_federated_base(
        client_dls, INPUT_DIM, OUTPUT_DIM, hidden_dim=HIDDEN_DIM, num_rounds=NUM_ROUNDS_BASE,
        local_epochs=LOCAL_EPOCHS_BASE, lr=LR_BASE, device=DEVICE
    )

# Aliases used later in the notebook (keeps naming consistent)
CLIENT_DATALOADERS = client_dls           # your dataloaders dict
BASE_MODEL        = base_model            # global base model
BASE_CLIENTS      = base_clients          # per-client base models


 <h1>7. Initialize Replay Buffers & Fisher (once)  </h1> 

In [35]:
# 7) Initialize Replay Buffers & Fisher (once)
# -------------------------------------------
# Tips for readers:
# ‚Ä¢ Replay coreset is built from the BASE train split using K-Means over flattened windows.
# ‚Ä¢ If REPLAY_RATIO is small and a client has very few samples, we keep all of them.
# ‚Ä¢ Fisher is a diagonal approximation (grad^2 of MSE) computed on a few batches
#   for speed. Increase the internal num_batches in `compute_ewc_fisher` for stability.

initial_buffers = {}
for c in client_dls["base_train"]:
    initial_buffers[c] = init_base_buffer(base_model, client_dls["base_train"][c],
                                          capacity=BUFFER_CAPACITY, sample_fraction=REPLAY_RATIO)

# Compute per-client Fisher wrt their base reference
fisher = {}
for c in client_dls["base_train"]:
    ref = base_clients.get(c, base_model)  # fall back to global if no per-client base
    fisher[c] = compute_ewc_fisher(ref, client_dls["base_train"][c], device=DEVICE)


  <h1>8. Continual Learning </h1>  

In [29]:
def continual_learning_hybrid(
    client_dataloaders,
    client_buffers,
    base_predictor,
    base_clients,           
    num_tasks,
    num_rounds_dict,
    local_epochs,
    lr,
    device,
    mode,
    replay_ratio,
    distil_coef,
    EwcCoeff,
    online_ewc_coeff,
    replay_coeff,
    fisher_matrices=None,
    si_coeff=None,               # SI Œª
    si_omegas=None,              # per-client per-param œâ
    si_contributions=None,       # per-client per-param W accumulator
    si_prev_params=None          # per-client Œ∏* (previous task snapshot)
):

    # TIP: This routine performs task-by-task federated updates with optional CF mitigations.
    # It returns a list of immutable model snapshots, one after each task (for evaluation matrix P).

    import torch, copy, numpy as np
    import torch.nn as nn
    from sklearn.cluster import KMeans
    import torch._dynamo
    torch._dynamo.disable()

    # Reconstruct dims from base model
    in_dim  = base_predictor.lstm.input_size
    hid_dim = base_predictor.lstm.hidden_size
    out_dim = base_predictor.fc.out_features

    global_predictor = copy.deepcopy(base_predictor).to(device)
    global_weights = global_predictor.state_dict()

    clients = list(client_dataloaders["base_train"].keys())
    replay_buffers = {c: copy.deepcopy(client_buffers[c]) for c in clients}
    predictors = {c: copy.deepcopy(global_predictor) for c in clients}
    checkpoints = []

    # Online EWC reference weights
    ewc_reference_weights = {c: copy.deepcopy(base_predictor) for c in clients}
    # HINT: For "online_ewc", these references are updated after each task using EMA Fisher.

    for task_id in range(1, num_tasks + 1):
        print(f"\nüîÅ Task {task_id}   mode = {mode}")
        task_train_key = f"task_{task_id}_train"

        # --- SI: initialize per-task accumulators ---
        if mode in ["si"]:
            assert si_contributions is not None and si_omegas is not None and si_prev_params is not None, \
                "Provide SI structures for SI modes."
            for c in clients:
                si_contributions.setdefault(c, {})
                # zero tensors with correct shapes
                for name, p in global_predictor.named_parameters():
                    si_contributions[c][name] = torch.zeros_like(p, device=device)
            # TIP: We re-zero per-task accumulators here; omegas and Œ∏* persist across tasks.

        assert isinstance(num_rounds_dict, dict) and mode in num_rounds_dict
        num_rounds = num_rounds_dict[mode]
        print(f"num_rounds = {num_rounds}")

        for r in range(num_rounds):
            print(f"üåê Communication Round {r+1}/{num_rounds}")
            local_weights = []

            for c in clients:
                predictor = copy.deepcopy(global_predictor).to(device)
                old_predictor = copy.deepcopy(predictors[c]).to(device)
                loader = client_dataloaders[task_train_key][c]
                buffer = replay_buffers[c]

                optimizer = torch.optim.Adam(predictor.parameters(), lr=lr)
                predictor.train()

                for _ in range(local_epochs):
                    for X, y in loader:
                        X, y = X.to(device), y.to(device)

                        # ---- forward & base loss ----
                        y_pred = predictor(X)
                        reg_loss = nn.functional.mse_loss(y_pred, y)

                        # ---- optional losses ----
                        replay_loss = 0.0
                        kd_loss = 0.0
                        ewc_loss = 0.0
                        ewc_loss_online = 0.0
                        si_loss = 0.0

                        if mode in ["replay"]:
                            rX, rY = buffer.sample(X.shape[0])
                            if rX is not None:
                                rX, rY = rX.to(device), rY.to(device)
                                rY_pred = predictor(rX)
                                replay_loss = nn.functional.mse_loss(rY_pred, rY)

                        if mode in ["kd"]:
                            # KD teacher = previous local predictor for this client (stability signal)
                            with torch.no_grad():
                                teacher_output = old_predictor(X).detach()
                            kd_loss = nn.functional.mse_loss(y_pred, teacher_output)

                        if fisher_matrices is not None:
                           
                            for name, p in predictor.named_parameters():
                                F = fisher_matrices.get(c, {}).get(name, None)
                                if F is None: 
                                    continue
                                if mode in ["ewc"]:
                                    # Classic EWC: anchor to the client's base model (or global fallback)
                                    ref_sd = base_clients[c].state_dict() if c in base_clients else base_predictor.state_dict()
                                    p0 = ref_sd[name].to(device)
                                    ewc_loss += (F * (p - p0).pow(2)).sum()
                                if mode in ["online_ewc"]:
                                    # Online EWC: anchor to moving reference after each task
                                    pref = ewc_reference_weights[c].state_dict()[name].to(device)
                                    ewc_loss_online += (F * (p - pref).pow(2)).sum()

                        if mode in ["si"]:
                            for name, p in predictor.named_parameters():
                                if name in si_omegas[c]:
                                    omega = si_omegas[c][name]
                                    theta_star = si_prev_params[c][name]
                                    si_loss += (omega * (p - theta_star).pow(2)).sum()

                        # ---- backprop ----
                        loss = (reg_loss
                                + replay_coeff * replay_loss
                                + distil_coef * kd_loss
                                + EwcCoeff * ewc_loss
                                + online_ewc_coeff * ewc_loss_online
                                + (si_coeff or 0.0) * si_loss)

                        optimizer.zero_grad()
                        loss.backward()

                        # ====== SI contribution update (CORRECT) ======
                        # snapshot Œ∏(t) BEFORE the optimizer step
                        pre_step = {n: p.detach().clone() for n, p in predictor.named_parameters()}
                        optimizer.step()
                        # now Œ∏(t+1) is in-place; accumulate W += ŒîŒ∏ * (‚àíg)
                        if mode in ["si"]:
                            for name, p in predictor.named_parameters():
                                g = p.grad  # grad from current loss (still present)
                                if g is not None:
                                    delta = p.detach() - pre_step[name]
                                    si_contributions[c][name] += delta * (-g.detach())
                        # ==============================================

                local_weights.append(copy.deepcopy(predictor.state_dict()))

            # FedAvg
            newW = copy.deepcopy(local_weights[0])
            for k in newW:
                for i in range(1, len(local_weights)):
                    newW[k] += local_weights[i][k]
                newW[k] /= len(local_weights)

            global_predictor.load_state_dict(newW)
            for c in clients:
                predictors[c] = copy.deepcopy(global_predictor)

        # ---- Immutable checkpoint after finishing this task ----
        snap = LSTMPredictor(in_dim, hid_dim, out_dim).to(device)
        snap.load_state_dict({k: v.detach().clone() for k, v in global_predictor.state_dict().items()}, strict=True)
        for p in snap.parameters(): p.requires_grad_(False)
        checkpoints.append(snap)

        # ---- Online EWC updates ----
        if mode in ["online_ewc"] and fisher_matrices is not None:
            for c in clients:
                new_loader = client_dataloaders[task_train_key][c]
                new_fisher = compute_ewc_fisher(predictors[c], new_loader, device=device)
                for name, val in new_fisher.items():
                    old_val = fisher_matrices[c].get(name, 0.0)
                    fisher_matrices[c][name] = 0.9 * old_val + 0.1 * val
                ewc_reference_weights[c] = copy.deepcopy(predictors[c])

        # ---- SI omega update & Œ∏* snapshot (EXACT) ----
        if mode in ["si"]:
            xi = 1e-3
            for c in clients:
                final_params = predictors[c]
                for name, p in final_params.named_parameters():
                    delta_total = p.detach() - si_prev_params[c][name]  # Œ∏(T) - Œ∏*
                    W = si_contributions[c][name]
                    si_omegas[c][name] += W / (delta_total.pow(2) + xi)
                    si_prev_params[c][name] = p.detach().clone()
                # reset accumulators
                for name in si_contributions[c]:
                    si_contributions[c][name].zero_()

        # ---- K-means replay update ----
        for c in clients:
            if getattr(replay_buffers[c], "capacity", 0) <= 0:
                continue
            loader = client_dataloaders[task_train_key][c]
            all_X, all_y = [], []
            global_predictor.eval()
            with torch.no_grad():
                for X, y in loader:
                    all_X.append(X.cpu()); all_y.append(y.cpu())
            if not all_X:
                continue
            all_X = torch.cat(all_X, 0); all_y = torch.cat(all_y, 0)
            X_embed = all_X.view(all_X.size(0), -1).numpy()
            n_total = len(all_X); n_sel = int(replay_ratio * n_total)
            if n_sel <= 0: 
                continue
            if n_sel >= len(X_embed):
                sel = np.arange(n_total)
            else:
                km = KMeans(n_clusters=n_sel, random_state=42).fit(X_embed)
                centers, labels = km.cluster_centers_, km.labels_
                sel = []
                for k in range(n_sel):
                    ids = np.where(labels == k)[0]
                    if ids.size == 0: continue
                    D = np.linalg.norm(X_embed[ids] - centers[k], axis=1)
                    sel.append(ids[np.argmin(D)])
            for i in sel:
                x_i, y_i = all_X[i], all_y[i]
                if x_i.ndim == 2:
                    replay_buffers[c].add(x_i, y_i)
                elif x_i.ndim == 3:
                    for j in range(x_i.shape[0]):
                        replay_buffers[c].add(x_i[j], y_i[j])

    return checkpoints


<h1>9. Run Ablation  & Evaluate   </h1>  

In [None]:
def run_combined_ablation(
    MODES,
    REPLAY_RATIO,
    KD_COEFF,
    EWC_COEFF,
    ONLINE_EWC_COEFF,
    REPLAY_COEFF,
    FISHER_MATRICES,
    SI_OMEGAS,
    SI_CONTRIBUTIONS,
    SI_PREV_PARAMS,
    SI_COEFF,
    CLIENT_DATALOADERS,
    BASE_MODEL,
    BASE_CLIENTS,
    NUM_TASKS,
    NUM_ROUNDS_CL,
    LOCAL_EPOCHS,
    LR,
    DEVICE
):
    import copy, time
    import numpy as np
    from sklearn.metrics import mean_squared_error

    checkpoints_per_mode = {}

    cpu_times = {}

    saved_base_weights = BASE_MODEL.state_dict()

    for mode in MODES:
        print(f"\n======= Running Mode: {mode} =======")

        # SI payloads per mode
        if mode in ["si", "kd+si", "si+replay", "si+ewc"]:
            si_omegas_mode = copy.deepcopy(SI_OMEGAS)
            si_prev_params_mode = copy.deepcopy(SI_PREV_PARAMS)
            si_contributions_mode = {}
            for c in CLIENT_DATALOADERS["base_train"]:
                ref = BASE_CLIENTS.get(c, BASE_MODEL)   # <-- fallback ensures shapes exist
                si_contributions_mode[c] = {name: torch.zeros_like(p) for name, p in ref.named_parameters()}
        else:
            si_omegas_mode = si_prev_params_mode = si_contributions_mode = None


        # Replay buffers per mode (use your current name `initial_buffers`)
        mode_replay_buffers = {}
        for c in CLIENT_DATALOADERS["base_train"]:
            if mode in ["replay", "ewc+replay", "kd+replay", "si+replay"]:
                original = initial_buffers[c]
                buf = ReplayBuffer(capacity=original.capacity)
                for x, y in original.buffer:
                    if isinstance(x, torch.Tensor) and x.ndim == 2:
                        buf.add(x, y)
            else:
                buf = ReplayBuffer(capacity=0)
            mode_replay_buffers[c] = buf

        # Reinit base for this mode
        base_copy = LSTMPredictor(
            input_dim=BASE_MODEL.lstm.input_size,
            hidden_dim=BASE_MODEL.lstm.hidden_size,
            output_dim=BASE_MODEL.fc.out_features
        ).to(DEVICE)
        base_copy.load_state_dict(saved_base_weights)

        rounds_config = {
            "Naive": NUM_ROUNDS_CL,
            "replay": NUM_ROUNDS_CL,
            "kd": NUM_ROUNDS_CL,
            "online_ewc": NUM_ROUNDS_CL,
            "ewc": NUM_ROUNDS_CL,
            # keep combos for future use
            "si": NUM_ROUNDS_CL, "ewc+replay": NUM_ROUNDS_CL, "kd+si": NUM_ROUNDS_CL,
            "si+ewc": NUM_ROUNDS_CL, "kd+replay": NUM_ROUNDS_CL, "ewc+kd": NUM_ROUNDS_CL
        }

        kd_for_mode = KD_COEFF if mode in ["kd"] else 0.0

        t0 = time.time()
        checkpoints = continual_learning_hybrid(
            client_dataloaders=CLIENT_DATALOADERS,
            client_buffers=mode_replay_buffers,
            base_predictor=base_copy,
            base_clients=BASE_CLIENTS,
            num_tasks=NUM_TASKS,
            num_rounds_dict=rounds_config,
            local_epochs=LOCAL_EPOCHS,
            lr=LR,
            device=DEVICE,
            mode=mode,
            replay_ratio=REPLAY_RATIO,
            distil_coef=kd_for_mode,
            EwcCoeff=EWC_COEFF,
            online_ewc_coeff=ONLINE_EWC_COEFF,
            replay_coeff=REPLAY_COEFF,
            fisher_matrices=FISHER_MATRICES if mode in ["ewc", "online_ewc"] else None,
            si_coeff=SI_COEFF,
            si_omegas=si_omegas_mode,
            si_contributions=si_contributions_mode,
            si_prev_params=si_prev_params_mode
        )
        cpu_times[mode] = time.time() - t0
        checkpoints_per_mode[mode] = checkpoints



    return checkpoints_per_mode, cpu_times




def evaluate_rmse_matrix_and_metrics(ablation_checkpoints, client_dataloaders, cpu_times, device=DEVICE):
    import numpy as np
    from sklearn.metrics import mean_squared_error, mean_absolute_error

    final_metrics = {}

    for mode, checkpoints in ablation_checkpoints.items():
        N = len(checkpoints)
        # Legacy: zero-init matrices (unseen cells stay 0 instead of NaN)
        rmse_matrix = np.zeros((N, N), dtype=float)
        mae_matrix  = np.zeros((N, N), dtype=float)

        # Fill lower triangle
        for i, predictor in enumerate(checkpoints):
            predictor.eval()
            for j in range(i + 1):
                task_key = f"task_{j+1}_test"
                preds, trues = [], []
                for _, loader in client_dataloaders[task_key].items():
                    with torch.no_grad():
                        for X, y in loader:
                            X, y = X.to(device), y.to(device)
                            preds.append(predictor(X).cpu().numpy().ravel())
                            trues.append(y.cpu().numpy().ravel())
                if preds and trues:
                    P = np.concatenate(preds); Y = np.concatenate(trues)
                    # Guard just in case
                    if P.size and Y.size and np.isfinite(P).all() and np.isfinite(Y).all():
                        rmse_matrix[i, j] = np.sqrt(mean_squared_error(Y, P))
                        mae_matrix[i, j]  = mean_absolute_error(Y, P)
                # else: leave zeros (legacy behavior)

        # === Legacy metrics (match your original notebook) ===
        # AvgPerf over tasks 1..N-1
        last_row_rmse = rmse_matrix[-1, :-1] if N > 1 else rmse_matrix[-1:, -1:]
        avg_perf = float(np.mean(last_row_rmse))

        # AF over first N-1 tasks: P_{N,j} - P_{j,j}
        diag_rmse = np.diag(rmse_matrix)
        if N > 1:
            af = float(np.mean(rmse_matrix[-1, :-1] - diag_rmse[:-1]))
        else:
            af = 0.0

        # AP (legacy: diagonal mean; if you want strictly legacy, this matched your prints)
        ap = float(np.mean(diag_rmse))

        final_metrics[mode] = {
            "rmse_matrix": rmse_matrix,
            "mae_matrix": mae_matrix,
            "avg_perf": avg_perf,
            "avg_forget_best": float('nan'),     # unused in your table
            "avg_forget_taskwise": af,
            "avg_plasticity": ap,
            "last_model_rmse": rmse_matrix[-1, :],
            "last_model_mae":  mae_matrix[-1, :],
            "forgetting_vector_best": np.array([]),
            "forgetting_vector_taskwise": (rmse_matrix[-1, :-1] - diag_rmse[:-1]) if N > 1 else np.array([]),
            "cpu_time": cpu_times[mode],
        }

        print(f"üî¢ [{mode}] AvgPerf={avg_perf:.6f} | AF={af:.6f} | AP={ap:.6f}")

    return final_metrics





# ===== Run Ablation Study on different modes =====
set_all_seeds(42)
MODES = ["Naive","replay", "kd", "online_ewc", "ewc", "si"]

ablation_checkpoints, cpu_times = run_combined_ablation(
    MODES=MODES,
    REPLAY_RATIO=REPLAY_RATIO,
    KD_COEFF=KD_COEFF,
    EWC_COEFF=EWC_COEFF,
    ONLINE_EWC_COEFF=ONLINE_EWC_COEFF,
    REPLAY_COEFF=REPLAY_COEFF,
    FISHER_MATRICES=fisher,
    SI_OMEGAS=si_omegas,
    SI_CONTRIBUTIONS=si_W,
    SI_PREV_PARAMS=si_prev,
    SI_COEFF=SI_COEFF,
    CLIENT_DATALOADERS=CLIENT_DATALOADERS,
    BASE_MODEL=BASE_MODEL,
    BASE_CLIENTS=BASE_CLIENTS,
    NUM_TASKS=NUM_TASKS,
    NUM_ROUNDS_CL=NUM_ROUNDS_CL,
    LOCAL_EPOCHS=LOCAL_EPOCHS,
    LR=LR,
    DEVICE=DEVICE
)

# Evaluate with your original logic
final_metrics = evaluate_rmse_matrix_and_metrics(
    ablation_checkpoints=ablation_checkpoints,
    client_dataloaders=CLIENT_DATALOADERS,
    cpu_times=cpu_times,
    device=DEVICE
)

summary_df = pd.DataFrame({
    mode: {
        "AvgForgetting": metrics["avg_forget_taskwise"] * 1000,  # your scaling
        "AvgPlasticity": metrics["avg_plasticity"],
        "AvgPerformance": metrics["avg_perf"],
        "CPUTime(s)": metrics["cpu_time"]
    }
    for mode, metrics in final_metrics.items()
}).T

print("\nüìä Final Results Table:")
print(summary_df.round(7).to_string())
