In [1]:
import sys
import yaml
import numpy as np
from pathlib import Path

sys.path.append('../')

In [2]:
with open('../config.yaml', 'r') as f:
    cfg = yaml.safe_load(f)
cfg

{'data': {'D': 102,
  'signal_dim': 2,
  'noise_dim': 100,
  'traj_len': 200,
  'step_size': 0.1,
  'n_train': 300,
  'n_val': 30,
  'n_test': 30,
  'seed': 0,
  'out_dir': 'data/'},
 'model': {'z_dim_vae': 2,
  'z_dim_contrastive': 2,
  'enc_widths': [256, 256],
  'dec_widths': [256, 256],
  'proj_widths': [256, 256],
  'dyn_widths': [256, 256],
  'probe_widths': [64, 64],
  'activation': 'relu'},
 'train': {'batch_size': 256,
  'epochs_phase1': 10,
  'epochs_phase2': 10,
  'epochs_probe': 10,
  'num_workers': 0,
  'ckpt_dir': 'ckpts/',
  'vae': {'lr': 0.002, 'weight_decay': 0.0, 'beta': 0.001},
  'contrastive': {'lr': 0.001, 'weight_decay': 0.0, 'temperature': 0.1},
  'dynamics': {'lr': 0.001, 'weight_decay': 0.0},
  'probe': {'lr': 0.001, 'weight_decay': 0.0}},
 'wandb': {'enabled': True,
  'project': 'repr-world',
  'entity': None,
  'group': None,
  'mode': 'online',
  'dir': 'logs',
  'tags': ['toy', 'repr']}}

In [3]:
from scripts.generate_data import main as gen_main

gen_main(cfg)

Wrote data/train.npz with shapes s(60000, 102), a(60000,), sp(60000, 102)
Wrote data/val.npz with shapes s(6000, 102), a(6000,), sp(6000, 102)
Wrote data/test.npz with shapes s(6000, 102), a(6000,), sp(6000, 102)


## Quick data exploration

In [4]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

data_dir = Path(cfg['data']['out_dir'])
train = np.load(data_dir / 'train.npz')
val = np.load(data_dir / 'val.npz')
test = np.load(data_dir / 'test.npz')

s, a, sp = train['s'], train['a'], train['sp']
print("train shapes:", s.shape, a.shape, sp.shape)

print("state mean/std:", s.mean(), s.std())
print("signal mean/std:", s[:,0:2].mean(), s[:,0:2].std())
print("noise mean/std:", s[:,2:].mean(), s[:,2:].std())
print("actions min/max/unique:", a.min(), a.max(), np.unique(a, return_counts=True))

N = min(10000, s.shape[0])
idx = np.random.choice(s.shape[0], N, replace=False)

# plt.figure()
# plt.scatter(s[idx,0], s[idx,1], s=2, alpha=0.3)
# plt.title("Signal component scatter (N random points)")
# plt.xlabel("x")
# plt.ylabel("y")
# plt.show()

# plt.figure()
# plt.hist(s[idx,0], bins=50, alpha=0.5, label='signal-x')
# plt.hist(s[idx,1], bins=50, alpha=0.5, label='signal-y')
# plt.title("Signal dims hist (x,y) [random sample]")
# plt.xlabel("value")
# plt.ylabel("freq")
# plt.legend()
# plt.show()

# plt.figure()
# plt.hist(s[idx,2], bins=50)
# plt.title("One noise dim hist (s[:,2]) [random sample]")
# plt.xlabel("value")
# plt.ylabel("freq")
# plt.show()

train shapes: (60000, 102) (60000,) (60000, 102)
state mean/std: -0.000257664 1.0058259
signal mean/std: -0.028032197 1.2757925
noise mean/std: 0.0002978269 0.9996753
actions min/max/unique: 0 3 (array([0, 1, 2, 3]), array([15118, 14927, 14967, 14988]))


In [5]:
# The signal has a quasi-gaussian histogram. The best non-parametric guess for reconstruction is the mean.
# MSE of that guess is the variance.

((s[:, :2] - s[:, :2].mean(axis=0))**2).mean()

np.float32(1.6275393)

In [6]:
# Same applies to the noise dimensions.
# Global MSE:

((s[:, :] - s[:, :].mean(axis=0))**2).mean()

np.float32(1.0116515)

## Train Phase 1 (choose one)

In [7]:
from scripts.train import train_phase1_contrastive, train_phase1_vae, _maybe_init_wandb

device = 'mps'
z_space = 'vae'
wb = _maybe_init_wandb(cfg)

assert z_space == 'vae'
enc_path, vae_metrics = train_phase1_vae(cfg, device, wb)
print("VAE training metrics:", vae_metrics)

[34m[1mwandb[0m: Currently logged in as: [33mmattia-scardecchia[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Training VAE. Epochs: 10; Batch size: 256; Steps per epoch: 235.




[VAE] Epoch 1: train_loss=0.9889, train_recon=0.9869, train_kl=1.9792
[VAE-VAL] Epoch 1: val_loss=0.9837, val_recon=0.9816, val_kl=2.0385
[VAE] Epoch 2: train_loss=0.9828, train_recon=0.9807, train_kl=2.1483
[VAE-VAL] Epoch 2: val_loss=0.9829, val_recon=0.9808, val_kl=2.0965
[VAE] Epoch 3: train_loss=0.9817, train_recon=0.9795, train_kl=2.2630
[VAE-VAL] Epoch 3: val_loss=0.9823, val_recon=0.9801, val_kl=2.2315
[VAE] Epoch 4: train_loss=0.9810, train_recon=0.9786, train_kl=2.3724
[VAE-VAL] Epoch 4: val_loss=0.9818, val_recon=0.9796, val_kl=2.2471
[VAE] Epoch 5: train_loss=0.9806, train_recon=0.9781, train_kl=2.4845
[VAE-VAL] Epoch 5: val_loss=0.9814, val_recon=0.9791, val_kl=2.3371
[VAE] Epoch 6: train_loss=0.9802, train_recon=0.9776, train_kl=2.5688
[VAE-VAL] Epoch 6: val_loss=0.9816, val_recon=0.9791, val_kl=2.5205
[VAE] Epoch 7: train_loss=0.9799, train_recon=0.9773, train_kl=2.6140
[VAE-VAL] Epoch 7: val_loss=0.9813, val_recon=0.9788, val_kl=2.4781
[VAE] Epoch 8: train_loss=0.9796, 

## Train Phase 2 dynamics (on selected latent space)

In [8]:
from scripts.train import train_phase2_dynamics

dyn_path, dynamics_metrics = train_phase2_dynamics(cfg, device, z_space=z_space, wb=wb)
print("Dynamics training metrics:", dynamics_metrics)
dyn_path

[DYN-VAE] Epoch 1: train_mse=0.5386
[DYN-VAE-VAL] Epoch 1: val_mse=0.4423
[DYN-VAE] Epoch 2: train_mse=0.5174
[DYN-VAE-VAL] Epoch 2: val_mse=0.4404
[DYN-VAE] Epoch 3: train_mse=0.5161
[DYN-VAE-VAL] Epoch 3: val_mse=0.4398
[DYN-VAE] Epoch 4: train_mse=0.5142
[DYN-VAE-VAL] Epoch 4: val_mse=0.4347
[DYN-VAE] Epoch 5: train_mse=0.5129
[DYN-VAE-VAL] Epoch 5: val_mse=0.4400
[DYN-VAE] Epoch 6: train_mse=0.5124
[DYN-VAE-VAL] Epoch 6: val_mse=0.4404
[DYN-VAE] Epoch 7: train_mse=0.5122
[DYN-VAE-VAL] Epoch 7: val_mse=0.4465
[DYN-VAE] Epoch 8: train_mse=0.5121
[DYN-VAE-VAL] Epoch 8: val_mse=0.4306
[DYN-VAE] Epoch 9: train_mse=0.5114
[DYN-VAE-VAL] Epoch 9: val_mse=0.4362
[DYN-VAE] Epoch 10: train_mse=0.5109
[DYN-VAE-VAL] Epoch 10: val_mse=0.4333
Saved ckpts/dyn_vae.pt
Dynamics training metrics: {'final_dyn-vae_train_train_mse': 0.5108871298631033, 'final_dyn-vae_val_val_mse': 0.43334845765431723}


'ckpts/dyn_vae.pt'

## Train probe on frozen latents (z → R^2)

In [9]:
from scripts.train import train_probes
probe_path, probe_metrics = train_probes(cfg, device, z_space=z_space, wb=wb)
print("Probe training metrics:", probe_metrics)
probe_path

[PROBE-VAE] Epoch 1: train_mse=0.4972
[PROBE-VAE-VAL] Epoch 1: val_mse=0.3440
[PROBE-VAE] Epoch 2: train_mse=0.3759
[PROBE-VAE-VAL] Epoch 2: val_mse=0.3537
[PROBE-VAE] Epoch 3: train_mse=0.3691
[PROBE-VAE-VAL] Epoch 3: val_mse=0.3418
[PROBE-VAE] Epoch 4: train_mse=0.3657
[PROBE-VAE-VAL] Epoch 4: val_mse=0.3570
[PROBE-VAE] Epoch 5: train_mse=0.3640
[PROBE-VAE-VAL] Epoch 5: val_mse=0.3416
[PROBE-VAE] Epoch 6: train_mse=0.3629
[PROBE-VAE-VAL] Epoch 6: val_mse=0.3583
[PROBE-VAE] Epoch 7: train_mse=0.3622
[PROBE-VAE-VAL] Epoch 7: val_mse=0.3512
[PROBE-VAE] Epoch 8: train_mse=0.3619
[PROBE-VAE-VAL] Epoch 8: val_mse=0.3500
[PROBE-VAE] Epoch 9: train_mse=0.3616
[PROBE-VAE-VAL] Epoch 9: val_mse=0.3464
[PROBE-VAE] Epoch 10: train_mse=0.3611
[PROBE-VAE-VAL] Epoch 10: val_mse=0.3502
Saved ckpts/probe_vae.pt
Probe training metrics: {'final_probe-vae_train_train_mse': 0.3611258747895559, 'final_probe-vae_val_val_mse': 0.3502463915348053}


'ckpts/probe_vae.pt'

## Evaluate end-to-end next-signal MSE

In [10]:
from scripts.eval import evaluate
eval_metrics = evaluate(cfg, z_space, device)
print("Evaluation metrics:", eval_metrics)

[EVAL-vae] next-signal MSE = 0.537184
Evaluation metrics: {'final_eval_loss': 0.5371844995816548}
