In [1]:
import sys
import yaml
import numpy as np
import torch
from pathlib import Path

sys.path.append("../")

In [2]:
with open('../config.yaml', 'r') as f:
    cfg = yaml.safe_load(f)
cfg


{'data': {'D': 102,
  'signal_dim': 2,
  'noise_dim': 100,
  'traj_len': 200,
  'step_size': 0.1,
  'n_train': 300,
  'n_val': 30,
  'n_test': 30,
  'seed': 0,
  'out_dir': 'data/'},
 'model': {'z_dim_vae': 2,
  'z_dim_contrastive': 2,
  'enc_widths': [256, 256],
  'dec_widths': [256, 256],
  'proj_widths': [256, 256],
  'dyn_widths': [256, 256],
  'probe_widths': [64, 64],
  'activation': 'relu'},
 'train': {'batch_size': 256,
  'epochs_phase1': 20,
  'epochs_phase2': 20,
  'epochs_probe': 20,
  'num_workers': 0,
  'ckpt_dir': 'ckpts/',
  'vae': {'lr': 0.002, 'weight_decay': 0.0, 'beta': 0.001},
  'contrastive': {'lr': 0.001, 'weight_decay': 0.0, 'temperature': 0.1},
  'dynamics': {'lr': 0.001, 'weight_decay': 0.0},
  'probe': {'lr': 0.001, 'weight_decay': 0.0}},
 'wandb': {'enabled': False,
  'project': 'repr-world',
  'entity': None,
  'group': None,
  'mode': 'online',
  'dir': 'logs',
  'tags': ['toy', 'repr']}}

In [3]:
from scripts.generate_data import main as gen_main
gen_main(cfg)

Wrote data/train.npz with shapes s(60000, 102), a(60000,), sp(60000, 102)
Wrote data/val.npz with shapes s(6000, 102), a(6000,), sp(6000, 102)
Wrote data/test.npz with shapes s(6000, 102), a(6000,), sp(6000, 102)


## Quick data exploration

In [4]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

data_dir = Path(cfg['data']['out_dir'])
train = np.load(data_dir / 'train.npz')
val = np.load(data_dir / 'val.npz')
test = np.load(data_dir / 'test.npz')

s, a, sp = train['s'], train['a'], train['sp']
print("train shapes:", s.shape, a.shape, sp.shape)

print("state mean/std:", s.mean(), s.std())
print("signal mean/std:", s[:,0:2].mean(), s[:,0:2].std())
print("noise mean/std:", s[:,2:].mean(), s[:,2:].std())
print("actions min/max/unique:", a.min(), a.max(), np.unique(a, return_counts=True))

N = min(10000, s.shape[0])
idx = np.random.choice(s.shape[0], N, replace=False)

# plt.figure()
# plt.scatter(s[idx,0], s[idx,1], s=2, alpha=0.3)
# plt.title("Signal component scatter (N random points)")
# plt.xlabel("x")
# plt.ylabel("y")
# plt.show()

# plt.figure()
# plt.hist(s[idx,0], bins=50, alpha=0.5, label='signal-x')
# plt.hist(s[idx,1], bins=50, alpha=0.5, label='signal-y')
# plt.title("Signal dims hist (x,y) [random sample]")
# plt.xlabel("value")
# plt.ylabel("freq")
# plt.legend()
# plt.show()

# plt.figure()
# plt.hist(s[idx,2], bins=50)
# plt.title("One noise dim hist (s[:,2]) [random sample]")
# plt.xlabel("value")
# plt.ylabel("freq")
# plt.show()

train shapes: (60000, 102) (60000,) (60000, 102)
state mean/std: -0.000257664 1.0058259
signal mean/std: -0.028032197 1.2757925
noise mean/std: 0.0002978269 0.9996753
actions min/max/unique: 0 3 (array([0, 1, 2, 3]), array([15118, 14927, 14967, 14988]))


In [5]:
# The signal has a quasi-gaussian histogram. The best non-parametric guess for reconstruction is the mean.
# MSE of that guess is the variance.

((s[:, :2] - s[:, :2].mean(axis=0))**2).mean()

np.float32(1.6275393)

In [6]:
# Same applies to the noise dimensions.
# Global MSE:

((s[:, :] - s[:, :].mean(axis=0))**2).mean()

np.float32(1.0116515)

## Train Phase 1 (choose one)

In [7]:
import wandb
from scripts.train import train_phase1_contrastive, _maybe_init_wandb
import torch

device = 'mps'
z_space = 'contrastive'
wb = _maybe_init_wandb(cfg)

phi_path, g_path = train_phase1_contrastive(cfg, device, wb)



[CONTRASTIVE] Epoch 1: loss=2.4592
[CONTRASTIVE] Epoch 2: loss=0.8990
[CONTRASTIVE] Epoch 3: loss=0.6245
[CONTRASTIVE] Epoch 4: loss=0.5218
[CONTRASTIVE] Epoch 5: loss=0.4607
[CONTRASTIVE] Epoch 6: loss=0.4327
[CONTRASTIVE] Epoch 7: loss=0.3447
[CONTRASTIVE] Epoch 8: loss=0.4316
[CONTRASTIVE] Epoch 9: loss=0.2974
[CONTRASTIVE] Epoch 10: loss=0.3488
[CONTRASTIVE] Epoch 11: loss=0.4512
[CONTRASTIVE] Epoch 12: loss=0.2297
[CONTRASTIVE] Epoch 13: loss=0.2885
[CONTRASTIVE] Epoch 14: loss=0.3023
[CONTRASTIVE] Epoch 15: loss=0.3611
[CONTRASTIVE] Epoch 16: loss=0.3395
[CONTRASTIVE] Epoch 17: loss=0.1810
[CONTRASTIVE] Epoch 18: loss=0.4450
[CONTRASTIVE] Epoch 19: loss=0.1908
[CONTRASTIVE] Epoch 20: loss=0.2149
Saved ckpts/contrastive_phi.pt
Saved ckpts/contrastive_g.pt


## Train Phase 2 dynamics (on selected latent space)

In [8]:
from scripts.train import train_phase2_dynamics

dyn_path = train_phase2_dynamics(cfg, device, z_space=z_space, wb=wb)
dyn_path

[DYN-CONTRASTIVE] Epoch 1: mse=1.2058
[DYN-CONTRASTIVE] Epoch 2: mse=0.0624
[DYN-CONTRASTIVE] Epoch 3: mse=0.0629
[DYN-CONTRASTIVE] Epoch 4: mse=0.0629
[DYN-CONTRASTIVE] Epoch 5: mse=0.0625
[DYN-CONTRASTIVE] Epoch 6: mse=0.0632
[DYN-CONTRASTIVE] Epoch 7: mse=0.0645
[DYN-CONTRASTIVE] Epoch 8: mse=0.0634
[DYN-CONTRASTIVE] Epoch 9: mse=0.0642
[DYN-CONTRASTIVE] Epoch 10: mse=0.0635
[DYN-CONTRASTIVE] Epoch 11: mse=0.0642
[DYN-CONTRASTIVE] Epoch 12: mse=0.0646
[DYN-CONTRASTIVE] Epoch 13: mse=0.0638
[DYN-CONTRASTIVE] Epoch 14: mse=0.0643
[DYN-CONTRASTIVE] Epoch 15: mse=0.0648
[DYN-CONTRASTIVE] Epoch 16: mse=0.0659
[DYN-CONTRASTIVE] Epoch 17: mse=0.0640
[DYN-CONTRASTIVE] Epoch 18: mse=0.0659
[DYN-CONTRASTIVE] Epoch 19: mse=0.0641
[DYN-CONTRASTIVE] Epoch 20: mse=0.0658
Saved ckpts/dyn_contrastive.pt


('ckpts/dyn_contrastive.pt',
 {'final_dyn-contrastive_mse': 0.06576383593082429})

## Train probe on frozen latents (z → R^2)

In [9]:
from scripts.train import train_probes
probe_path = train_probes(cfg, device, z_space=z_space, wb=wb)
probe_path

[PROBE-CONTRASTIVE] Epoch 1: mse=0.0335
[PROBE-CONTRASTIVE] Epoch 2: mse=0.0013
[PROBE-CONTRASTIVE] Epoch 3: mse=0.0010
[PROBE-CONTRASTIVE] Epoch 4: mse=0.0010
[PROBE-CONTRASTIVE] Epoch 5: mse=0.0010
[PROBE-CONTRASTIVE] Epoch 6: mse=0.0009
[PROBE-CONTRASTIVE] Epoch 7: mse=0.0009
[PROBE-CONTRASTIVE] Epoch 8: mse=0.0010
[PROBE-CONTRASTIVE] Epoch 9: mse=0.0010
[PROBE-CONTRASTIVE] Epoch 10: mse=0.0010
[PROBE-CONTRASTIVE] Epoch 11: mse=0.0010
[PROBE-CONTRASTIVE] Epoch 12: mse=0.0010
[PROBE-CONTRASTIVE] Epoch 13: mse=0.0010
[PROBE-CONTRASTIVE] Epoch 14: mse=0.0010
[PROBE-CONTRASTIVE] Epoch 15: mse=0.0010
[PROBE-CONTRASTIVE] Epoch 16: mse=0.0010
[PROBE-CONTRASTIVE] Epoch 17: mse=0.0010
[PROBE-CONTRASTIVE] Epoch 18: mse=0.0010
[PROBE-CONTRASTIVE] Epoch 19: mse=0.0010
[PROBE-CONTRASTIVE] Epoch 20: mse=0.0010
Saved ckpts/probe_contrastive.pt


('ckpts/probe_contrastive.pt',
 {'final_probe-contrastive_mse': 0.0009764376486341158})

## Evaluate end-to-end next-signal MSE

In [10]:
from scripts.eval import evaluate
evaluate(cfg, z_space, device)

[EVAL-contrastive] next-signal MSE = 0.001231


{'final_eval_loss': 0.001231414869427681}