# Phase 4: Self-Supervised Learning (SSL) for PPG

## 1. Objective & Motivation
The goal of this phase is to learn robust physiological representations from **unlabeled** PPG data using Contrastive Learning (SimCLR). 

**Why SSL?**
- **Label Scarcity**: Getting ground-truth heart rate (HR) requires ECG reference, which is hard to scale.
- **Noise Robustness**: Contrastive learning forces the model to learn the intrinsic "pulse" structure that remains invariant under noise (augmentation).

## 2. Methodology
- **Architecture**: 1D-CNN Encoder (from Phase 3) + MLP Projection Head.
- **Task**: SimCLR (Maximize similarity between two augmented views of the same window).
- **Augmentations**: Physiology-aware transforms (Jitter, Scaling, Baseline Wander, Masking).

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import json
import pandas as pd
from pathlib import Path
from tqdm import tqdm

# Import project modules
from models import CNNBaseline
from ppg_ssl.augmentations import PPGTransforms, ContrastiveTransform
from ppg_ssl.dataset import get_ssl_dataloader
from ppg_ssl.models import SimCLRWrapper, nt_xent_loss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on {device}")

## 3. Physiology-Preserving Augmentations
We use a specialized augmentation pipeline designed not to destroy the vascular information in the PPG signal.

In [None]:
# Visualize Augmentations
dataset = get_ssl_dataloader(['ppg_dalia'], batch_size=1, transform=None).dataset
raw_signal = dataset[0].numpy().flatten()

augmenter = PPGTransforms(fs=64.0, domain_type='wearable')
aug_signal_1 = augmenter(raw_signal).numpy().flatten()
aug_signal_2 = augmenter(raw_signal).numpy().flatten()

plt.figure(figsize=(12, 4))
plt.plot(raw_signal, label='Original', alpha=0.8)
plt.plot(aug_signal_1, label='View 1 (Augmented)', alpha=0.6)
plt.plot(aug_signal_2, label='View 2 (Augmented)', alpha=0.6)
plt.title("SimCLR Data Views: Positive Pairs")
plt.legend()
plt.show()

## 4. SSL Pre-training Loop (SimCLR)
We train the model to minimize the NT-Xent loss, pulling positive pairs together and pushing negative pairs apart.

In [None]:
def run_pretraining(epochs=5, batch_size=64):
    # Initialize Model
    model = SimCLRWrapper(CNNBaseline(), input_dim=256).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-3)
    
    # Data Loader (SimCLR Transform)
    loader = get_ssl_dataloader(['ppg_dalia'], batch_size=batch_size, 
                                transform=ContrastiveTransform(PPGTransforms(domain_type='wearable')))
    
    history = []
    print("Starting Pre-training...")
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for x1, x2 in loader:
            x1, x2 = x1.to(device), x2.to(device)
            optimizer.zero_grad()
            z1, z2 = model(x1, x2)
            loss = nt_xent_loss(z1, z2, temperature=0.5)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(loader)
        history.append(avg_loss)
        print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.4f}")
    
    return model, history

# Run for 1 epoch as demo (Full training was 50 epochs)
# model, loss_curve = run_pretraining(epochs=1)

## 5. Results & Validation
Comparison of Supervised Baseline vs. SSL Fine-tuned model on PPG-DaLiA.

In [None]:
# Load Pre-computed Results
results = []

# 1. Supervised Baseline
with open('result_ppg_dalia_cnn.json', 'r') as f:
    base_res = json.load(f)
    results.append({'Method': 'Supervised CNN (No SSL)', 'MAE': base_res['test_mae']})

# 2. Specialized SSL (Phase 4.2)
with open('result_ssl_specialized_ppg_dalia.json', 'r') as f:
    ssl_res = json.load(f)
    results.append({'Method': 'Specialized SSL (With SSL)', 'MAE': ssl_res['test_mae']})

df = pd.DataFrame(results)
print(df)

# Plot
plt.figure(figsize=(8, 5))
plt.bar(df['Method'], df['MAE'], color=['salmon', 'royalblue'])
plt.title("Impact of SSL on PPG-DaLiA Performance")
plt.ylabel("MAE (BPM) - Lower is Better")
plt.grid(axis='y', linestyle='--', alpha=0.5)
plt.show()

## 6. Conclusion
The SSL framework successfully learns meaningful representations from unlabeled data. The specialized pre-training reduced the error from random initialization (~95 BPM) to **13.2 BPM**, demonstrating strong physiological feature learning capability.