## 1. Imports

In [None]:
import os, sys
sys.path.append("..")

import torch
import wandb
import numpy as np
from tqdm import tqdm
from sklearn.datasets import make_swiss_roll, make_moons

from src.light_sb import LightSB
from src.plotters import plot_2D, plot_2D_mapping, plot_2D_trajectory
from src.distributions import StandardNormalSampler, SwissRollSampler

## 2. Config

In [None]:
DIM = 2
assert DIM > 1

OUTPUT_SEED = 42

N_POTENTIALS = 500
INIT_BY_SAMPLES = True
IS_DIAGONAL = True

BATCH_SIZE = 128
SAMPLING_BATCH_SIZE = 128

EPSILON = 0.002

D_LR = 3e-4 # 1e-3 for eps 0.1, 0.01 and 3e-4 for eps 0.002
D_GRADIENT_MAX_NORM = float("inf")

PLOT_EVERY = 500
MAX_STEPS = 20000
CONTINUE = -1

In [None]:
torch.manual_seed(OUTPUT_SEED); np.random.seed(OUTPUT_SEED)

EPS = EPSILON

In [None]:
EXP_NAME = f'LightSB_Swiss_Roll_EPSILON_{EPSILON}'
OUTPUT_PATH = '../checkpoints/{}'.format(EXP_NAME)

config = dict(
    DIM=DIM,
    D_LR=D_LR,
    BATCH_SIZE=BATCH_SIZE,
    EPSILON=EPSILON,
    D_GRADIENT_MAX_NORM=D_GRADIENT_MAX_NORM,
    N_POTENTIALS=N_POTENTIALS,
    INIT_BY_SAMPLES=INIT_BY_SAMPLES,
    IS_DIAGONAL=IS_DIAGONAL,
)

if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

## 3. Create samplers

In [None]:
X_sampler = StandardNormalSampler(dim=2, device="cpu")
Y_sampler = SwissRollSampler(dim=2, device="cpu")

## 4. Model initialization

In [None]:
D = LightSB(dim=DIM, n_potentials=N_POTENTIALS, epsilon=EPSILON,
            sampling_batch_size=SAMPLING_BATCH_SIZE, is_diagonal=IS_DIAGONAL)

if INIT_BY_SAMPLES:
    D.init_r_by_samples(Y_sampler.sample(N_POTENTIALS))

D_opt = torch.optim.Adam(D.parameters(), lr=D_LR)

if CONTINUE > -1:
    D_opt.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f'D_opt_{SEED}_{CONTINUE}.pt')))

## 5. Model training

In [None]:
wandb.init(name=EXP_NAME, config=config)

for step in tqdm(range(CONTINUE + 1, MAX_STEPS)):
    # training cycle 
    D_opt.zero_grad()
    
    X0, X1 = X_sampler.sample(BATCH_SIZE), Y_sampler.sample(BATCH_SIZE)
    
    log_potential = D.get_log_potential(X1)
    log_C = D.get_log_C(X0)
    
    D_loss = (-log_potential + log_C).mean()
    D_loss.backward()
    D_gradient_norm = torch.nn.utils.clip_grad_norm_(D.parameters(), max_norm=D_GRADIENT_MAX_NORM)
    D_opt.step()
    
    wandb.log({f'D gradient norm' : D_gradient_norm.item()}, step=step)
    wandb.log({f'D_loss' : D_loss.item()}, step=step)

torch.save(D.state_dict(), os.path.join(OUTPUT_PATH, f'D.pt'))
torch.save(D_opt.state_dict(), os.path.join(OUTPUT_PATH, f'D_opt.pt'))

wandb.finish()