# ECG Multi-task GNN (SHD + Reconstruction) â€” Colab Quick Start

This notebook is a **minimal, Colab-friendly** quick start for your project.

It focuses on:
- mounting Drive
- making sure `Dataset/` is visible
- installing **PyTorch Geometric fast (prebuilt wheels)**
- running training + evaluation


In [None]:
import sys
import os

# Check if running on Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Running on Google Colab")
    from google.colab import drive

    # Mount Google Drive
    drive.mount('/content/drive')

    # Change to project directory (uncomment and adjust path as needed)
    # For paths with spaces, use quotes or access via: /content/drive/Shareddrives/
    # Example:
    %cd "/content/drive/Shareddrives/CS224W Project/Project/Testing_final_submission"
    # or
    # %cd /content/drive/MyDrive/your-project-folder

    print(f"Drive mounted! Current directory: {os.getcwd()}")
    print("Remember to uncomment and adjust the %cd command above to navigate to your project folder")
else:
    print("Running locally")
    print(f"Current directory: {os.getcwd()}")


Running on Google Colab
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/Shareddrives/CS224W Project/Project/Testing_final_submission
Drive mounted! Current directory: /content/drive/Shareddrives/CS224W Project/Project/Testing_final_submission
Remember to uncomment and adjust the %cd command above to navigate to your project folder


### 1.1 Install PyTorch

Choose the appropriate installation based on your system:


In [None]:
import torch

# Check if PyTorch is already installed
try:
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA version: {torch.version.cuda}")
        print(f"GPU device: {torch.cuda.get_device_name(0)}")
except ImportError:
    print("PyTorch not installed. Installing...")

    # Uncomment the appropriate line based on your system:

    # For CPU only:
    # !pip install torch torchvision torchaudio

    # For CUDA 11.8:
    # !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

    # For CUDA 12.1:
    # !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

    pass


PyTorch version: 2.9.0+cu126
CUDA available: True
CUDA version: 12.6
GPU device: NVIDIA A100-SXM4-40GB


### 1.2 Install PyTorch Geometric


In [None]:
# Check if PyTorch Geometric is installed
try:
    import torch_geometric
    print(f"PyTorch Geometric version: {torch_geometric.__version__}")
except ImportError:
    print("Installing PyTorch Geometric...")
    import sys
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "torch-geometric"])

    # If the above fails, try this alternative in a separate cell:
    # import subprocess
    # subprocess.check_call([sys.executable, "-m", "pip", "install", "pyg_lib", "torch_scatter", "torch_sparse", "torch_cluster", "torch_spline_conv", "-f", "https://data.pyg.org/whl/torch-2.0.0+cpu.html"])

    print("PyTorch Geometric installed!")


PyTorch Geometric version: 2.7.0


In [None]:
# --- 3) Ensure Dataset/ is visible to the scripts ---
# train_multitask_shd.py and evaluate_multitask.py expect:
#   DATA_DIR = <directory_of_script>/Dataset
# So if your scripts are in PROJECT_DIR, Dataset must be at:
#   PROJECT_DIR/Dataset

import os, glob

required = [
    "echonext_metadata_100k.csv",
    "EchoNext_train_waveforms.npy",
    "EchoNext_val_waveforms.npy",
    "EchoNext_train_tabular_features.npy",
    "EchoNext_val_tabular_features.npy",
    "EchoNext_test_waveforms.npy",
    "EchoNext_test_tabular_features.npy",
]

data_dir = os.path.join(os.getcwd(), "Dataset")
print("Dataset dir:", data_dir)
print("Exists:", os.path.isdir(data_dir))

missing = [f for f in required if not os.path.exists(os.path.join(data_dir, f))]
print("Missing:", missing)

# If Dataset is on Drive elsewhere, symlink it here:
# Example:
#   !ln -s "/content/drive/MyDrive/CS224W_Project/Dataset" "{data_dir}"


Dataset dir: /content/drive/Shareddrives/CS224W Project/Project/Testing_final_submission/Dataset
Exists: True
Missing: []


In [None]:
# --- 4) Run a *short* training smoke test (adjust epochs/batch_size as needed) ---
# Tip: start tiny to validate paths + GPU, then increase epochs.
!python train_multitask_shd.py --epochs 1 --batch_size 32 --lr 1e-4 --save_dir checkpoints_multitask


Multi-Task SHD Classification + Reconstruction
Classification weight: 1.0
Reconstruction weight: 0.1
Lead dropout: 6-12 leads, prob=0.5
Reconstruction targets: 13 primary + negated (prob=0.2, n=3)
Device: cuda

Loading datasets...
Loading waveforms from /content/drive/Shareddrives/CS224W Project/Project/Testing_final_submission/Dataset/EchoNext_train_waveforms.npy...
  Shape: (72475, 1, 2500, 12)
Loading tabular from /content/drive/Shareddrives/CS224W Project/Project/Testing_final_submission/Dataset/EchoNext_train_tabular_features.npy...
  Shape: (72475, 7)
Loading labels from /content/drive/Shareddrives/CS224W Project/Project/Testing_final_submission/Dataset/echonext_metadata_100k.csv...
  train set: 72475 samples, positive rate: 52.37%
ECGGraphBuilderV2: 13 electrodes, 156 possible edges
Loading waveforms from /content/drive/Shareddrives/CS224W Project/Project/Testing_final_submission/Dataset/EchoNext_val_waveforms.npy...
  Shape: (4626, 1, 2500, 12)
Loading tabular from /content/dri

In [19]:
# --- 5) Pick the latest checkpoint produced by training ---
import os, re, glob

ckpt_dir = "checkpoints_multitask"
paths = glob.glob(os.path.join(ckpt_dir, "best_multitask_model*.pt"))
assert paths, f"No checkpoints found in {ckpt_dir}. Did training run?"

def epoch_num(p):
    m = re.search(r"checkpoint_epoch(\d+)\.pt$", os.path.basename(p))
    return int(m.group(1)) if m else -1

latest = sorted(paths, key=epoch_num)[-1]
print("Latest checkpoint:", latest)


Latest checkpoint: checkpoints_multitask/best_multitask_model.pt


In [20]:
# --- 6) Evaluate on the test split ---
!python evaluate_multitask.py --checkpoint "{latest}" --batch_size 16 --output_dir evaluation_multitask


Multi-Task Model Evaluation
Device: cuda

Loading checkpoint: checkpoints_multitask/best_multitask_model.pt
  Loaded from epoch 1
  Val AUROC: 0.8097917508635744
ECGGraphBuilderV2: 13 electrodes, 156 possible edges

EVALUATION: 12-Lead Input
ECGGraphBuilderV2: 13 electrodes, 156 possible edges
Loaded test set: 5442 samples
  Using leads: ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
Evaluating classification: 100% 341/341 [00:24<00:00, 14.11it/s]

12-Lead Classification Metrics:
  AUROC:    0.8174
  AUPRC:    0.7785
  Accuracy: 0.7501
  F1 Score: 0.6963

EVALUATION: 6 Limb Leads Only (I, II, III, aVR, aVL, aVF)
ECGGraphBuilderV2: 13 electrodes, 156 possible edges
Loaded test set: 5442 samples
  Using leads: ['I', 'II', 'III', 'aVR', 'aVL', 'aVF']
Evaluating classification: 100% 341/341 [00:20<00:00, 17.00it/s]

6-Limb-Lead Classification Metrics:
  AUROC:    0.8004
  AUPRC:    0.7556
  Accuracy: 0.7299
  F1 Score: 0.6700

RECONSTRUCTION: 6 Limb Leads -> 13

In [21]:
# --- 7) Peek at evaluation outputs ---
import os, glob, pandas as pd

out_dir = "evaluation_multitask"
print("Output files:", glob.glob(out_dir + "/*")[:20])

# If metrics are saved as CSV, this will show them (safe if file doesn't exist)
csvs = glob.glob(out_dir + "/*.csv")
if csvs:
    df = pd.read_csv(csvs[0])
    display(df.head())
else:
    print("No CSV found (this is OK if your eval script writes JSON/NPY instead).")


Output files: ['evaluation_multitask/roc_curves.pdf', 'evaluation_multitask/pr_curves.pdf', 'evaluation_multitask/confusion_matrices.pdf', 'evaluation_multitask/reconstruction_quality.pdf', 'evaluation_multitask/reconstruction_example_sample1186.pdf', 'evaluation_multitask/reconstruction_example_sample4763.pdf', 'evaluation_multitask/reconstruction_example_sample410.pdf', 'evaluation_multitask/reconstruction_example_sample1945.pdf', 'evaluation_multitask/reconstruction_example_sample5087.pdf', 'evaluation_multitask/metrics.txt', 'evaluation_multitask/predictions.csv']


Unnamed: 0,label,prob_12lead,pred_12lead,prob_6limb,pred_6limb
0,1.0,0.86149,1,0.742749,1
1,0.0,0.571211,1,0.379952,0
2,1.0,0.332494,0,0.258973,0
3,1.0,0.828494,1,0.694085,1
4,0.0,0.417697,0,0.452478,0
