In [1]:
import sys
import yaml
import numpy as np
from pathlib import Path
import hydra
from hydra import compose, initialize_config_dir
from omegaconf import DictConfig
import os

sys.path.append('../')

In [None]:
from omegaconf import OmegaConf


config_dir = os.path.abspath('../conf')
with initialize_config_dir(config_dir=config_dir, version_base=None):
    cfg = compose(config_name="config")

cfg

{'data': {'signal_dim': 2,
  'noise_dim': 100,
  'num_actions': 4,
  'traj_len': 200,
  'step_size': 0.1,
  'n_train': 300,
  'n_val': 30,
  'n_test': 30,
  'seed': 0,
  'out_dir': 'data/'},
 'model': {'z_dim_vae': 2,
  'z_dim_contrastive': 2,
  'enc_widths': [256, 256],
  'dec_widths': [256, 256],
  'proj_widths': [256, 256],
  'dyn_widths': [256, 256],
  'probe_widths': [64, 64],
  'activation': 'relu'},
 'train': {'batch_size': 256,
  'steps_phase1': 10,
  'steps_phase2': 10,
  'steps_probe': 10,
  'num_workers': 0,
  'ckpt_dir': 'ckpts/',
  'eval_batch_size': 512,
  'wandb_log_freq': 200,
  'device': 'mps',
  'vae': {'lr': 0.002, 'weight_decay': 0.0, 'beta': 0.001},
  'contrastive': {'lr': 0.001, 'weight_decay': 0.0, 'temperature': 0.1},
  'dynamics': {'lr': 0.001, 'weight_decay': 0.0},
  'probe': {'lr': 0.001, 'weight_decay': 0.0}},
 'wandb': {'enabled': True,
  'project': 'repr-world',
  'entity': 'mattia-scardecchia',
  'group': None,
  'mode': 'online',
  'dir': 'logs',
  'tags

In [3]:
from scripts.generate_data import main as gen_main
gen_main(cfg)

Wrote data/train.npz with shapes s(60000, 102), a(60000,), sp(60000, 102)
Wrote data/val.npz with shapes s(6000, 102), a(6000,), sp(6000, 102)
Wrote data/test.npz with shapes s(6000, 102), a(6000,), sp(6000, 102)


## Train Phase 1 (choose one)

In [7]:
from scripts.train import train_phase1_contrastive, _maybe_init_wandb

device = 'mps'
z_space = 'contrastive'
wb = _maybe_init_wandb(cfg)

phi_path, g_path, contrastive_metrics = train_phase1_contrastive(cfg, device, wb)

[34m[1mwandb[0m: Currently logged in as: [33mmattia-scardecchia[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




[CONTRASTIVE] Epoch 1: train_loss=2.5239
[CONTRASTIVE-VAL] Epoch 1: val_loss=2.2710
[CONTRASTIVE] Epoch 2: train_loss=0.8999
[CONTRASTIVE-VAL] Epoch 2: val_loss=1.9363
[CONTRASTIVE] Epoch 3: train_loss=0.5999
[CONTRASTIVE-VAL] Epoch 3: val_loss=1.7528
[CONTRASTIVE] Epoch 4: train_loss=0.5161
[CONTRASTIVE-VAL] Epoch 4: val_loss=1.7313
[CONTRASTIVE] Epoch 5: train_loss=0.4489
[CONTRASTIVE-VAL] Epoch 5: val_loss=1.7587
[CONTRASTIVE] Epoch 6: train_loss=0.4258
[CONTRASTIVE-VAL] Epoch 6: val_loss=1.7538
[CONTRASTIVE] Epoch 7: train_loss=0.3717
[CONTRASTIVE-VAL] Epoch 7: val_loss=1.9741
[CONTRASTIVE] Epoch 8: train_loss=0.3644
[CONTRASTIVE-VAL] Epoch 8: val_loss=1.4004
[CONTRASTIVE] Epoch 9: train_loss=0.3410
[CONTRASTIVE-VAL] Epoch 9: val_loss=1.6027
[CONTRASTIVE] Epoch 10: train_loss=0.3987
[CONTRASTIVE-VAL] Epoch 10: val_loss=1.0876
Saved ckpts/contrastive_phi.pt
Saved ckpts/contrastive_g.pt




## Train Phase 2 dynamics (on selected latent space)

In [8]:
from scripts.train import train_phase2_dynamics

dyn_path, dynamics_metrics = train_phase2_dynamics(cfg, device, z_space=z_space, wb=wb)

[DYN-CONTRASTIVE] Epoch 1: train_mse=0.9865
[DYN-CONTRASTIVE-VAL] Epoch 1: val_mse=0.0389
[DYN-CONTRASTIVE] Epoch 2: train_mse=0.0360
[DYN-CONTRASTIVE-VAL] Epoch 2: val_mse=0.0376
[DYN-CONTRASTIVE] Epoch 3: train_mse=0.0353
[DYN-CONTRASTIVE-VAL] Epoch 3: val_mse=0.0373
[DYN-CONTRASTIVE] Epoch 4: train_mse=0.0356
[DYN-CONTRASTIVE-VAL] Epoch 4: val_mse=0.0387
[DYN-CONTRASTIVE] Epoch 5: train_mse=0.0358
[DYN-CONTRASTIVE-VAL] Epoch 5: val_mse=0.0388
[DYN-CONTRASTIVE] Epoch 6: train_mse=0.0358
[DYN-CONTRASTIVE-VAL] Epoch 6: val_mse=0.0413
[DYN-CONTRASTIVE] Epoch 7: train_mse=0.0364
[DYN-CONTRASTIVE-VAL] Epoch 7: val_mse=0.0394
[DYN-CONTRASTIVE] Epoch 8: train_mse=0.0360
[DYN-CONTRASTIVE-VAL] Epoch 8: val_mse=0.0382
[DYN-CONTRASTIVE] Epoch 9: train_mse=0.0362
[DYN-CONTRASTIVE-VAL] Epoch 9: val_mse=0.0404
[DYN-CONTRASTIVE] Epoch 10: train_mse=0.0367
[DYN-CONTRASTIVE-VAL] Epoch 10: val_mse=0.0411
Saved ckpts/dyn_contrastive.pt


## Train probe on frozen latents (z → R^2)

In [9]:
from scripts.train import train_probes
probe_path, probe_metrics = train_probes(cfg, device, z_space=z_space, wb=wb)

[PROBE-CONTRASTIVE] Epoch 1: train_mse=0.0501
[PROBE-CONTRASTIVE-VAL] Epoch 1: val_mse=0.0023
[PROBE-CONTRASTIVE] Epoch 2: train_mse=0.0015
[PROBE-CONTRASTIVE-VAL] Epoch 2: val_mse=0.0013
[PROBE-CONTRASTIVE] Epoch 3: train_mse=0.0011
[PROBE-CONTRASTIVE-VAL] Epoch 3: val_mse=0.0011
[PROBE-CONTRASTIVE] Epoch 4: train_mse=0.0010
[PROBE-CONTRASTIVE-VAL] Epoch 4: val_mse=0.0010
[PROBE-CONTRASTIVE] Epoch 5: train_mse=0.0009
[PROBE-CONTRASTIVE-VAL] Epoch 5: val_mse=0.0010
[PROBE-CONTRASTIVE] Epoch 6: train_mse=0.0009
[PROBE-CONTRASTIVE-VAL] Epoch 6: val_mse=0.0009
[PROBE-CONTRASTIVE] Epoch 7: train_mse=0.0009
[PROBE-CONTRASTIVE-VAL] Epoch 7: val_mse=0.0009
[PROBE-CONTRASTIVE] Epoch 8: train_mse=0.0008
[PROBE-CONTRASTIVE-VAL] Epoch 8: val_mse=0.0009
[PROBE-CONTRASTIVE] Epoch 9: train_mse=0.0008
[PROBE-CONTRASTIVE-VAL] Epoch 9: val_mse=0.0008
[PROBE-CONTRASTIVE] Epoch 10: train_mse=0.0008
[PROBE-CONTRASTIVE-VAL] Epoch 10: val_mse=0.0009
Saved ckpts/probe_contrastive.pt


## Evaluate end-to-end next-signal MSE

In [10]:
from scripts.eval import evaluate
evaluate(cfg, z_space, device)

[EVAL-contrastive] next-signal MSE = 0.001361


{'final_eval_loss': 0.0013614735392232737}