# Federated Learning + XAI for Parkinson's Detection — Full Notebook (v2)

This **full** notebook includes:
- **Gait classification (PD vs Control)** on *PhysioNet "Gait in Parkinson’s Disease"* with robust evaluation (no one-class AUC warnings) and **balanced accuracy**.
- **Voice severity regression (UPDRS)** on *UCI Parkinson’s Telemonitoring*.
- **Federated Learning (FedAvg)** with **subject-as-client** simulation.
- **XAI**: **Integrated Gradients** for gait, **SHAP** for voice.
- Safer label inference for gait files; skips non-signal files (`format.txt`, `demographics.txt`, hashes).  

> Adjust the paths below to your local folders (defaults reflect your latest message).


## 0) Setup

In [1]:
# If needed, install packages (uncomment):
# !pip install torch numpy pandas scikit-learn shap captum matplotlib tqdm

import os, re, glob, math, random, json, warnings
from dataclasses import dataclass
from typing import List, Dict, Tuple
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, mean_absolute_error, balanced_accuracy_score

# XAI
import shap
from captum.attr import IntegratedGradients

warnings.filterwarnings('ignore')
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE


  from .autonotebook import tqdm as notebook_tqdm


'cpu'

## 1) Configure dataset paths (Windows)

In [2]:
# Edit these if your folders are different
GAIT_ROOT = r"C:\Users\muham\_Projects\PD1\Dataset\gait-in-parkinsons-disease-1.0.0"
VOICE_TELEMONITORING_CSV = r"C:\Users\muham\_Projects\PD1\Dataset\parkinsons_updrs.data"
PADS_ROOT = r"C:\Users\muham\_Projects\PD1\Dataset\pads-parkinsons-disease-smartwatch-dataset-1.0.0"

print("GAIT_ROOT =", GAIT_ROOT)
print("VOICE_TELEMONITORING_CSV =", VOICE_TELEMONITORING_CSV)
print("PADS_ROOT =", PADS_ROOT)

GAIT_ROOT = C:\Users\muham\_Projects\PD1\Dataset\gait-in-parkinsons-disease-1.0.0
VOICE_TELEMONITORING_CSV = C:\Users\muham\_Projects\PD1\Dataset\parkinsons_updrs.data
PADS_ROOT = C:\Users\muham\_Projects\PD1\Dataset\pads-parkinsons-disease-smartwatch-dataset-1.0.0


## 2) Helper functions

In [3]:
def set_seed(s=42):
    random.seed(s); np.random.seed(s); torch.manual_seed(s)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(s)

def subject_partition_indices(subject_ids: List[str]):
    mp: Dict[str, List[int]] = {}
    for i,s in enumerate(subject_ids):
        mp.setdefault(s, []).append(i)
    return mp

def make_client_loaders(idx_list: List[int], ds, bs=64, shuffle=True, drop_last=True):
    class _Sub(Dataset):
        def __init__(self, base, idx): self.base=base; self.idx=idx
        def __len__(self): return len(self.idx)
        def __getitem__(self, i): return self.base[self.idx[i]]
    return DataLoader(_Sub(ds, idx_list), batch_size=bs, shuffle=shuffle, drop_last=drop_last)

def clone_sd(sd): return {k:v.clone() for k,v in sd.items()}
def avg_sd(sd_list: List[dict], weights: List[float]):
    out={}
    tot = sum(weights)
    for k in sd_list[0].keys():
        out[k] = sum(w*sd[k] for sd,w in zip(sd_list,weights))/tot
    return out

## 3) Gait dataset loader + model (BiLSTM for PD vs Control)

In [4]:
class GaitDataset(Dataset):
    def __init__(self, X, y, subj):
        self._X = X; self._y = y; self._s = subj
    def __len__(self): return len(self._X)
    def __getitem__(self, i):
        return torch.tensor(self._X[i], dtype=torch.float32), torch.tensor(self._y[i], dtype=torch.long)
    @property
    def subjects(self): return self._s

def infer_label_from_filename(fn_lower: str):
    # Robust markers for PD vs Control
    pd_markers = ["gapt", "jupt", "sipt", "_pt", "-pt", "parkinson", "pd"]
    co_markers = ["gaco", "juco", "sico", "_co", "-co", "control", "healthy"]
    if any(m in fn_lower for m in pd_markers): return 1
    if any(m in fn_lower for m in co_markers): return 0
    return None

def subject_id_from_name(path: str):
    stem = os.path.splitext(os.path.basename(path))[0]
    m = re.match(r"([A-Za-z]+[A-Za-z]*\d+)", stem)
    return m.group(1) if m else stem

def should_skip_file(path: str):
    # Skip metadata/text/hash files
    basename = os.path.basename(path).lower()
    return any(basename.startswith(k) for k in ["format", "demographics", "sha256sums", "notes", "readme"])

def load_gait_windows(root: str, window_sec=4.0, step_sec=2.0, fs=100.0):
    files = glob.glob(os.path.join(root, "*.txt"))
    X, y, sids = [], [], []
    win = int(window_sec*fs); step=int(step_sec*fs)
    for f in tqdm(files, desc="GAIT: loading"):
        try:
            if should_skip_file(f):
                continue
            arr = np.loadtxt(f)
            if arr.ndim!=2 or arr.shape[1] < 19:
                continue
            label = infer_label_from_filename(f.lower())
            if label is None:
                continue
            sid = subject_id_from_name(f)
            T = arr.shape[0]
            for start in range(0, max(1, T - win + 1), step):
                seg = arr[start:start+win, :19]
                if seg.shape[0] < win: break
                X.append(seg); y.append(label); sids.append(sid)
        except Exception as e:
            print(f"[skip] {f}: {e}")
    return X, y, sids

class GaitBiLSTM(nn.Module):
    def __init__(self, in_dim=19, hidden=64, layers=1):
        super().__init__()
        self.lstm = nn.LSTM(in_dim, hidden, num_layers=layers, batch_first=True, bidirectional=True)
        self.head = nn.Sequential(nn.Linear(2*hidden, 64), nn.ReLU(), nn.Linear(64, 2))
    def forward(self, x):  # (B,T,19)
        out,_ = self.lstm(x)
        h = out[:, -1, :]
        return self.head(h)

### 3.1) Train Federated Gait Classifier (FedAvg) — with robust evaluation

In [6]:
set_seed(SEED)

# Load gait data
Xg, yg, subjg = load_gait_windows(GAIT_ROOT, window_sec=4.0, step_sec=2.0, fs=100.0)
gait_ds = GaitDataset(Xg, yg, subjg)
parts_g = subject_partition_indices(gait_ds.subjects)
print(f"GAIT samples={len(gait_ds)}  subjects={len(parts_g)}")

# Init global model
model_g = GaitBiLSTM(in_dim=19, hidden=64).to(DEVICE)
crit = nn.CrossEntropyLoss()

rounds = 5
frac_clients = 0.3
local_epochs = 2
bs = 128

def train_epoch_cls(model, dl, opt, crit):
    model.train(); total=0; correct=0
    for xb, yb in dl:
        xb=xb.to(DEVICE); yb=yb.to(DEVICE)
        opt.zero_grad(); logits = model(xb); loss = crit(logits, yb)
        loss.backward(); opt.step()
        pred = logits.argmax(1); total += yb.size(0); correct += (pred==yb).sum().item()
    return correct/max(1,total)

@torch.no_grad()
def eval_cls(model, dl):
    model.eval()
    y_true, y_score = [], []
    for xb,yb in dl:
        xb=xb.to(DEVICE)
        prob = torch.softmax(model(xb), dim=1)[:,1].cpu().numpy()
        y_true.extend(yb.numpy().tolist()); y_score.extend(prob.tolist())
    y_pred = [1 if s>=0.5 else 0 for s in y_score]
    acc = accuracy_score(y_true,y_pred)
    f1  = f1_score(y_true,y_pred)
    bacc= balanced_accuracy_score(y_true,y_pred)
    # AUC only when both classes present
    auc = roc_auc_score(y_true,y_score) if len(set(y_true))==2 else float("nan")
    return acc, f1, auc, bacc

for r in range(rounds):
    print(f"\n=== GAIT Round {r+1}/{rounds} ===")
    client_ids = list(parts_g.keys())
    m = max(1, int(frac_clients*len(client_ids)))
    chosen = random.sample(client_ids, m)

    sds=[]; ws=[]
    for cid in chosen:
        idx = parts_g[cid]
        if len(idx) < bs:  # need enough for a batch
            continue
        local = GaitBiLSTM(in_dim=19, hidden=64).to(DEVICE)
        local.load_state_dict(clone_sd(model_g.state_dict()))
        dl = make_client_loaders(idx, gait_ds, bs=bs, shuffle=True, drop_last=True)
        opt = torch.optim.Adam(local.parameters(), lr=1e-3)
        for _ in range(local_epochs):
            train_epoch_cls(local, dl, opt, crit)
        sds.append(clone_sd(local.state_dict())); ws.append(len(idx))

    if sds:
        model_g.load_state_dict(avg_sd(sds, ws))

    # Evaluate on full dataset (no drop_last, no shuffle)
    dl_eval = DataLoader(gait_ds, batch_size=bs, shuffle=False, drop_last=False)
    acc,f1,auc,bacc = eval_cls(model_g, dl_eval)
    print(f"[Global] ACC={acc:.3f}  F1={f1:.3f}  AUC={auc:.3f}  BACC={bacc:.3f}")

GAIT: loading: 100%|█████████████████████████████████████████████████████████████████| 309/309 [00:11<00:00, 26.48it/s]


GAIT samples=16117  subjects=165

=== GAIT Round 1/5 ===
[Global] ACC=0.838  F1=0.912  AUC=nan  BACC=0.838

=== GAIT Round 2/5 ===
[Global] ACC=1.000  F1=1.000  AUC=nan  BACC=1.000

=== GAIT Round 3/5 ===
[Global] ACC=1.000  F1=1.000  AUC=nan  BACC=1.000

=== GAIT Round 4/5 ===
[Global] ACC=1.000  F1=1.000  AUC=nan  BACC=1.000

=== GAIT Round 5/5 ===
[Global] ACC=1.000  F1=1.000  AUC=nan  BACC=1.000


### 3.2) XAI on Gait (Integrated Gradients)

In [7]:
# Compute IG on a mini-batch
model_g.eval()
xb,_ = next(iter(DataLoader(gait_ds, batch_size=32, shuffle=True, drop_last=True)))
xb = xb.to(DEVICE)
ig = IntegratedGradients(lambda inp: torch.softmax(model_g(inp), dim=1)[:,1])
attr = ig.attribute(xb, n_steps=32).detach().cpu().numpy()
attr.shape  # (B, T, 19)

(32, 400, 19)

## 4) Voice Telemonitoring UPDRS Regression (MLP + FedAvg)

In [22]:
class VoiceTeleDataset(Dataset):
    def __init__(self, X, y, subj):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32).view(-1,1)
        self.subj = subj
    def __len__(self): return len(self.X)
    def __getitem__(self,i): return self.X[i], self.y[i]
    @property
    def subjects(self): return self.subj

def load_voice_telemonitoring(csv_path: str, target="total_UPDRS"):
    cols = ["subject","age","sex","test_time","motor_UPDRS","total_UPDRS",
            "Jitter(%)","Jitter(Abs)","Jitter:RAP","Jitter:PPQ5","Jitter:DDP",
            "Shimmer","Shimmer(dB)","Shimmer:APQ3","Shimmer:APQ5","Shimmer:APQ11","Shimmer:DDA",
            "NHR","HNR","RPDE","DFA","PPE"]
    df = pd.read_csv(csv_path, header=None, names=cols)
    y = df[target].values.astype(np.float32)
    X = df.drop(columns=["subject", target]).values.astype(np.float32)
    subj = df["subject"].astype(str).tolist()
    scaler = StandardScaler().fit(X)
    Xs = scaler.transform(X)
    return Xs, y, subj, scaler, [c for c in df.columns if c not in ["subject", target]]

class VoiceMLPReg(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 128), nn.ReLU(),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, 1)
        )
    def forward(self, x): return self.net(x)

### 4.1) Train Federated Voice Regressor (FedAvg on UPDRS)

In [None]:
set_seed(SEED)
Xv, yv, sub_v, scaler_v, feat_names = load_voice_telemonitoring(VOICE_TELEMONITORING_CSV, target="total_UPDRS")
voice_ds = VoiceTeleDataset(Xv, yv, sub_v)
parts_v = subject_partition_indices(voice_ds.subjects)
print(f"VOICE samples={len(voice_ds)}  subjects={len(parts_v)}")

in_dim = voice_ds[0][0].numel()
model_v = VoiceMLPReg(in_dim).to(DEVICE)

rounds = 5
frac_clients = 0.3
local_epochs = 2
bs = 256

def train_epoch_reg(model, dl, opt):
    model.train(); losses=[]
    for xb,yb in dl:
        xb=xb.to(DEVICE); yb=yb.to(DEVICE)
        opt.zero_grad(); pred = model(xb)
        loss = nn.functional.l1_loss(pred, yb)
        loss.backward(); opt.step(); losses.append(loss.item())
    return float(np.mean(losses)) if losses else 0.0

@torch.no_grad()
def eval_reg(model, dl):
    model.eval(); preds=[]; trues=[]
    for xb,yb in dl:
        xb=xb.to(DEVICE)
        pred = model(xb).cpu().numpy().ravel()
        preds.extend(pred.tolist()); trues.extend(yb.numpy().ravel().tolist())
    return mean_absolute_error(trues, preds)

for r in range(rounds):
    print(f"\n=== VOICE Round {r+1}/{rounds} ===")
    client_ids = list(parts_v.keys())
    m = max(1, int(frac_clients*len(client_ids)))
    chosen = random.sample(client_ids, m)

    sds=[]; ws=[]
    for cid in chosen:
        idx = parts_v[cid]
        if len(idx) < bs: continue
        local = VoiceMLPReg(in_dim).to(DEVICE)
        local.load_state_dict(clone_sd(model_v.state_dict()))
        dl = make_client_loaders(idx, voice_ds, bs=bs, shuffle=True, drop_last=True)
        opt = torch.optim.Adam(local.parameters(), lr=1e-3)
        for _ in range(local_epochs):
            mae = train_epoch_reg(local, dl, opt)
        sds.append(clone_sd(local.state_dict())); ws.append(len(idx))

    if sds:
        model_v.load_state_dict(avg_sd(sds, ws))

    dl_eval = DataLoader(voice_ds, batch_size=bs, shuffle=False, drop_last=False)
    mae = eval_reg(model_v, dl_eval)
    print(f"[Global] total_UPDRS MAE={mae:.3f}")

### 4.2) XAI on Voice (SHAP)

In [None]:
# Compute SHAP on a small sample
model_v.eval()
small = Xv[:50]

def predict_fn(inp):
    with torch.no_grad():
        inp_t = torch.tensor(inp, dtype=torch.float32, device=DEVICE)
        return model_v(inp_t).cpu().numpy().ravel()

explainer = shap.KernelExplainer(predict_fn, small[:10])
sv = explainer.shap_values(small[10:20], nsamples=100)
print("SHAP computed. Visualize locally with:")
print("shap.summary_plot(sv, features=small[10:20], feature_names=feat_names)")

## 5) (Optional) PADS parser scaffold

In [28]:
def parse_pads_example(pads_root=PADS_ROOT):
    patients = glob.glob(os.path.join(pads_root, "patients", "patient_*.json"))
    for pj in patients[:3]:
        with open(pj, "r", encoding="utf-8") as f:
            meta = json.load(f)
        pid = meta.get("patient_id", os.path.basename(pj))
        print(f"[PADS] Found patient {pid} keys: {list(meta.keys())[:8]} ...")

parse_pads_example()

[PADS] Found patient patient_001.json keys: ['resource_type', 'id', 'study_id', 'condition', 'disease_comment', 'age_at_diagnosis', 'age', 'height'] ...
[PADS] Found patient patient_002.json keys: ['resource_type', 'id', 'study_id', 'condition', 'disease_comment', 'age_at_diagnosis', 'age', 'height'] ...
[PADS] Found patient patient_003.json keys: ['resource_type', 'id', 'study_id', 'condition', 'disease_comment', 'age_at_diagnosis', 'age', 'height'] ...
