# ECG-FM Finetuning V38: Domain Adaptation & Robustness

Phi√™n b·∫£n n√†y ƒë∆∞·ª£c thi·∫øt k·∫ø ƒë·ªÉ gi·∫£i quy·∫øt v·∫•n ƒë·ªÅ **"Lead Mismatch"** (S·ª± kh√°c bi·ªát gi·ªØa Lead II b·ªánh vi·ªán v√† Lead I Polar H10) th√¥ng qua k·ªπ thu·∫≠t Augmentation.

**T√≠nh nƒÉng n·ªïi b·∫≠t:**
1.  **Signal Augmentation:** T·ª± ƒë·ªông co gi√£n bi√™n ƒë·ªô v√† th·ªùi gian ng·∫´u nhi√™n trong l√∫c train ƒë·ªÉ model h·ªçc ƒë∆∞·ª£c t√≠nh b·∫•t bi·∫øn (Invariance).
2.  **Weighted Loss:** C√¢n b·∫±ng l·∫°i s·ª± ch√∫ √Ω c·ªßa model v√†o c√°c l·ªõp b·ªánh hi·∫øm (STE, STD).
3.  **Support Polar Data:** T·ªëi ∆∞u ƒë·ªÉ h·ªçc t·ª´ d·ªØ li·ªáu h·ªón h·ª£p (Dataset g·ªëc + D·ªØ li·ªáu Polar user).

In [None]:
# 1. Setup Environment
!rm -rf fairseq fairseq-signals bin microroot py39_env

print("‚è≥ Installing Micromamba...")
!curl -Ls https://micro.mamba.pm/api/micromamba/linux-64/latest | tar -xj bin/micromamba

!./bin/micromamba create -r microroot -n ecg_env -c pytorch -c nvidia -c conda-forge \
    python=3.9 \
    pytorch torchvision torchaudio pytorch-cuda=12.1 \
    -y
print("‚úÖ Environment Created.")

In [None]:
%%bash
./bin/micromamba run -r microroot -n ecg_env pip install \
    "transformers==4.30.0" "accelerate>=0.20.0" \
    pandas scipy wfdb pyarrow scikit-learn tqdm \
    hydra-core omegaconf bitarray soundfile matplotlib \
    sacrebleu portalocker regex tensorboardX "antlr4-python3-runtime==4.8"

git clone https://github.com/facebookresearch/fairseq.git
cd fairseq; git checkout v0.12.2; cd ..
git clone https://github.com/Jwoo5/fairseq-signals.git

echo "‚úÖ Setup Complete."

In [None]:
%%writefile train_robust.py
import os
import sys
import json
import warnings
import random

# --- ENV SETUP ---
os.environ["MPLBACKEND"] = "Agg"
warnings.filterwarnings("ignore")
cwd = os.getcwd()
sys.path.insert(0, os.path.join(cwd, "fairseq"))
sys.path.insert(0, os.path.join(cwd, "fairseq-signals"))

try:
    from fairseq_signals.models.wav2vec.wav2vec2_cmsc_rlm import Wav2Vec2CMSCRLMModel, Wav2Vec2CMSCRLMConfig
except:
    try:
        from fairseq_signals.models.ecg_transformer import ECGTransformerModel as Wav2Vec2CMSCRLMModel
        from fairseq_signals.models.ecg_transformer import ECGTransformerConfig as Wav2Vec2CMSCRLMConfig
    except: sys.exit(1)

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from scipy import signal
from sklearn.metrics import f1_score, classification_report
from tqdm import tqdm

# --- CONFIGURATION ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CSV_PATH = "/kaggle/input/dataset-multilabel-v3/dataset_multilabel_500hz/labels.csv"
DATA_DIR = "/kaggle/input/dataset-multilabel-v3/dataset_multilabel_500hz/data"
PRETRAINED_PATH = "/kaggle/input/ecg-fm-pretrained-v2/pytorch/default/1/mimic_iv_ecg_physionet_pretrained.pt"
SAVE_DIR = "/kaggle/working"

BATCH_SIZE = 32
EPOCHS = 25
LR = 3e-5
TARGET_LEN = 5000

# --- ROBUST DATASET (AUGMENTATION) ---
class RobustECGDataset(Dataset):
    def __init__(self, csv_file, root_dir, split='train'):
        self.root_dir = root_dir
        self.split = split
        df = pd.read_csv(csv_file)
        self.data = df[df['split'] == split].reset_index(drop=True)
        
        all_labels = []
        for x in df['labels'].dropna().astype(str):
            all_labels.extend(x.split(';'))
        self.classes = sorted(list(set(all_labels)))
        self.c2i = {c: i for i, c in enumerate(self.classes)}
        print(f"[{split.upper()}] Samples: {len(self.data)} | Classes: {len(self.classes)}")
        
    def get_pos_weights(self):
        """T√≠nh tr·ªçng s·ªë ph·∫°t cho c√°c l·ªõp hi·∫øm"""
        counts = np.zeros(len(self.classes))
        for labels in self.data['labels'].dropna():
            for l in str(labels).split(';'):
                if l in self.c2i: counts[self.c2i[l]] += 1
        
        total = len(self.data)
        weights = (total - counts) / (counts + 1e-6)
        return torch.tensor(weights, dtype=torch.float32)

    def __len__(self): return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        try:
            path = os.path.join(self.root_dir, row['filename'])
            ecg = np.load(path).astype(np.float32)
            ecg = np.nan_to_num(ecg)
            
            # === AUGMENTATION (Ch·ªâ train) ===
            if self.split == 'train':
                # 1. Random Amplitude (Tr·ªã l·ªói Polar bi√™n ƒë·ªô th·∫•p)
                if random.random() > 0.3:
                    scale = random.uniform(0.7, 1.4)
                    ecg = ecg * scale
                
                # 2. Time Warp (Tr·ªã l·ªói RBBB gi·∫£ do ƒë·ªô r·ªông s√≥ng)
                if random.random() > 0.5:
                    factor = random.uniform(0.9, 1.1)
                    new_len = int(len(ecg) * factor)
                    ecg = signal.resample(ecg, new_len)
                
                # 3. Add Noise (Tr·ªã nhi·ªÖu sensor)
                if random.random() > 0.5:
                    noise = np.random.normal(0, 0.02, ecg.shape)
                    ecg = ecg + noise
            # ================================

            # Normalize
            if np.std(ecg) > 1e-6:
                ecg = (ecg - np.mean(ecg)) / np.std(ecg)
            else: ecg = np.zeros_like(ecg)
            
            # Fix Length (C·∫Øt/ƒê·ªám l·∫°i sau khi time warp)
            if len(ecg) < TARGET_LEN:
                ecg = np.pad(ecg, (0, TARGET_LEN - len(ecg)), 'constant')
            else:
                # Random Crop khi train gi√∫p h·ªçc nhi·ªÅu ph·∫ßn c·ªßa s√≥ng
                if self.split == 'train':
                    start = random.randint(0, len(ecg) - TARGET_LEN)
                    ecg = ecg[start : start + TARGET_LEN]
                else:
                    ecg = ecg[:TARGET_LEN]

            x = torch.tensor(np.tile(ecg, (12, 1)), dtype=torch.float32)
            y = torch.zeros(len(self.classes), dtype=torch.float32)
            if pd.notna(row['labels']):
                for l in str(row['labels']).split(';'):
                    if l in self.c2i: y[self.c2i[l]] = 1.0
            return x, y
        except: 
            return torch.zeros((12, TARGET_LEN), dtype=torch.float32), torch.zeros(len(self.classes), dtype=torch.float32)

# --- MODEL ---
class ECGFM_MultiLabel(nn.Module):
    def __init__(self, pt_path, n_cls):
        super().__init__()
        cfg = Wav2Vec2CMSCRLMConfig()
        if hasattr(cfg, 'model'): model_cfg = cfg.model
        else: model_cfg = cfg
        model_cfg.encoder_embed_dim = 768
        model_cfg.conv_feature_layers = "[(256, 2, 2)] * 4"
        self.enc = Wav2Vec2CMSCRLMModel(model_cfg)
        if pt_path and os.path.exists(pt_path):
            state = torch.load(pt_path, map_location="cpu")
            if "model" in state: state = state["model"]
            s = {k.replace("module.", ""): v for k, v in state.items()}
            self.enc.load_state_dict(s, strict=False)
        self.head = nn.Sequential(
            nn.Linear(768, 256), nn.ReLU(), nn.Dropout(0.4), nn.Linear(256, n_cls)
        )
    def forward(self, x):
        return self.head(self.enc(source=x, padding_mask=None, mask=False)['x'].mean(dim=1))

# --- MAIN ---
if __name__ == "__main__":
    if not os.path.exists(CSV_PATH): sys.exit(0)
    
    train_ds = RobustECGDataset(CSV_PATH, DATA_DIR, 'train')
    val_ds = RobustECGDataset(CSV_PATH, DATA_DIR, 'val')
    test_ds = RobustECGDataset(CSV_PATH, DATA_DIR, 'test')
    
    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    
    # WEIGHTED LOSS
    pos_weight = train_ds.get_pos_weights().to(DEVICE)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    
    model = ECGFM_MultiLabel(PRETRAINED_PATH, len(train_ds.classes)).to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, verbose=True)
    
    print("üöÄ Training V38 (Robust Mode)...")
    best_f1 = 0.0
    
    for ep in range(EPOCHS):
        model.train()
        loss_sum = 0
        for x, y in tqdm(train_dl, desc=f"Ep {ep+1}"):
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            loss_sum += loss.item()
            
        # Val
        model.eval()
        all_y, all_p = [], []
        with torch.no_grad():
            for x, y in val_dl:
                probs = torch.sigmoid(model(x.to(DEVICE))).cpu().numpy()
                all_y.append(y.numpy()); all_p.append(probs)
        
        all_y = np.concatenate(all_y); all_p = np.concatenate(all_p)
        curr_f1 = f1_score(all_y, (all_p > 0.5).astype(int), average='macro', zero_division=0)
        print(f"   Loss: {loss_sum/len(train_dl):.4f} | Val F1: {curr_f1:.4f}")
        scheduler.step(curr_f1)
        
        if curr_f1 > best_f1:
            best_f1 = curr_f1
            torch.save(model.state_dict(), os.path.join(SAVE_DIR, "ecg_fm_best.pth"))
            
    # TEST
    print("\nüß™ TESTING PHASE (Sensitive Thresholds)")
    model.load_state_dict(torch.load(os.path.join(SAVE_DIR, "ecg_fm_best.pth")))
    model.eval()
    y_true, y_prob = [], []
    with torch.no_grad():
        for x, y in tqdm(test_dl):
            probs = torch.sigmoid(model(x.to(DEVICE))).cpu().numpy()
            y_true.append(y.numpy()); y_prob.append(probs)
            
    y_true = np.concatenate(y_true); y_prob = np.concatenate(y_prob)
    
    print(f"{'CLASS':<10} | {'THRESHOLD':<10} | {'PRECISION':<10} | {'RECALL':<10} | {'F1':<10}")
    print("-"*60)
    for i, cls in enumerate(train_ds.classes):
        thresh = 0.3 if cls in ['STE', 'STD', 'LBBB', 'AFIB'] else 0.5
        y_pred = (y_prob[:, i] > thresh).astype(int)
        report = classification_report(y_true[:, i], y_pred, output_dict=True, zero_division=0)
        s = report['1.0']
        print(f"{cls:<10} | {thresh:<10} | {s['precision']:.2f}       | {s['recall']:.2f}     | {s['f1-score']:.2f}")


In [None]:
!./bin/micromamba run -r microroot -n ecg_env python train_robust.py