In [None]:
import os, sys
sys.path.append("..")

import torch
import numpy as np

from src.light_sb import LightSB
from src.distributions import LoaderSampler, TensorSampler
from src.plotters import plot_2D
from tqdm import tqdm
from sklearn.decomposition import PCA
from matplotlib import pyplot as plt

import wandb
from sklearn.preprocessing import StandardScaler
from sklearn.metrics.pairwise import pairwise_distances

## Config

In [None]:
DIM = 1000
assert DIM > 1

SEED = 42
BATCH_SIZE = 128
EPSILON = 0.1
D_LR = 1e-2
D_GRADIENT_MAX_NORM = float("inf")
N_POTENTIALS = 10
SAMPLING_BATCH_SIZE = 128
INIT_BY_SAMPLES = True  
IS_DIAGONAL = True
DAY_START = 2
DAY_END = 4
DAY_EVAL = 3
DEVICE = "cpu"
EVAL_EVERY = 10000
SERIES_ID = 1

MAX_STEPS = 10000
CONTINUE = -1

In [None]:
torch.manual_seed(SEED); np.random.seed(SEED)
EPS = EPSILON
EPSILON_END = EPSILON

In [None]:
EXP_NAME = f'LightSB_old_Single_Cell_MAX_STEPS_{MAX_STEPS}_full_CITE_cell_DIM_{DIM}_DAY_EVAL_{DAY_EVAL}_EPSILON_{EPSILON}_SEED_{SEED}'
OUTPUT_PATH = '../checkpoints/{}'.format(EXP_NAME)

config = dict(
    SERIES_ID=SERIES_ID,
    DAY_START=DAY_START,
    DAY_END=DAY_END,
    DAY_EVAL=DAY_EVAL,
    DIM=DIM,
    D_LR=D_LR,
    BATCH_SIZE=BATCH_SIZE,
    EPSILON=EPSILON,
    D_GRADIENT_MAX_NORM=D_GRADIENT_MAX_NORM,
    N_POTENTIALS=N_POTENTIALS,
    INIT_BY_SAMPLES=INIT_BY_SAMPLES,
    IS_DIAGONAL=IS_DIAGONAL,
    SEED=SEED,
)

if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

## Data loading

In [None]:
data = {}
for day in [2, 3, 4, 7]:
    data[day] = np.load(f"../data/full_cite_pcas_{DIM}_day_{day}.npy")
    
eval_data = data[DAY_EVAL]
start_data = data[DAY_START]
end_data = data[DAY_END]

constant_scale = np.concatenate([start_data, end_data, eval_data]).std(axis=0).mean()

eval_data_scaled = eval_data/constant_scale
start_data_scaled = start_data/constant_scale
end_data_scaled = end_data/constant_scale

eval_data = torch.tensor(eval_data).float()
start_data = torch.tensor(start_data_scaled).float()
end_data = torch.tensor(end_data_scaled).float()

X_sampler = TensorSampler(torch.tensor(start_data).float(), device="cpu")
Y_sampler = TensorSampler(torch.tensor(end_data).float(), device="cpu")

## Model initialisation

In [None]:
D = LightSB(dim=DIM, n_potentials=N_POTENTIALS, epsilon=EPSILON,
            sampling_batch_size=SAMPLING_BATCH_SIZE,
            is_diagonal=IS_DIAGONAL).cpu()

if INIT_BY_SAMPLES:
    D.init_r_by_samples(Y_sampler.sample(N_POTENTIALS).to(DEVICE))
    
D_opt = torch.optim.Adam(D.parameters(), lr=D_LR)

if CONTINUE > -1:
    D_opt.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f'D_opt_{SEED}_{CONTINUE}.pt')))

In [None]:
def mmd(x, y):
    Kxx = pairwise_distances(x, x)
    Kyy = pairwise_distances(y, y)
    Kxy = pairwise_distances(x, y)

    m = x.shape[0]
    n = y.shape[0]
    
    c1 = 1 / ( m * (m - 1))
    A = np.sum(Kxx - np.diag(np.diagonal(Kxx)))

    # Term II
    c2 = 1 / (n * (n - 1))
    B = np.sum(Kyy - np.diag(np.diagonal(Kyy)))

    # Term III
    c3 = 1 / (m * n)
    C = np.sum(Kxy)

    # estimate MMD
    mmd_est = -0.5*c1*A - 0.5*c2*B + c3*C
    
    return mmd_est

## Model training

In [None]:
wandb.init(name=EXP_NAME, config=config, project="LightSBplus")

for step in tqdm(range(CONTINUE + 1, MAX_STEPS)):
    # training cycle
    D_opt.zero_grad()
    
    X0, X1 = X_sampler.sample(BATCH_SIZE).to(DEVICE), Y_sampler.sample(BATCH_SIZE).to(DEVICE)
    
    log_potential = D.get_log_potential(X1)
    log_C = D.get_log_C(X0)
    
    D_loss = (-log_potential + log_C).mean()
    D_loss.backward()
    D_gradient_norm = torch.nn.utils.clip_grad_norm_(D.parameters(), max_norm=D_GRADIENT_MAX_NORM)
    D_opt.step()
    
    wandb.log({f'D gradient norm' : D_gradient_norm.item()}, step=step)
    wandb.log({f'D_loss_minibatch' : D_loss.item()}, step=step)

    # eval and plots
    if (step + 1) % EVAL_EVERY == 0:
        with torch.no_grad():
            X = X_sampler.sample(start_data.shape[0]).to(DEVICE)
            Y = Y_sampler.sample(end_data.shape[0]).to(DEVICE)
            
            XN = D(X)
            XN_pred = XN.detach().cpu().numpy()*constant_scale
            
            MMD_target = mmd(XN_pred, end_data*constant_scale)
            wandb.log({f'MMD_target' : MMD_target}, step=step)

            X_mid_pred = D.sample_at_time_moment(X, torch.ones(X.shape[0], 1)*(DAY_EVAL - DAY_START)/(DAY_END-DAY_START)).detach().cpu().numpy()
            X_mid_pred = X_mid_pred*constant_scale
            
            MMD = mmd(X_mid_pred, eval_data)
            wandb.log({f'MMD' : MMD}, step=step)

wandb.finish()