# Haptic Signal VAE â€” Colab Training

This notebook runs the full training pipeline on Google Colab.

**Steps:**
1. Mount Google Drive
2. Clone the GitHub repo
3. Install dependencies
4. Run training via CLI
5. Evaluate and listen to results

In [None]:
# 1. Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# 2. Clone repo (or pull latest if already cloned)
import os

REPO_URL = "https://github.com/cindy-77jiayi/thesis_hapticAE.git"
REPO_DIR = "/content/thesis_hapticAE"

if os.path.exists(REPO_DIR):
    !cd {REPO_DIR} && git pull
else:
    !git clone {REPO_URL} {REPO_DIR}

os.chdir(REPO_DIR)
print(f"Working directory: {os.getcwd()}")

In [None]:
# 3. Install dependencies
!pip install -q -r requirements.txt

In [None]:
# 4. Configure paths
DATA_DIR = "/content/drive/MyDrive/hapticgen-dataset/expertvoted"
OUTPUT_DIR = "/content/drive/MyDrive/thesis/outputs"
CONFIG = "configs/vae_default.yaml"

print(f"Data:   {DATA_DIR}")
print(f"Output: {OUTPUT_DIR}")
print(f"Config: {CONFIG}")

In [None]:
# 5. Run training
!python scripts/train.py --config {CONFIG} --data_dir {DATA_DIR} --output_dir {OUTPUT_DIR}

In [None]:
# 6. Evaluate: load results and visualize
import sys
sys.path.insert(0, REPO_DIR)

import glob
import numpy as np
import torch
from torch.utils.data import DataLoader

from src.utils.config import load_config
from src.utils.seed import set_seed
from src.data.preprocessing import collect_clean_wavs, estimate_global_rms
from src.data.dataset import HapticWavDataset
from src.models.conv_vae import ConvVAE
from src.eval.evaluate import evaluate_reconstruction, print_metrics
from src.eval.visualize import plot_loss_curves, plot_waveform_comparison
from src.eval.audio import play_ab_comparison

config = load_config(CONFIG)
set_seed(config['seed'])

# Find latest run
run_dirs = sorted(glob.glob(f"{OUTPUT_DIR}/*/best_model.pt"))
assert run_dirs, "No trained models found"
ckpt_path = run_dirs[-1]
run_dir = os.path.dirname(ckpt_path)
print(f"Using checkpoint: {ckpt_path}")

# Load metrics
metrics = np.load(os.path.join(run_dir, 'metrics.npz'))
plot_loss_curves(metrics['train_losses'].tolist(), metrics['val_losses'].tolist())

In [None]:
# 7. Reconstruction evaluation + audio
data_cfg = config['data']
wav_files = collect_clean_wavs(DATA_DIR)
N = len(wav_files)
perm = np.random.permutation(N)
split = int(data_cfg['train_split'] * N)
val_files = [wav_files[i] for i in perm[split:]]
train_files = [wav_files[i] for i in perm[:split]]
global_rms = estimate_global_rms(train_files, n=200, sr_expect=data_cfg['sr'])

val_ds = HapticWavDataset(val_files, T=data_cfg['T'], sr_expect=data_cfg['sr'], global_rms=global_rms, scale=data_cfg['scale'])
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_cfg = config['model']
model = ConvVAE(
    T=data_cfg['T'], latent_dim=model_cfg['latent_dim'],
    channels=tuple(model_cfg['channels']),
    first_kernel=model_cfg.get('first_kernel', 25),
    kernel_size=model_cfg.get('kernel_size', 9),
).to(device)
model.load_state_dict(torch.load(ckpt_path, map_location=device))
model.eval()

result = evaluate_reconstruction(model, val_loader, device, n_samples=10)
print_metrics(result)
plot_waveform_comparison(result['x_np'], result['xhat_np'])
play_ab_comparison(result['x_np'], result['xhat_np'], sr=data_cfg['sr'])