 Setup & Imports

In [None]:
# General utilities
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import random

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# PyTorch (for model building and training)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

# Sklearn (for preprocessing and evaluation)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    classification_report, confusion_matrix,
    mean_absolute_error, mean_squared_error, r2_score
)

# Optional: set seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


##  Dataset Preparation

In [None]:
# Using these 4 features across all tasks
common_features = ["Catchment", "Mean_annual_precip", "Mean_annual_temp", "lc_urban"]

 flood_year dataset

In [None]:
# Load full merged dataset (we used this in previous notebooks)
df = pd.read_csv("cleaned_ADHI.csv")

# Drop rows with missing values in important columns
df = df.dropna(subset=common_features + ["Maxi_q", "q95th", "Country"])
# Create binary target
df["flood_year"] = (df["Maxi_q"] > df["q95th"]).astype(int)

# Define features
flood_year_features = ["Catchment", "Mean_annual_precip", "lc_urban"]

# Normalize features
scaler = StandardScaler()
df[common_features] = scaler.fit_transform(df[common_features])

# Split by country into dictionary
flood_year_data = {}

for country in df["Country"].unique():
    country_df = df[df["Country"] == country]
    X = country_df[flood_year_features].values
    y = country_df["flood_year"].values

    if len(y) >= 30:  # Keep countries with enough data
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
        flood_year_data[country] = (X_train, y_train, X_test, y_test)


 Flood Magnitude Regression Dataset Preparation

In [None]:
# Drop rows with missing values
df_magnitude = df.dropna(subset=common_features + ["Maxi_q", "Country"])
# log-transform the target

df_magnitude["log_max_q"] = np.log1p(df_magnitude["Maxi_q"])
# Define features
magnitude_features = [
    "Catchment", "Mean_annual_precip", "Mean_annual_temp", "lc_forest"
]

# Normalize features
scaler_mag = StandardScaler()
df_magnitude[common_features] = scaler_mag.fit_transform(df_magnitude[common_features])


# Split by country
flood_magnitude_data = {}

for country in df_magnitude["Country"].unique():
    country_df = df_magnitude[df_magnitude["Country"] == country]
    X = country_df[magnitude_features].values
    y = country_df["log_max_q"].values

    if len(y) >= 30:
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
        flood_magnitude_data[country] = (X_train, y_train, X_test, y_test)


Flood Seasonality Classification Dataset

In [None]:
import zipfile
import os

# Path to your uploaded zip
zip_path = "/content/ADHI_MonthlySeries.zip"
extract_dir = "/content/ADHI/ADHI_MonthlySeries"

# Make sure the target folder exists
os.makedirs(extract_dir, exist_ok=True)

# Unzip
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_dir)

print("Unzipped to:", extract_dir)

Unzipped to: /content/ADHI/ADHI_MonthlySeries


In [None]:
import os
import pandas as pd
import glob

# Path to your directory
monthly_dir = "/content/ADHI/ADHI_MonthlySeries/MonthlySeries"

# Get all .txt files
monthly_files = glob.glob(os.path.join(monthly_dir, "*.txt"))

print(f"Found {len(monthly_files)} monthly files.")

Found 1466 monthly files.


In [None]:
# Load all files into a single DataFrame
monthly_data = []

for file in monthly_files:
    station_id = os.path.basename(file).split("_")[-1].replace(".txt", "")

    df = pd.read_csv(file, header=None, names=["Year", "Month", "Mean_Q", "Max_Q", "Min_Q", "Missing_days"])
    df["Station_ID"] = f"ADHI_{station_id}"

    monthly_data.append(df)

# Concatenate all into a single DataFrame
monthly_df = pd.concat(monthly_data, ignore_index=True)
print("Combined monthly data shape:", monthly_df.shape)
monthly_df.head()

Combined monthly data shape: (1213848, 7)


Unnamed: 0,Year,Month,Mean_Q,Max_Q,Min_Q,Missing_days,Station_ID
0,1950,1,,,,31,ADHI_1223
1,1950,2,,,,28,ADHI_1223
2,1950,3,,,,31,ADHI_1223
3,1950,4,,,,30,ADHI_1223
4,1950,5,,,,31,ADHI_1223


In [None]:
# Drop rows where all Q values are NaN
monthly_df = monthly_df.dropna(subset=["Mean_Q", "Max_Q", "Min_Q"], how="all")

# Create a datetime column
monthly_df["Date"] = pd.to_datetime(dict(year=monthly_df["Year"], month=monthly_df["Month"], day=1))

# Sort the data
monthly_df = monthly_df.sort_values(by=["Station_ID", "Date"]).reset_index(drop=True)

# Preview
print("Cleaned data shape:", monthly_df.shape)
monthly_df.head()


Cleaned data shape: (510762, 8)


Unnamed: 0,Year,Month,Mean_Q,Max_Q,Min_Q,Missing_days,Station_ID,Date
0,1963,5,10.31,16.29,5.46,7,ADHI_1,1963-05-01
1,1963,6,18.98,33.29,7.46,0,ADHI_1,1963-06-01
2,1963,7,20.55,27.6,14.69,0,ADHI_1,1963-07-01
3,1963,8,56.82,101.0,22.0,0,ADHI_1,1963-08-01
4,1963,9,57.0,96.3,39.9,0,ADHI_1,1963-09-01


In [None]:
stations_df = pd.read_csv("ADHI_stations.csv")
# Step 1: Extract relevant columns
station_country_df = stations_df[["ID", "Country"]].copy()

# Step 2: Rename for consistency with ADHI summary table
station_country_df = station_country_df.rename(columns={"ID": "ADHI_ID"})

# Step 3: Save to CSV
station_country_df.to_csv("stations_by_country.csv", index=False)

# Preview
print("Saved stations_by_country.csv with the following columns:")
display(station_country_df.head())

Saved stations_by_country.csv with the following columns:


Unnamed: 0,ADHI_ID,Country
0,ADHI_1,Cameroon
1,ADHI_2,Cameroon
2,ADHI_3,Cameroon
3,ADHI_4,Cameroon
4,ADHI_5,Cameroon


In [None]:
# Rename for consistency
station_country_df = station_country_df.rename(columns={"ADHI_ID": "Station_ID"})

# Merge into monthly_df
monthly_df = monthly_df.merge(station_country_df, on="Station_ID", how="left")

# Preview to verify merge
monthly_df.drop_duplicates().head()

Unnamed: 0,Year,Month,Mean_Q,Max_Q,Min_Q,Missing_days,Station_ID,Date,Country
0,1963,5,10.31,16.29,5.46,7,ADHI_1,1963-05-01,Cameroon
1,1963,6,18.98,33.29,7.46,0,ADHI_1,1963-06-01,Cameroon
2,1963,7,20.55,27.6,14.69,0,ADHI_1,1963-07-01,Cameroon
3,1963,8,56.82,101.0,22.0,0,ADHI_1,1963-08-01,Cameroon
4,1963,9,57.0,96.3,39.9,0,ADHI_1,1963-09-01,Cameroon


In [None]:
monthly_df = monthly_df.rename(columns={"Country_x": "Country"})
# Create YearMonth column
monthly_df["YearMonth"] = monthly_df["Date"].dt.to_period("M")


In [None]:
# Find the peak month each year per station
monthly_df["Year"] = monthly_df["Date"].dt.year
peak_months = monthly_df.loc[monthly_df.groupby(["Station_ID", "Year"])["Max_Q"].idxmax()]

# Drop rows with NaN peaks if any
peak_months = peak_months.dropna(subset=["Max_Q"])

# Rename for clarity
peak_months = peak_months[["Station_ID", "Year", "Month", "Max_Q"]]
peak_months = peak_months.rename(columns={"Month": "Peak_Month"})

# preview
peak_months.head()


Unnamed: 0,Station_ID,Year,Peak_Month,Max_Q
3,ADHI_1,1963,8,101.0
16,ADHI_1,1964,9,126.0
28,ADHI_1,1965,9,218.0
40,ADHI_1,1966,9,147.0
51,ADHI_1,1967,8,70.09


In [None]:
# # Merge peak months with static station features
# peak_features_df = pd.merge(peak_months, stations_df, left_on="Station_ID", right_on="ID", how="inner")
# peak_features_df.info()

In [None]:
from collections import Counter
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

# --- Map peak months into seasons ---
def month_to_season(month):
    if month in [12, 1, 2]:
        return "Dry"
    elif month in [3, 4, 5]:
        return "PreFlood"
    elif month in [6, 7, 8, 9]:
        return "Flood"
    elif month in [10, 11]:
        return "PostFlood"
    else:
        return np.nan  # catch invalid values

# Apply mapping
peak_months["Peak_Season"] = peak_months["Peak_Month"].apply(month_to_season)

# Merge with static features
season_df = pd.merge(
    peak_months, stations_df, left_on="Station_ID", right_on="ID", how="inner"
)

# Drop rows with any missing values in relevant columns
season_df = season_df.dropna(subset=common_features + ["Peak_Season", "Country"])

# Encode season labels
season_mapping = {"Dry": 0, "PreFlood": 1, "Flood": 2, "PostFlood": 3}
season_df["season_label"] = season_df["Peak_Season"].map(season_mapping)

# Define feature columns
season_features = ["Catchment", "Mean_annual_temp", "Mean_annual_precip", "lc_urban"]

# Normalize feature values
scaler_season = StandardScaler()
season_df[common_features] = scaler_season.fit_transform(season_df[common_features])

# Create per-country training/testing splits
flood_season_data = {}

for country in season_df["Country"].unique():
    country_df = season_df[season_df["Country"] == country]
    X = country_df[season_features].values
    y = country_df["season_label"].values
    label_counts = Counter(y)

    # Keep only countries with ≥30 samples and at least 2 samples per class
    if len(y) >= 30 and all(count >= 2 for count in label_counts.values()):
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, stratify=y, test_size=0.2, random_state=42
        )
        flood_season_data[country] = (X_train, y_train, X_test, y_test)
    else:
        print(f" Skipping {country} due to class imbalance or low sample count: {label_counts}")


 Skipping Angola due to class imbalance or low sample count: Counter({np.int64(1): 72, np.int64(0): 36, np.int64(2): 1})
 Skipping Liberia due to class imbalance or low sample count: Counter({np.int64(2): 91, np.int64(3): 47, np.int64(1): 1, np.int64(0): 1})
 Skipping Uganda due to class imbalance or low sample count: Counter({np.int64(1): 10, np.int64(0): 6, np.int64(3): 2})
 Skipping Burundi due to class imbalance or low sample count: Counter({np.int64(1): 20, np.int64(0): 7, np.int64(3): 2})


## Define Shared Encoder

In [None]:
import torch.nn as nn

class SharedEncoder(nn.Module):
    def __init__(self, input_dim=4, hidden_dim=64, output_dim=32):
        super(SharedEncoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, output_dim),
            nn.ReLU()
        )

    def forward(self, x):
        return self.encoder(x)


## Define Task Heads

 Flood Year Classification Head

In [None]:
class FloodYearHead(nn.Module):
    def __init__(self, input_dim=32):
        super(FloodYearHead, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 16),
            nn.ReLU(),
            nn.Linear(16, 1)  # Output = 1 for binary classification
        )

    def forward(self, x):
        return self.classifier(x)


 Flood Magnitude Regression Head

In [None]:
class FloodMagnitudeHead(nn.Module):
    def __init__(self, input_dim=32):
        super(FloodMagnitudeHead, self).__init__()
        self.regressor = nn.Sequential(
            nn.Linear(input_dim, 16),
            nn.ReLU(),
            nn.Linear(16, 1)  # Output = 1 continuous value
        )

    def forward(self, x):
        return self.regressor(x)

Flood Seasonality Classification Head

In [None]:
class FloodSeasonHead(nn.Module):
    def __init__(self, input_dim=32, num_classes=4):
        super(FloodSeasonHead, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.ReLU(),
            nn.Linear(32, num_classes)  # Output = 4 classes (Dry, PreFlood, Flood, PostFlood)
        )

    def forward(self, x):
        return self.classifier(x)

 Define Full Multi-Head Model

In [None]:
class MultiTaskModel(nn.Module):
    def __init__(self, encoder, year_head, magnitude_head, season_head):
        super(MultiTaskModel, self).__init__()
        self.encoder = encoder
        self.year_head = year_head
        self.magnitude_head = magnitude_head
        self.season_head = season_head

    def forward(self, x):
        shared_repr = self.encoder(x)

        out_year = self.year_head(shared_repr)        # Binary (logits)
        out_magnitude = self.magnitude_head(shared_repr)  # Continuous
        out_season = self.season_head(shared_repr)    # Multiclass (logits)

        return {
            "flood_year": out_year.squeeze(-1),
            "flood_magnitude": out_magnitude.squeeze(-1),
            "flood_season": out_season  # raw logits
        }

## Loss Functions

In [None]:
# Binary classification (with raw logits)
loss_fn_year = nn.BCEWithLogitsLoss()

# Regression
loss_fn_magnitude = nn.MSELoss()

# Multi-class classification (with raw logits, no softmax needed)
loss_fn_season = nn.CrossEntropyLoss()


## Federated Training Loop

 Average All Encoders (FedAvg Function)

In [None]:
def average_encoders(encoders):
    """Average the parameters of shared encoders"""
    new_encoder = SharedEncoder(input_dim=4).to(device)
    new_state_dict = new_encoder.state_dict()

    for key in new_state_dict:
        new_state_dict[key] = torch.stack([enc.state_dict()[key] for enc in encoders]).mean(0)

    new_encoder.load_state_dict(new_state_dict)
    return new_encoder

Federated Training Loop

In [None]:
def train_federated(
    countries=None,
    num_rounds=5,
    local_epochs=2,
    lr=1e-3
):
    if countries is None:
        countries = list(
            set(flood_year_data.keys()) &
            set(flood_magnitude_data.keys()) &
            set(flood_season_data.keys())
        )

    global_encoder = SharedEncoder(input_dim=4).to(device)

    for round_num in range(num_rounds):
        print(f"\n Communication Round {round_num+1}/{num_rounds}")
        local_encoders = []

        for country in countries:
            print(f"  - Training for {country}")

            # Clone global encoder
            encoder = SharedEncoder(input_dim=4).to(device)
            encoder.load_state_dict(global_encoder.state_dict())

            # Independent heads per country
            year_head = FloodYearHead().to(device)
            mag_head = FloodMagnitudeHead().to(device)
            season_head = FloodSeasonHead().to(device)

            model = MultiTaskModel(encoder, year_head, mag_head, season_head).to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=lr)

            # Get data
            Xy_year = flood_year_data[country]
            Xy_mag = flood_magnitude_data[country]
            Xy_season = flood_season_data[country]

            # Convert to tensors
            X_year = torch.tensor(Xy_year[0], dtype=torch.float32).to(device)
            y_year = torch.tensor(Xy_year[1], dtype=torch.float32).to(device)

            X_mag = torch.tensor(Xy_mag[0], dtype=torch.float32).to(device)
            y_mag = torch.tensor(Xy_mag[1], dtype=torch.float32).to(device)

            X_season = torch.tensor(Xy_season[0], dtype=torch.float32).to(device)
            y_season = torch.tensor(Xy_season[1], dtype=torch.long).to(device)

            # Training loop
            model.train()
            for epoch in range(local_epochs):
                optimizer.zero_grad()

                pred_year = model(X_year)["flood_year"]
                pred_mag = model(X_mag)["flood_magnitude"]
                pred_season = model(X_season)["flood_season"]

                loss = (
                    loss_fn_year(pred_year, y_year) +
                    loss_fn_magnitude(pred_mag, y_mag) +
                    loss_fn_season(pred_season, y_season)
                )

                loss.backward()
                optimizer.step()

            # Save the updated local encoder
            local_encoders.append(model.encoder)

        # Aggregate all encoders
        global_encoder = average_encoders(local_encoders)
        print("Aggregated new global encoder.")

    return global_encoder

Saving Final Encoder

In [None]:
trained_encoder = train_federated()
torch.save(trained_encoder.state_dict(), "global_multitask_encoder.pt")
print(" Saved global encoder.")

Evaluation & Saving the Models

In [None]:
# Ensure output directory exists
os.makedirs("saved_models", exist_ok=True)

# Save the global encoder after federated training
torch.save(trained_encoder.state_dict(), "saved_models/global_multitask_encoder.pt")
print(" Global encoder saved to 'saved_models/global_multitask_encoder.pt'")

# -------------------------------
#  Restore Encoder for Reuse
# -------------------------------
encoder = SharedEncoder(input_dim=4).to(device)
encoder.load_state_dict(torch.load("saved_models/global_multitask_encoder.pt"))

# Dummy heads (to be re-trained later)
year_head = FloodYearHead().to(device)
mag_head = FloodMagnitudeHead().to(device)
season_head = FloodSeasonHead().to(device)

# Compose full model
global_model = MultiTaskModel(encoder, year_head, mag_head, season_head).to(device)

# Save complete model with empty (fresh) task heads
torch.save(global_model.state_dict(), "saved_models/global_multitask_full_model.pt")
print("Full model saved (w/ fresh heads) to 'saved_models/global_multitask_full_model.pt'")


In [None]:
# Example evaluation on one country (Kenya)
country = "Kenya"

if country in flood_year_data and country in flood_magnitude_data and country in flood_season_data:
    global_model.eval()

    # Year classification
    X_test_year = torch.tensor(flood_year_data[country][2], dtype=torch.float32).to(device)
    y_test_year = torch.tensor(flood_year_data[country][3], dtype=torch.float32).to(device)
    with torch.no_grad():
        y_pred_logits = global_model(X_test_year)["flood_year"]
        y_pred = torch.sigmoid(y_pred_logits).cpu().numpy() > 0.5
    print("Classification Report (Flood Year):")
    print(classification_report(y_test_year.cpu().numpy(), y_pred))

    # Magnitude regression
    X_test_mag = torch.tensor(flood_magnitude_data[country][2], dtype=torch.float32).to(device)
    y_test_mag = torch.tensor(flood_magnitude_data[country][3], dtype=torch.float32).to(device)
    with torch.no_grad():
        y_pred_mag = global_model(X_test_mag)["flood_magnitude"].cpu().numpy()
    print("\nRegression Metrics (Flood Magnitude):")
    print("MAE:", mean_absolute_error(y_test_mag.cpu().numpy(), y_pred_mag))
    print("RMSE:", np.sqrt(mean_squared_error(y_test_mag.cpu().numpy(), y_pred_mag)))
    print("R²:", r2_score(y_test_mag.cpu().numpy(), y_pred_mag))

    # Season classification
    X_test_season = torch.tensor(flood_season_data[country][2], dtype=torch.float32).to(device)
    y_test_season = torch.tensor(flood_season_data[country][3], dtype=torch.long).to(device)
    with torch.no_grad():
        y_pred_season = global_model(X_test_season)["flood_season"].argmax(dim=1).cpu().numpy()
    print("\nClassification Report (Flood Season):")
    print(classification_report(y_test_season.cpu().numpy(), y_pred_season))