In [1]:
import sys
import yaml
import numpy as np
import torch
from pathlib import Path

sys.path.append("../")

In [2]:
with open('../config.yaml', 'r') as f:
    cfg = yaml.safe_load(f)
cfg


{'data': {'D': 102,
  'signal_dim': 2,
  'noise_dim': 100,
  'traj_len': 200,
  'step_size': 0.1,
  'n_train': 300,
  'n_val': 30,
  'n_test': 30,
  'seed': 0,
  'out_dir': 'data/'},
 'model': {'z_dim_vae': 2,
  'z_dim_contrastive': 2,
  'enc_widths': [256, 256],
  'dec_widths': [256, 256],
  'proj_widths': [256, 256],
  'dyn_widths': [256, 256],
  'probe_widths': [64, 64],
  'activation': 'relu'},
 'train': {'batch_size': 256,
  'epochs_phase1': 10,
  'epochs_phase2': 10,
  'epochs_probe': 10,
  'num_workers': 0,
  'ckpt_dir': 'ckpts/',
  'vae': {'lr': 0.002, 'weight_decay': 0.0, 'beta': 0.001},
  'contrastive': {'lr': 0.001, 'weight_decay': 0.0, 'temperature': 0.1},
  'dynamics': {'lr': 0.001, 'weight_decay': 0.0},
  'probe': {'lr': 0.001, 'weight_decay': 0.0}},
 'wandb': {'enabled': True,
  'project': 'repr-world',
  'entity': None,
  'group': None,
  'mode': 'online',
  'dir': 'logs',
  'tags': ['toy', 'repr']}}

In [3]:
from scripts.generate_data import main as gen_main
gen_main(cfg)

Wrote data/train.npz with shapes s(60000, 102), a(60000,), sp(60000, 102)
Wrote data/val.npz with shapes s(6000, 102), a(6000,), sp(6000, 102)
Wrote data/test.npz with shapes s(6000, 102), a(6000,), sp(6000, 102)


## Quick data exploration

In [4]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

data_dir = Path(cfg['data']['out_dir'])
train = np.load(data_dir / 'train.npz')
val = np.load(data_dir / 'val.npz')
test = np.load(data_dir / 'test.npz')

s, a, sp = train['s'], train['a'], train['sp']
print("train shapes:", s.shape, a.shape, sp.shape)

print("state mean/std:", s.mean(), s.std())
print("signal mean/std:", s[:,0:2].mean(), s[:,0:2].std())
print("noise mean/std:", s[:,2:].mean(), s[:,2:].std())
print("actions min/max/unique:", a.min(), a.max(), np.unique(a, return_counts=True))

N = min(10000, s.shape[0])
idx = np.random.choice(s.shape[0], N, replace=False)

# plt.figure()
# plt.scatter(s[idx,0], s[idx,1], s=2, alpha=0.3)
# plt.title("Signal component scatter (N random points)")
# plt.xlabel("x")
# plt.ylabel("y")
# plt.show()

# plt.figure()
# plt.hist(s[idx,0], bins=50, alpha=0.5, label='signal-x')
# plt.hist(s[idx,1], bins=50, alpha=0.5, label='signal-y')
# plt.title("Signal dims hist (x,y) [random sample]")
# plt.xlabel("value")
# plt.ylabel("freq")
# plt.legend()
# plt.show()

# plt.figure()
# plt.hist(s[idx,2], bins=50)
# plt.title("One noise dim hist (s[:,2]) [random sample]")
# plt.xlabel("value")
# plt.ylabel("freq")
# plt.show()

train shapes: (60000, 102) (60000,) (60000, 102)
state mean/std: -0.000257664 1.0058259
signal mean/std: -0.028032197 1.2757925
noise mean/std: 0.0002978269 0.9996753
actions min/max/unique: 0 3 (array([0, 1, 2, 3]), array([15118, 14927, 14967, 14988]))


In [5]:
# The signal has a quasi-gaussian histogram. The best non-parametric guess for reconstruction is the mean.
# MSE of that guess is the variance.

((s[:, :2] - s[:, :2].mean(axis=0))**2).mean()

np.float32(1.6275393)

In [6]:
# Same applies to the noise dimensions.
# Global MSE:

((s[:, :] - s[:, :].mean(axis=0))**2).mean()

np.float32(1.0116515)

## Train Phase 1 (choose one)

In [7]:
from scripts.train import train_phase1_contrastive, _maybe_init_wandb

device = 'mps'
z_space = 'contrastive'
wb = _maybe_init_wandb(cfg)

phi_path, g_path, contrastive_metrics = train_phase1_contrastive(cfg, device, wb)

[34m[1mwandb[0m: Currently logged in as: [33mmattia-scardecchia[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




[CONTRASTIVE] Epoch 1: loss=2.5640
[CONTRASTIVE] Epoch 2: loss=0.9152
[CONTRASTIVE] Epoch 3: loss=0.6395
[CONTRASTIVE] Epoch 4: loss=0.5176
[CONTRASTIVE] Epoch 5: loss=0.4570
[CONTRASTIVE] Epoch 6: loss=0.3891
[CONTRASTIVE] Epoch 7: loss=0.4120
[CONTRASTIVE] Epoch 8: loss=0.3330
[CONTRASTIVE] Epoch 9: loss=0.4311
[CONTRASTIVE] Epoch 10: loss=0.2609
Saved ckpts/contrastive_phi.pt
Saved ckpts/contrastive_g.pt


## Train Phase 2 dynamics (on selected latent space)

In [8]:
from scripts.train import train_phase2_dynamics

dyn_path, dynamics_metrics = train_phase2_dynamics(cfg, device, z_space=z_space, wb=wb)

[DYN-CONTRASTIVE] Epoch 1: mse=1.0759
[DYN-CONTRASTIVE] Epoch 2: mse=0.0547
[DYN-CONTRASTIVE] Epoch 3: mse=0.0543
[DYN-CONTRASTIVE] Epoch 4: mse=0.0545
[DYN-CONTRASTIVE] Epoch 5: mse=0.0552
[DYN-CONTRASTIVE] Epoch 6: mse=0.0548
[DYN-CONTRASTIVE] Epoch 7: mse=0.0554
[DYN-CONTRASTIVE] Epoch 8: mse=0.0562
[DYN-CONTRASTIVE] Epoch 9: mse=0.0556
[DYN-CONTRASTIVE] Epoch 10: mse=0.0560
Saved ckpts/dyn_contrastive.pt


## Train probe on frozen latents (z → R^2)

In [9]:
from scripts.train import train_probes
probe_path, probe_metrics = train_probes(cfg, device, z_space=z_space, wb=wb)

[PROBE-CONTRASTIVE] Epoch 1: mse=0.0443
[PROBE-CONTRASTIVE] Epoch 2: mse=0.0015
[PROBE-CONTRASTIVE] Epoch 3: mse=0.0011
[PROBE-CONTRASTIVE] Epoch 4: mse=0.0011
[PROBE-CONTRASTIVE] Epoch 5: mse=0.0011
[PROBE-CONTRASTIVE] Epoch 6: mse=0.0011
[PROBE-CONTRASTIVE] Epoch 7: mse=0.0011
[PROBE-CONTRASTIVE] Epoch 8: mse=0.0011
[PROBE-CONTRASTIVE] Epoch 9: mse=0.0010
[PROBE-CONTRASTIVE] Epoch 10: mse=0.0011
Saved ckpts/probe_contrastive.pt


## Evaluate end-to-end next-signal MSE

In [10]:
from scripts.eval import evaluate
evaluate(cfg, z_space, device)

[EVAL-contrastive] next-signal MSE = 0.001272


{'final_eval_loss': 0.0012724682620416086}