# Heart-JEPA Pipeline

1. Clone & Setup
2. Download data
3. Pretrain (self-supervised)
4. Fine-tune (classification/segmentation)

## 0. Clone & Setup

In [None]:
# Clone the repo
!git clone https://github.com/omar-A-hassan/heart-jepa.git
%cd heart-jepa

In [None]:
# Install dependencies
!pip install git+https://github.com/rbalestr-lab/lejepa.git
!pip install -e ".[train]"

## 1. Download PhysioNet 2016 Data

In [None]:
import wfdb
from pathlib import Path

data_dir = Path("data/physionet2016")
data_dir.mkdir(parents=True, exist_ok=True)

for subset in ['a', 'b', 'c', 'd', 'e', 'f']:
    output_dir = data_dir / f"training-{subset}"
    if output_dir.exists() and any(output_dir.glob("*.wav")):
        print(f"training-{subset}: exists")
        continue
    print(f"Downloading training-{subset}...")
    wfdb.dl_database(f"challenge-2016/training-{subset}", str(output_dir))

In [None]:
# Verify
total = len(list(data_dir.glob("**/*.wav")))
print(f"Total WAV files: {total}")

## 2. Pretrain

Self-supervised learning with SIGReg + Invariance loss. No labels needed.

In [None]:
# Run pretraining
!python scripts/train_pretrain.py \
    ++data_dir=data/physionet2016 \
    ++max_epochs=50 \
    ++batch_size=16 \
    ++save_dir=checkpoints

## 3. What Pretraining Produces

| Component | What it learned | Used for |
|-----------|-----------------|----------|
| **Encoder** | Cardiac features | Backbone for downstream tasks |
| **Projector** | 256-dim embeddings | Discard after pretraining |
| **Heads** | Random init | Fine-tune for tasks |

In [None]:
# Find the best checkpoint
from pathlib import Path
checkpoints = list(Path("checkpoints").glob("heart-jepa-*.ckpt"))
if checkpoints:
    latest = max(checkpoints, key=lambda p: p.stat().st_mtime)
    print(f"Pretrained checkpoint: {latest}")
else:
    print("No checkpoint found - run pretraining first")

## 4. Fine-tune for Classification

Train classification head: Normal vs Abnormal

In [None]:
# Fine-tune classification (update checkpoint path)
!python scripts/train_finetune.py \
    task=classification \
    ++data_dir=data/physionet2016 \
    ++pretrained_checkpoint=checkpoints/YOUR_CHECKPOINT.ckpt \
    ++freeze_encoder=true \
    ++finetune_epochs=20 \
    ++cls_num_classes=2

## 5. Fine-tune for Segmentation

Train segmentation head: S1, S2, Systole, Diastole, etc.

**Pseudo-labels are auto-generated** from the PCG signal.

In [None]:
# Fine-tune segmentation (update checkpoint path)
!python scripts/train_finetune.py \
    task=segmentation \
    ++data_dir=data/physionet2016 \
    ++pretrained_checkpoint=checkpoints/YOUR_CHECKPOINT.ckpt \
    ++freeze_encoder=true \
    ++finetune_epochs=30 \
    ++seg_num_classes=7

## Summary

| Stage | Script | Data | Output |
|-------|--------|------|--------|
| Pretrain | `train_pretrain.py` | Unlabeled PCG | Encoder weights |
| Fine-tune (cls) | `train_finetune.py task=classification` | Normal/Abnormal labels | Classification model |
| Fine-tune (seg) | `train_finetune.py task=segmentation` | Auto pseudo-labels | Segmentation model |

### Segmentation classes:
0=Background, 1=S1, 2=Systole, 3=S2, 4=Diastole, 5=S3, 6=S4