In [1]:
import sys
import yaml
import numpy as np
from pathlib import Path

sys.path.append('../')

In [2]:
with open('../config.yaml', 'r') as f:
    cfg = yaml.safe_load(f)
cfg

{'data': {'D': 102,
  'signal_dim': 2,
  'noise_dim': 100,
  'traj_len': 200,
  'step_size': 0.1,
  'n_train': 300,
  'n_val': 30,
  'n_test': 30,
  'seed': 0,
  'out_dir': 'data/'},
 'model': {'z_dim_vae': 2,
  'z_dim_contrastive': 2,
  'enc_widths': [256, 256],
  'dec_widths': [256, 256],
  'proj_widths': [256, 256],
  'dyn_widths': [256, 256],
  'probe_widths': [64, 64],
  'activation': 'relu'},
 'train': {'batch_size': 256,
  'epochs_phase1': 10,
  'epochs_phase2': 10,
  'epochs_probe': 10,
  'num_workers': 0,
  'ckpt_dir': 'ckpts/',
  'vae': {'lr': 0.002, 'weight_decay': 0.0, 'beta': 0.001},
  'contrastive': {'lr': 0.001, 'weight_decay': 0.0, 'temperature': 0.1},
  'dynamics': {'lr': 0.001, 'weight_decay': 0.0},
  'probe': {'lr': 0.001, 'weight_decay': 0.0}},
 'wandb': {'enabled': True,
  'project': 'repr-world',
  'entity': None,
  'group': None,
  'mode': 'online',
  'dir': 'logs',
  'tags': ['toy', 'repr']}}

In [3]:
from scripts.generate_data import main as gen_main

gen_main(cfg)

Wrote data/train.npz with shapes s(60000, 102), a(60000,), sp(60000, 102)
Wrote data/val.npz with shapes s(6000, 102), a(6000,), sp(6000, 102)
Wrote data/test.npz with shapes s(6000, 102), a(6000,), sp(6000, 102)


## Quick data exploration

In [4]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

data_dir = Path(cfg['data']['out_dir'])
train = np.load(data_dir / 'train.npz')
val = np.load(data_dir / 'val.npz')
test = np.load(data_dir / 'test.npz')

s, a, sp = train['s'], train['a'], train['sp']
print("train shapes:", s.shape, a.shape, sp.shape)

print("state mean/std:", s.mean(), s.std())
print("signal mean/std:", s[:,0:2].mean(), s[:,0:2].std())
print("noise mean/std:", s[:,2:].mean(), s[:,2:].std())
print("actions min/max/unique:", a.min(), a.max(), np.unique(a, return_counts=True))

N = min(10000, s.shape[0])
idx = np.random.choice(s.shape[0], N, replace=False)

# plt.figure()
# plt.scatter(s[idx,0], s[idx,1], s=2, alpha=0.3)
# plt.title("Signal component scatter (N random points)")
# plt.xlabel("x")
# plt.ylabel("y")
# plt.show()

# plt.figure()
# plt.hist(s[idx,0], bins=50, alpha=0.5, label='signal-x')
# plt.hist(s[idx,1], bins=50, alpha=0.5, label='signal-y')
# plt.title("Signal dims hist (x,y) [random sample]")
# plt.xlabel("value")
# plt.ylabel("freq")
# plt.legend()
# plt.show()

# plt.figure()
# plt.hist(s[idx,2], bins=50)
# plt.title("One noise dim hist (s[:,2]) [random sample]")
# plt.xlabel("value")
# plt.ylabel("freq")
# plt.show()

train shapes: (60000, 102) (60000,) (60000, 102)
state mean/std: -0.000257664 1.0058259
signal mean/std: -0.028032197 1.2757925
noise mean/std: 0.0002978269 0.9996753
actions min/max/unique: 0 3 (array([0, 1, 2, 3]), array([15118, 14927, 14967, 14988]))


In [5]:
# The signal has a quasi-gaussian histogram. The best non-parametric guess for reconstruction is the mean.
# MSE of that guess is the variance.

((s[:, :2] - s[:, :2].mean(axis=0))**2).mean()

np.float32(1.6275393)

In [6]:
# Same applies to the noise dimensions.
# Global MSE:

((s[:, :] - s[:, :].mean(axis=0))**2).mean()

np.float32(1.0116515)

## Train Phase 1 (choose one)

In [7]:
from scripts.train import train_phase1_contrastive, train_phase1_vae, _maybe_init_wandb

device = 'mps'
z_space = 'vae'
wb = _maybe_init_wandb(cfg)

assert z_space == 'vae'
enc_path, vae_metrics = train_phase1_vae(cfg, device, wb)
print("VAE training metrics:", vae_metrics)

[34m[1mwandb[0m: Currently logged in as: [33mmattia-scardecchia[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Training VAE. Epochs: 10; Batch size: 256; Steps per epoch: 235.




[VAE] Epoch 1: loss=0.9886 (recon=0.9866, kl=1.9974)
[VAE] Epoch 2: loss=0.9827 (recon=0.9806, kl=2.1695)
[VAE] Epoch 3: loss=0.9817 (recon=0.9795, kl=2.2874)
[VAE] Epoch 4: loss=0.9810 (recon=0.9786, kl=2.4247)
[VAE] Epoch 5: loss=0.9806 (recon=0.9780, kl=2.5227)
[VAE] Epoch 6: loss=0.9802 (recon=0.9776, kl=2.6111)
[VAE] Epoch 7: loss=0.9799 (recon=0.9772, kl=2.6901)
[VAE] Epoch 8: loss=0.9796 (recon=0.9768, kl=2.7303)
[VAE] Epoch 9: loss=0.9793 (recon=0.9765, kl=2.8001)
[VAE] Epoch 10: loss=0.9791 (recon=0.9763, kl=2.8453)
Saved ckpts/vae.pt
VAE training metrics: {'final_vae_loss': 0.9791400657335917, 'final_vae_recon': 0.9762948108990988, 'final_vae_kl': 2.8452555030822753}


## Train Phase 2 dynamics (on selected latent space)

In [8]:
from scripts.train import train_phase2_dynamics

dyn_path, dynamics_metrics = train_phase2_dynamics(cfg, device, z_space=z_space, wb=wb)
print("Dynamics training metrics:", dynamics_metrics)
dyn_path

[DYN-VAE] Epoch 1: mse=0.5117
[DYN-VAE] Epoch 2: mse=0.4898
[DYN-VAE] Epoch 3: mse=0.4851
[DYN-VAE] Epoch 4: mse=0.4837
[DYN-VAE] Epoch 5: mse=0.4831
[DYN-VAE] Epoch 6: mse=0.4820
[DYN-VAE] Epoch 7: mse=0.4820
[DYN-VAE] Epoch 8: mse=0.4807
[DYN-VAE] Epoch 9: mse=0.4804
[DYN-VAE] Epoch 10: mse=0.4811
Saved ckpts/dyn_vae.pt
Dynamics training metrics: {'final_dyn-vae_mse': 0.4811010751247406}


'ckpts/dyn_vae.pt'

## Train probe on frozen latents (z → R^2)

In [9]:
from scripts.train import train_probes
probe_path, probe_metrics = train_probes(cfg, device, z_space=z_space, wb=wb)
print("Probe training metrics:", probe_metrics)
probe_path

[PROBE-VAE] Epoch 1: mse=0.4810
[PROBE-VAE] Epoch 2: mse=0.3379
[PROBE-VAE] Epoch 3: mse=0.3270
[PROBE-VAE] Epoch 4: mse=0.3233
[PROBE-VAE] Epoch 5: mse=0.3207
[PROBE-VAE] Epoch 6: mse=0.3201
[PROBE-VAE] Epoch 7: mse=0.3195
[PROBE-VAE] Epoch 8: mse=0.3192
[PROBE-VAE] Epoch 9: mse=0.3193
[PROBE-VAE] Epoch 10: mse=0.3186
Saved ckpts/probe_vae.pt
Probe training metrics: {'final_probe-vae_mse': 0.31855003072420757}


'ckpts/probe_vae.pt'

## Evaluate end-to-end next-signal MSE

In [10]:
from scripts.eval import evaluate
eval_metrics = evaluate(cfg, z_space, device)
print("Evaluation metrics:", eval_metrics)

[EVAL-vae] next-signal MSE = 0.489901
Evaluation metrics: {'final_eval_loss': 0.48990092452367145}
