In [1]:
%load_ext autoreload

In [23]:
%autoreload 2

import torch
from torch import nn, optim
from torchdrive.tasks.diff_traj import XYMLPEncoder


device = torch.device("cuda")
MAX_DIST = 128
DIM = 1024
BS = 16000
m = XYMLPEncoder(dim=DIM, max_dist=MAX_DIST).to(device)

optimizer = optim.AdamW(m.parameters(), lr=1e-4)

In [27]:
# theoretical
bucket_m = MAX_DIST*2/(DIM/2)
bucket_m

0.5

In [38]:
BATCHES = 2000
LOG_EVERY = 100

for i in range(BATCHES):
    should_log = i % LOG_EVERY == 0
    optimizer.zero_grad()
    
    batch = (torch.rand(BS, 1, 2, device=device) - 0.5) * (2 * MAX_DIST)

    encoded = m(batch)

    encoded_mag = torch.linalg.vector_norm(encoded, dim=-1).mean()
    # scale by 0-1
    noise_scale = encoded_mag * torch.rand(BS, device=device)
    noise_scale = noise_scale.unsqueeze(1).unsqueeze(1)
    
    noise = torch.randn_like(encoded)
    noise = noise * noise_scale

    
    if should_log:
        print(f"{i} - magnitude {encoded_mag}, scale {noise_scale.mean().item()}")

    noise_encoded = encoded + noise
    
    variations = {
        "normal": encoded,
        "noise": noise_encoded,
    }
    losses = {}
    for k, enc in variations.items():
        loss = m.loss(enc, batch).mean()
        mae = torch.linalg.vector_norm(m.decode(enc)-batch, dim=-1).mean()
        if should_log:
            print(f"{i} - {k}: loss {loss.item()}, mae {mae.item()}")

        losses[k] = loss

    total_loss = sum(losses.values())
    total_loss.backward()

    optimizer.step()

    if should_log:
        print(i, "batch", total_loss.item())
        print()
    
    

0 - magnitude 16.623794555664062, scale 8.333362579345703
0 - normal: loss 0.011047407984733582, mae 0.3838900923728943
0 - noise: loss 9.734607696533203, mae 110.0469741821289
0 batch 9.745655059814453

100 - magnitude 16.00202178955078, scale 8.027247428894043
100 - normal: loss 0.01364201307296753, mae 0.3835380971431732
100 - noise: loss 9.665282249450684, mae 109.15914154052734
100 batch 9.678924560546875

200 - magnitude 15.547680854797363, scale 7.71797513961792
200 - normal: loss 0.013021253049373627, mae 0.3811604678630829
200 - noise: loss 9.588726043701172, mae 108.6501693725586
200 batch 9.601747512817383

300 - magnitude 15.0445556640625, scale 7.536011219024658
300 - normal: loss 0.013400120660662651, mae 0.3808329105377197
300 - noise: loss 9.600756645202637, mae 109.33332061767578
300 batch 9.614156723022461

400 - magnitude 14.528416633605957, scale 7.243579387664795
400 - normal: loss 0.013808717019855976, mae 0.38157016038894653
400 - noise: loss 9.557799339294434, m

KeyboardInterrupt: 

In [41]:
from safetensors.torch import save_file

save_file(m.state_dict(), "xy_mlp_vae.safetensors")