# Heart-JEPA Pipeline

1. Clone & Setup
2. Download data
3. Pretrain (self-supervised)
4. Fine-tune (classification/segmentation)

## 0. Clone & Setup

In [1]:
# Clone the repo
!git clone https://github.com/omar-A-hassan/heart-jepa.git
%cd heart-jepa

Cloning into 'heart-jepa'...
remote: Enumerating objects: 115, done.[K
remote: Counting objects: 100% (115/115), done.[K
remote: Compressing objects: 100% (100/100), done.[K
remote: Total 115 (delta 13), reused 113 (delta 11), pack-reused 0 (from 0)[K
Receiving objects: 100% (115/115), 3.73 MiB | 38.22 MiB/s, done.
Resolving deltas: 100% (13/13), done.
/content/heart-jepa


In [2]:
# Install dependencies
!pip install ./lejepa
!pip install -e ".[train]"

Processing ./lejepa
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting loguru (from lejepa==0.0.1)
  Downloading loguru-0.7.3-py3-none-any.whl.metadata (22 kB)
Downloading loguru-0.7.3-py3-none-any.whl (61 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.6/61.6 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: lejepa
  Building wheel for lejepa (pyproject.toml) ... [?25l[?25hdone
  Created wheel for lejepa: filename=lejepa-0.0.1-py3-none-any.whl size=35473 sha256=f57c12f4ae74a4907f3a2903de2e69c599ef1c041b574e780f0d4a6d1a741a54
  Stored in directory: /tmp/pip-ephem-wheel-cache-l5rd1c6f/wheels/06/3e/08/55b53506ec03b2043a3bdcda01de2e1ececa31705d2ee6d70e
Successfully built lejepa
Installing collected packages: loguru, lejepa
Successfully installed lejepa-0.0.1 loguru-0.7.3
Obtaining file:///content

## 1. Download PhysioNet 2016 Data

In [3]:
import wfdb
from pathlib import Path

data_dir = Path("data/physionet2016")
data_dir.mkdir(parents=True, exist_ok=True)

for subset in ['a', 'b', 'c', 'd', 'e', 'f']:
    output_dir = data_dir / f"training-{subset}"
    if output_dir.exists() and any(output_dir.glob("*.wav")):
        print(f"training-{subset}: exists")
        continue
    print(f"Downloading training-{subset}...")
    wfdb.dl_database(f"challenge-2016/training-{subset}", str(output_dir))

Downloading training-a...
Generating record list for: a0001
Generating record list for: a0002
Generating record list for: a0003
Generating record list for: a0004
Generating record list for: a0005
Generating record list for: a0006
Generating record list for: a0007
Generating record list for: a0008
Generating record list for: a0009
Generating record list for: a0010
Generating record list for: a0011
Generating record list for: a0012
Generating record list for: a0013
Generating record list for: a0014
Generating record list for: a0015
Generating record list for: a0016
Generating record list for: a0017
Generating record list for: a0018
Generating record list for: a0019
Generating record list for: a0020
Generating record list for: a0021
Generating record list for: a0022
Generating record list for: a0023
Generating record list for: a0024
Generating record list for: a0025
Generating record list for: a0026
Generating record list for: a0027
Generating record list for: a0028
Generating record list

In [13]:
import urllib.request

base_url = "https://physionet.org/files/challenge-2016/1.0.0"
for s in ['a', 'b', 'c', 'd', 'e', 'f']:
    url = f"{base_url}/training-{s}/REFERENCE.csv"
    path = f"data/physionet2016/training-{s}/REFERENCE.csv"
    urllib.request.urlretrieve(url, path)
    print(f"Downloaded {path}")

Downloaded data/physionet2016/training-a/REFERENCE.csv
Downloaded data/physionet2016/training-b/REFERENCE.csv
Downloaded data/physionet2016/training-c/REFERENCE.csv
Downloaded data/physionet2016/training-d/REFERENCE.csv
Downloaded data/physionet2016/training-e/REFERENCE.csv
Downloaded data/physionet2016/training-f/REFERENCE.csv


In [4]:
# Verify
total = len(list(data_dir.glob("**/*.wav")))
print(f"Total WAV files: {total}")

Total WAV files: 3240


In [32]:
!cd /content/heart-jepa && git pull

remote: Enumerating objects: 7, done.[K
remote: Counting objects: 100% (7/7), done.[K
remote: Compressing objects: 100% (2/2), done.[K
remote: Total 4 (delta 2), reused 4 (delta 2), pack-reused 0 (from 0)[K
Unpacking objects: 100% (4/4), 449 bytes | 449.00 KiB/s, done.
From https://github.com/omar-A-hassan/heart-jepa
   c87d6c3..def1761  main       -> origin/main
Updating c87d6c3..def1761
Fast-forward
 scripts/train_finetune.py | 3 [32m+++[m
 1 file changed, 3 insertions(+)


## 2. Pretrain

Self-supervised learning with SIGReg + Invariance loss. No labels needed.

In [19]:
# Run pretraining
!python scripts/train_pretrain.py \
    ++data_dir=data/physionet2016 \
    ++max_epochs=5 \
    ++batch_size=32 \
    ++num_workers=2 \
    ++precision=16 \
    ++save_dir=checkpoints

backbone: vit_base_patch16_224
pretrained: true
proj_dim: 256
hidden_dim: 2048
embed_dim: 768
seg_hidden_dim: 256
seg_num_classes: 7
seg_output_frames: 224
cls_hidden_dim: 256
cls_num_classes: 3
cls_dropout: 0.1
bstat_name: epps_pulley
bstat_num_slices: 1000
bstat_t_max: 3.0
bstat_n_points: 17
bstat_lambda: 0.01
invariance_weight: 1.0
invariance_temp: 0.1
data_dir: data/physionet2016
batch_size: 32
num_workers: 2
n_views: 4
sample_rate: 2000
duration: 5.0
n_fft: 512
hop_length: 64
n_mels: 128
fmin: 20
fmax: 500
max_epochs: 5
lr: 0.0001
weight_decay: 1.0e-05
warmup_epochs: 5
optimizer: adamw
scheduler: cosine
accelerator: auto
devices: 1
precision: 16
wandb_project: null
wandb_entity: null
log_every_n_steps: 10
val_check_interval: 1.0
save_dir: checkpoints
save_top_k: 3
seed: 42
task: classification
pretrained_checkpoint: null
freeze_encoder: true
finetune_epochs: 30
finetune_lr: 0.001
finetune_batch_size: 32
seg_ce_weight: 0.5
seg_dice_weight: 0.5

Seed set to 42
[2025-12-31 22:37:56,0

## 3. What Pretraining Produces

| Component | What it learned | Used for |
|-----------|-----------------|----------|
| **Encoder** | Cardiac features | Backbone for downstream tasks |
| **Projector** | 256-dim embeddings | Discard after pretraining |
| **Heads** | Random init | Fine-tune for tasks |

In [24]:
import shutil

shutil.copy(
      "checkpoints/heart-jepa-epoch=04-val/loss=1.4092.ckpt",
      "checkpoints/pretrained.ckpt"
  )

'checkpoints/pretrained.ckpt'

## 4. Fine-tune for Classification

Train classification head: Normal vs Abnormal

In [33]:
# Fine-tune classification (update checkpoint path)
!python scripts/train_finetune.py \
    task=classification \
    ++data_dir=data/physionet2016 \
    ++pretrained_checkpoint=checkpoints/pretrained.ckpt \
    ++freeze_encoder=true \
    ++finetune_epochs=5 \
    ++num_workers=2 \
    ++precision=16 \
    ++cls_num_classes=2

backbone: vit_base_patch16_224
pretrained: true
proj_dim: 256
hidden_dim: 2048
embed_dim: 768
seg_hidden_dim: 256
seg_num_classes: 7
seg_output_frames: 224
cls_hidden_dim: 256
cls_num_classes: 2
cls_dropout: 0.1
bstat_name: epps_pulley
bstat_num_slices: 1000
bstat_t_max: 3.0
bstat_n_points: 17
bstat_lambda: 0.01
invariance_weight: 1.0
invariance_temp: 0.1
data_dir: data/physionet2016
batch_size: 32
num_workers: 2
n_views: 4
sample_rate: 2000
duration: 5.0
n_fft: 512
hop_length: 64
n_mels: 128
fmin: 20
fmax: 500
max_epochs: 100
lr: 0.0001
weight_decay: 1.0e-05
warmup_epochs: 5
optimizer: adamw
scheduler: cosine
accelerator: auto
devices: 1
precision: 16
wandb_project: null
wandb_entity: null
log_every_n_steps: 10
val_check_interval: 1.0
save_dir: checkpoints
save_top_k: 3
seed: 42
task: classification
pretrained_checkpoint: checkpoints/pretrained.ckpt
freeze_encoder: true
finetune_epochs: 5
finetune_lr: 0.001
finetune_batch_size: 32
seg_ce_weight: 0.5
seg_dice_weight: 0.5

Seed set to 4

## 5. Fine-tune for Segmentation

Train segmentation head: S1, S2, Systole, Diastole, etc.

**Pseudo-labels are auto-generated** from the PCG signal.

In [30]:
# Fine-tune segmentation (update checkpoint path)
!python scripts/train_finetune.py \
    task=segmentation \
    ++data_dir=data/physionet2016 \
    ++pretrained_checkpoint=checkpoints/pretrained.ckpt \
    ++freeze_encoder=true \
    ++finetune_epochs=5 \
    ++num_workers=2 \
    ++precision=16 \
    ++seg_num_classes=7

backbone: vit_base_patch16_224
pretrained: true
proj_dim: 256
hidden_dim: 2048
embed_dim: 768
seg_hidden_dim: 256
seg_num_classes: 7
seg_output_frames: 224
cls_hidden_dim: 256
cls_num_classes: 3
cls_dropout: 0.1
bstat_name: epps_pulley
bstat_num_slices: 1000
bstat_t_max: 3.0
bstat_n_points: 17
bstat_lambda: 0.01
invariance_weight: 1.0
invariance_temp: 0.1
data_dir: data/physionet2016
batch_size: 32
num_workers: 2
n_views: 4
sample_rate: 2000
duration: 5.0
n_fft: 512
hop_length: 64
n_mels: 128
fmin: 20
fmax: 500
max_epochs: 100
lr: 0.0001
weight_decay: 1.0e-05
warmup_epochs: 5
optimizer: adamw
scheduler: cosine
accelerator: auto
devices: 1
precision: 16
wandb_project: null
wandb_entity: null
log_every_n_steps: 10
val_check_interval: 1.0
save_dir: checkpoints
save_top_k: 3
seed: 42
task: segmentation
pretrained_checkpoint: checkpoints/pretrained.ckpt
freeze_encoder: true
finetune_epochs: 5
finetune_lr: 0.001
finetune_batch_size: 32
seg_ce_weight: 0.5
seg_dice_weight: 0.5

Seed set to 42
