In [None]:
import torch
from torchdiffeq import odeint_adjoint as odeint
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import math
import time
import sys
sys.path.append('../')
from typing import Tuple, Any


# import os



%load_ext autoreload
%autoreload 2


import interflow as itf
import interflow.prior as prior
import interflow.fabrics
import interflow.stochastic_interpolant as stochastic_interpolant
from torch import autograd
from functorch import jacfwd, vmap


if torch.cuda.is_available():
    print('CUDA available, setting default tensor residence to GPU.')
    itf.util.set_torch_device('cuda')
else:
    print('No CUDA device found!')
print(itf.util.get_torch_device())


print("Torch version:", torch.__version__)

## Utility functions

In [None]:
def grab(var):
    """Take a tensor off the gpu and convert it to a numpy array on the CPU."""
    return var.detach().cpu().numpy()


def _compute_mu(i):
    """Compute the ith mean for a GMM."""
    return 10.0 * torch.Tensor([[
        torch.tensor(i * math.pi / 4).sin(),
        torch.tensor(i * math.pi / 4).cos()
    ]])


def compute_likelihoods(
    v: torch.nn.Module,
    s: torch.nn.Module,
    interpolant: stochastic_interpolant.Interpolant,
    n_save: int,
    n_likelihood: int,
    eps: int,
    bs: int
) -> Tuple[torch.tensor, torch.tensor]:
    """Draw samples from the probability flow and SDE models, and compute likelihoods."""
    sde_flow = stochastic_interpolant.SDEIntegrator(
        v=v, s=s, dt=torch.tensor(1e-2), eps=eps, interpolant=interpolant, n_save=n_save, n_likelihood=n_likelihood
    )

    pflow = stochastic_interpolant.PFlowIntegrator(v=v, s=s,  
                                                  method='dopri5', 
                                                  interpolant=interpolant,
                                                  n_step=3)
    
    with torch.no_grad():
        x0_tests  = base(bs)
        xfs_sde   = sde_flow.rollout_forward(x0_tests) # [n_save x bs x dim]
        xf_sde    = grab(xfs_sde[-1].squeeze())        # [bs x dim]
        
        # ([n_likelihood, bs, dim], [bs])
        x0s_sdeflow, dlogps_sdeflow = sde_flow.rollout_likelihood(xfs_sde[-1])
        log_p0s = torch.reshape(
            base.log_prob(x0s_sdeflow.reshape((n_likelihood*bs, ndim))),
            (n_likelihood, bs)
        )
        logpx_sdeflow = torch.mean(log_p0s, axis=0) - dlogps_sdeflow


    logp0                  = base.log_prob(x0_tests)            # [bs]
    xfs_pflow, dlogp_pflow = pflow.rollout(x0_tests)            # [n_save x bs x dim], [n_save x bs]
    logpx_pflow            = logp0 + dlogp_pflow[-1].squeeze()  # [bs]
    xf_pflow               = grab(xfs_pflow[-1].squeeze())      # [bs x dim]


    return xf_sde, logpx_sdeflow, xf_pflow, logpx_pflow


def log_metrics(
    v: torch.nn.Module,
    s: torch.nn.Module,
    interpolant: stochastic_interpolant.Interpolant,
    n_save: int,
    n_likelihood: int,
    likelihood_bs: int, 
    v_loss: torch.tensor,
    s_loss: torch.tensor,
    loss: torch.tensor,
    v_grad: torch.tensor,
    s_grad: torch.tensor,
    eps: torch.tensor,
    data_dict: dict
) -> None:
    # log loss and gradient data
    v_loss   = grab(v_loss).mean(); data_dict['v_losses'].append(v_loss)
    s_loss   = grab(s_loss).mean(); data_dict['s_losses'].append(s_loss)
    loss     = grab(loss).mean(); data_dict['losses'].append(loss)
    v_grad   = grab(v_grad).mean(); data_dict['v_grads'].append(v_grad)
    s_grad   = grab(s_grad).mean(); data_dict['s_grads'].append(s_grad)

    
    # compute and log likelihood data
    _, logpx_sdeflow, _, logpx_pflow = compute_likelihoods(
        v, s, interpolant, n_save, n_likelihood, eps, likelihood_bs)
    
    logpx_sdeflow = grab(logpx_sdeflow).mean(); data_dict['logps_sdeflow'].append(logpx_sdeflow)
    logpx_pflow = grab(logpx_pflow).mean(); data_dict['logps_pflow'].append(logpx_pflow)
    
    
def make_plots(
    v: torch.nn.Module,
    s: torch.nn.Module,
    interpolant: stochastic_interpolant.Interpolant,
    n_save: int,
    n_likelihood: int,
    likelihood_bs: int,
    counter: int,
    metrics_freq: int,
    eps: torch.tensor,
    data_dict: dict
) -> None:
    """Make plots to visualize samples and evolution of the likelihood."""
    # compute likelihood and samples for SDE and probability flow.
    xf_sde, logpx_sdeflow, xf_pflow, logpx_pflow = compute_likelihoods(
        v, s, interpolant, n_save, n_likelihood, eps, likelihood_bs
    )


    ### plot the loss, test logp, and samples from interpolant flow
    fig, axes = plt.subplots(1,4, figsize=(16,4))
    print("EPOCH:", counter)
    print("LOSS, GRAD:", loss, v_grad, s_grad)


    # plot loss over time.
    nsaves = len(data_dict['losses'])
    epochs = np.arange(nsaves)*metrics_freq
    axes[0].plot(epochs, data_dict['losses'], label=" v + s")
    axes[0].plot(epochs, data_dict['v_losses'], label="v")
    axes[0].plot(epochs, data_dict['s_losses'], label = "s" )
    axes[0].set_title("LOSS")
    axes[0].legend()


    # plot samples from SDE.
    axes[1].scatter(
        xf_sde[:,0], xf_sde[:,1], vmin=0.0, vmax=0.05, alpha = 0.2, c=grab(torch.exp(logpx_sdeflow).detach()))
    axes[1].set_xlim(-5,5)
    axes[1].set_ylim(-6.5,6.5)
    axes[1].set_title("Samples from SDE", fontsize=14)


    # plot samples from pflow
    axes[2].scatter(
        xf_pflow[:,0], xf_pflow[:,1], vmin=0.0, vmax=0.05, alpha = 0.2, c=grab(torch.exp(logpx_pflow).detach()))
    axes[2].set_xlim(-5,5)
    axes[2].set_ylim(-6.5,6.5)
    axes[2].set_title("Samples from PFlow", fontsize=14)


    # plot likelihood estimates.
    axes[3].plot(epochs, data_dict['logps_pflow'],   label='pflow', color='purple')
    axes[3].plot(epochs, data_dict['logps_sdeflow'], label='sde',   color='red')
    # axes[3].hlines(
    #     y=grab(target_logp_est), xmin=0, xmax=epochs[-1], color='green', linestyle='--', label='exact', linewidth=2
    # )
    axes[3].set_title(r"$\log p$ from PFlow and SDE")
    axes[3].legend(loc='best')
    axes[3].set_ylim(-7,0)


    fig.suptitle(r"$\epsilon = $" + str(grab(eps)) + r" $n_{likelihood} = $" + str(n_likelihood), fontsize=16, y = 1.05)
    plt.savefig("figs/training-checker.pdf", bbox_inches='tight')
    plt.show()
    
    


def train_step(
    prior_bs: int,
    target_bs: int,
    N_t: int,
    interpolant: stochastic_interpolant.Interpolant,
    opt: Any,
    sched: Any
):
    """
    Take a single step of optimization on the training set.
    """
    opt.zero_grad()


    # construct batch
    x0s = base(prior_bs)
    x1s = target(target_bs)
    ts  = torch.rand(size=(N_t,))


    # compute the loss
    loss_start = time.perf_counter()
    loss_val, (loss_v, loss_s) = stochastic_interpolant.loss_sv(
        v, s, x0s, x1s, ts, interpolant, loss_fac=loss_fac
    )
    loss_end = time.perf_counter()


    # compute the gradient
    backprop_start = time.perf_counter()
    loss_val.backward()
    v_grad = torch.tensor([torch.nn.utils.clip_grad_norm_(v.parameters(), float('inf'))])
    s_grad = torch.tensor([torch.nn.utils.clip_grad_norm_(s.parameters(), float('inf'))])
    backprop_end = time.perf_counter()


    # perform the update.
    update_start = time.perf_counter()
    opt.step()
    sched.step()
    update_end = time.perf_counter()


    if counter < 5:
        print(f'[Loss: {loss_end - loss_start}], [Backprop: {backprop_end-backprop_start}], [Update: {update_end-update_start}].')


    return loss_val.detach(), loss_v.detach(), loss_s.detach(), v_grad.detach(), s_grad.detach()

### Define target

In [None]:

ndim = 2
def target(bs):
    x1 = torch.rand(bs) * 4 - 2
    x2_ = torch.rand(bs) - torch.randint(2, (bs,)) * 2
    x2 = x2_ + (torch.floor(x1) % 2)
    return (torch.cat([x1[:, None], x2[:, None]], 1) * 2)


target_samples = grab(target(10000))


fig = plt.figure(figsize=(6,6))
plt.hist2d(target_samples[:,0], target_samples[:,1], bins = 100, range=[[-4,4],[-4,4]]);
plt.title("Checker Target")
plt.show()


print("Batch Shape:", target_samples.shape)
# target_logp_est = target.log_prob(target(10000)).mean()

### Define Base Distribution

In [None]:
base_loc     = torch.zeros(ndim)
base_var     = torch.ones(ndim)
base         = prior.SimpleNormal(base_loc, 1.0*base_var)
base_samples = grab(base(20000))

In [None]:
fig = plt.figure(figsize=(6,6,))
plt.scatter(base_samples[:,0], base_samples[:,1],  label = 'base', alpha = 0.2);
plt.scatter(target_samples[:,0], target_samples[:,1], alpha = 0.2);
plt.title("Bimodal Target")
plt.title("Base vs Target")
plt.show()

### Define Interpolant

In [None]:
gamma = lambda t: t*(1-t)
gamma_dot = lambda t: 1 -2*t
gg_dot = lambda t: gamma(t)*gamma_dot(t)

# gamma = lambda t: torch.sqrt(t*(1-t))
# gamma_dot = lambda t: (1/(2*torch.sqrt(t*(1-t)))) * (1 -2*t)
# gg_dot = lambda t: (1/2)*(1-2*t)

# gamma = lambda t: torch.sin(math.pi * t)**2
# gamma_dot = lambda t: 2*math.pi*torch.sin(math.pi * t)*torch.cos(math.pi*t)

interpolant  = stochastic_interpolant.Interpolant(path='linear', 
                                                          gamma=gamma, 
                                                          gamma_dot=gamma_dot,
                                                          gg_dot = gg_dot,
                                                          It=None, dtIt=None)

### Define velocity field and optimizers

In [None]:
base_lr      = 2e-3
hidden_sizes = [150, 150, 150, 150]
in_size      = (ndim+1)
out_size     = (ndim)
inner_act    = 'relu'
final_act    = 'none'
print_model  = False


v     = itf.fabrics.make_fc_net(hidden_sizes=hidden_sizes, in_size=in_size, out_size=out_size, inner_act=inner_act, final_act=final_act)
s     = itf.fabrics.make_fc_net(hidden_sizes=hidden_sizes, in_size=in_size, out_size=out_size, inner_act=inner_act, final_act=final_act)
opt   = torch.optim.Adam([*v.parameters(), *s.parameters()], lr=base_lr)
sched = torch.optim.lr_scheduler.StepLR(optimizer=opt, step_size=1500, gamma=0.4)


eps          = torch.tensor(0.5)
N_era        = 14
N_epoch      = 500
N_t          = 50    # number of time steps in batch (e.g. to make samples from rho_t)
plot_bs      = 5000 # number of samples to use when plotting
prior_bs     = 25   # number of samples from rho_0 in batch
target_bs    = 100   # number of samples from rho_1 in batch
metrics_freq = 50   # how often to log metrics, e.g. if logp is not super cheap don't do it everytime
plot_freq    = 500  # how often to plot
n_save       = 10   # how often to checkpoint SDE integrator
loss_fac     = 4.0 # ratio of learning rates for w to v
n_likelihood = 20    # number of trajectories used to compute the SDE likelihood


if print_model:
    print("Here's the model v, s:", v, s)

In [None]:
data_dict = {
    'losses': [],
    'v_losses': [],
    's_losses': [],
    'v_grads': [],
    's_grads': [],
    'times': [],
    'logps_pflow': [],
    'logps_sdeflow': []
}

counter = 1
for i, era in enumerate(range(N_era)):
    for j, epoch in enumerate(range(N_epoch)):
        loss, v_loss, s_loss, v_grad, s_grad = train_step(
            prior_bs, target_bs, N_t, interpolant, opt, sched
        )


        if (counter - 1) % metrics_freq == 0:
            log_metrics(v, s, interpolant, n_save, n_likelihood, prior_bs, v_loss, 
                        s_loss, loss, v_grad, s_grad, eps, data_dict)


        if (counter - 1) % plot_freq == 0:
            make_plots(v, s, interpolant, n_save, n_likelihood, plot_bs, counter, metrics_freq, eps, data_dict)


        counter+=1

In [None]:
eps          = torch.tensor(0.5)


make_plots(v, s, interpolant, n_save, 10*n_likelihood, plot_bs, counter, metrics_freq, eps, data_dict)


### Save models for plotting later

In [None]:
# logdir = "/mnt/home/malbergo/InterpolantFlow/notebooks/figs/models_for_figs/"

# model_fname_sd = os.path.join(logdir, f'checker_state_dict_v-2.pt')
# model_fname = os.path.join(logdir, f'checker_v-2.pt')
# torch.save(v.state_dict(), model_fname_sd)
# torch.save(v, model_fname)

# model_fname_sd = os.path.join(logdir, f'checker_state_dict_s-2.pt')
# model_fname = os.path.join(logdir, f'checker_s-2.pt')
# torch.save(s.state_dict(), model_fname_sd)
# torch.save(s, model_fname)

In [None]:
# torch.manual_seed(111)

with torch.no_grad():
    sde_flow  = itf.stochastic_interpolant.SDEIntegrator(v=v, s=s, dt=torch.tensor(1e-2), eps=eps, n_save=n_save)
    bs  = 75000
    x0s = base(bs)
    xfs = sde_flow.rollout_forward(x0s)
    xf = grab(xfs[-1].squeeze())

    fig, ax = plt.subplots(1,1,figsize=(5,5))
    ax.scatter(xf[:,0], xf[:,1], vmin=0.0, vmax=0.05, alpha = 0.01)
    ax.set_xlim(-5,5)
    ax.set_ylim(-6,6)
    ax.set_title("Samples from v,w SDE Model, bs = " + str(bs) + ", eps = " + str(grab(eps)), fontsize=14)

In [None]:
# torch.manual_seed(111)

with torch.no_grad():
    pflow = itf.stochastic_interpolant.PFlowIntegrator(v=v, s=s, method='dopri5', eps=interpolant.eps, n_step=3)
    bs  = 75000
    x0s = base(bs)
    xfs, _ = pflow.rollout(x0s)
    print(xfs.shape)
    xf = grab(xfs[-1].squeeze())
    print(xf.shape)

    fig, ax = plt.subplots(1,1,figsize=(5,5))
    ax.scatter(xf[:,0], xf[:,1], vmin=0.0, vmax=0.05, alpha = 0.01)
    ax.set_xlim(-5,5)
    ax.set_ylim(-6,6)
    ax.set_title("Samples from v,w ODE Model, bs = " + str(bs) + ", eps = " + str(grab(eps)), fontsize=14)

### Make a plot of xt over time from integrating $v_t(x)$

In [None]:
skip = 2

ncol = len(xfs) // skip
fig, axes = plt.subplots(1, ncol, figsize=(ncol*4,4))
bins = 30

ts = np.linspace(0,1,10)

for i in range(len(xfs)):
    if i == len(xf) - 1:
        time_slice = grab(xfs[-1])
        axes[-1].scatter(time_slice[:,0], time_slice[:,1], label = 'diffused from $N(0, I_2)$', alpha = 0.02)
        axes[-1].scatter(target_samples[:,0], target_samples[:,1], alpha = 0.02,  label= ' true target mixture')
        axes[-1].set_xticks([])
        axes[-1].set_title('$t = %.2f$' % ((ts[i])))
    elif (i-1) % skip == 0:
        ind = (i-1) // skip
        
        time_slice = grab(xfs[i])
        axes[ind].scatter(time_slice[:,0], time_slice[:,1], label = 'diffused from $N(0, I_2)$', alpha = 0.02)
        axes[ind].scatter(target_samples[:,0], target_samples[:,1], alpha = 0.02,  label= ' true target mixture')
        axes[ind].set_xticks([])
        axes[ind].set_title('$t = %.2f$' % ((ts[i])))
        if (i-1) !=0:
            axes[ind].set_yticks([])
        else:
            axes[ind].set_ylabel("histogram density", fontsize=16)
fig.text(x=0.5, y = 0.0, s="x", fontsize = 16)
fig.text(0.91, 0.5, '$\cdot$ true', bbox={'facecolor': 'lightblue', 'pad': 4})
fig.text(0.91, 0.4, '$\cdot$ model', bbox={'facecolor': 'xkcd:orange', 'pad': 4})