In [None]:
import torch
from torchdiffeq import odeint_adjoint as odeint
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import math
import time
import sys
sys.path.append('../')
from typing import Tuple, Any

%load_ext autoreload
%autoreload 2


import interflow as itf
import interflow.prior as prior
import interflow.fabrics
import interflow.stochastic_interpolant as stochastic_interpolant
import interflow.gmm as gmm

from torch import autograd
from functorch import jacfwd, vmap


if torch.cuda.is_available():
    print('CUDA available, setting default tensor residence to GPU.')
    itf.util.set_torch_device('cuda')
else:
    print('No CUDA device found!')
print(itf.util.get_torch_device())


print("Torch version:", torch.__version__)

In [None]:
%matplotlib inline
import matplotlib as mpl
mpl.rcParams['axes.grid']  = True
mpl.rcParams['axes.grid.which']  = 'both'
mpl.rcParams['xtick.minor.visible']  = True
mpl.rcParams['ytick.minor.visible']  = True
mpl.rcParams['xtick.minor.visible']  = True
mpl.rcParams['axes.facecolor'] = 'white'
mpl.rcParams['grid.color'] = '0.8'
mpl.rcParams['grid.alpha'] = '0.5'
mpl.rcParams['figure.figsize'] = (8, 4)
mpl.rcParams['figure.titlesize'] = 12.5
mpl.rcParams['font.size'] = 12.5
mpl.rcParams['legend.fontsize'] = 12.5
mpl.rcParams['figure.dpi'] = 200
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['text.usetex'] = False


## Utility functions

In [None]:
def grab(var):
    """Take a tensor off the gpu and convert it to a numpy array on the CPU."""
    return var.detach().cpu().numpy()


def compute_likelihoods(
    v: torch.nn.Module,
    s: torch.nn.Module,
    interpolant: stochastic_interpolant.Interpolant,
    n_save: int,
    n_likelihood: int,
    eps: int,
    bs: int
) -> Tuple[torch.tensor, torch.tensor]:
    """Draw samples from the probability flow and SDE models, and compute likelihoods."""
    sde_flow = stochastic_interpolant.SDEIntegrator(
        v=v, s=s, dt=torch.tensor(1e-2), eps=eps, interpolant=interpolant, n_save=n_save, n_likelihood=n_likelihood
    )

    pflow = stochastic_interpolant.PFlowIntegrator(v=v, s=s,  
                                                  method='dopri5', 
                                                  interpolant=interpolant,
                                                  n_step=3)
    
    with torch.no_grad():
        x0_tests  = base(bs)
        xfs_sde   = sde_flow.rollout_forward(x0_tests) # [n_save x bs x dim]
        xf_sde    = grab(xfs_sde[-1].squeeze())        # [bs x dim]
        
        # ([n_likelihood, bs, dim], [bs])
        x0s_sdeflow, dlogps_sdeflow = sde_flow.rollout_likelihood(xfs_sde[-1])
        log_p0s = torch.reshape(
            base.log_prob(x0s_sdeflow.reshape((n_likelihood*bs, ndim))),
            (n_likelihood, bs)
        )
        logpx_sdeflow = torch.mean(log_p0s, axis=0) - dlogps_sdeflow


    logp0                  = base.log_prob(x0_tests)            # [bs]
    xfs_pflow, dlogp_pflow = pflow.rollout(x0_tests)            # [n_save x bs x dim], [n_save x bs]
    logpx_pflow            = logp0 + dlogp_pflow[-1].squeeze()  # [bs]
    xf_pflow               = grab(xfs_pflow[-1].squeeze())      # [bs x dim]


    return xf_sde, logpx_sdeflow, xf_pflow, logpx_pflow


def log_metrics(
    v: torch.nn.Module,
    s: torch.nn.Module,
    exact_interpolant: gmm.GMMInterpolant,
    interpolant: stochastic_interpolant.Interpolant,
    n_save: int,
    n_likelihood: int,
    likelihood_bs: int, 
    v_loss: torch.tensor,
    s_loss: torch.tensor,
    loss: torch.tensor,
    v_grad: torch.tensor,
    s_grad: torch.tensor,
    eps: torch.tensor,
    data_dict: dict
) -> None:
    # log loss and gradient data
    v_loss   = grab(v_loss).mean();   data_dict['v_losses'].append(v_loss)
    s_loss   = grab(s_loss).mean();   data_dict['s_losses'].append(s_loss)
    loss     = grab(loss).mean();     data_dict['losses'].append(loss)
    v_grad   = grab(v_grad).mean();   data_dict['v_grads'].append(v_grad)
    s_grad   = grab(s_grad).mean();   data_dict['s_grads'].append(s_grad)

    
    # compute and log likelihood data
    _, logpx_sdeflow, _, logpx_pflow = compute_likelihoods(
        v, s, interpolant, n_save, n_likelihood, eps, likelihood_bs)
    
    
    # compute kl and log data
    kl_pflow, kl_sdeflow = compute_kl(v, s, exact_interpolant, interpolant, eps, likelihood_bs)
    kl_pflow = grab(kl_pflow).mean(); data_dict['kl_pflow'].append(kl_pflow)
    kl_sdeflow = grab(kl_sdeflow).mean(); data_dict['kl_sdeflow'].append(kl_sdeflow)
    
    logpx_sdeflow = grab(logpx_sdeflow).mean(); data_dict['logps_sdeflow'].append(logpx_sdeflow)
    logpx_pflow = grab(logpx_pflow).mean(); data_dict['logps_pflow'].append(logpx_pflow)
    
    
def make_plots(
    v: torch.nn.Module,
    s: torch.nn.Module,
    interpolant: stochastic_interpolant.Interpolant,
    n_save: int,
    n_likelihood: int,
    likelihood_bs: int,
    counter: int,
    metrics_freq: int,
    eps: torch.tensor,
    data_dict: dict
) -> None:
    """Make plots to visualize samples and evolution of the likelihood."""
    # compute likelihood and samples for SDE and probability flow.
    xf_sde, logpx_sdeflow, xf_pflow, logpx_pflow = compute_likelihoods(
        v, s, interpolant, n_save, n_likelihood, eps, likelihood_bs
    )
    
    


    ### plot the loss, test logp, and samples from interpolant flow
    fig, axes = plt.subplots(1,4, figsize=(16,4))
    print("EPOCH:", counter)
    print("LOSS, GRAD:", loss, v_grad, s_grad)


    # plot loss over time.
    nsaves = len(data_dict['losses'])
    epochs = np.arange(nsaves)*metrics_freq
    axes[0].plot(epochs, data_dict['losses'], label=" v + s")
    axes[0].plot(epochs, data_dict['v_losses'], label="v")
    axes[0].plot(epochs, data_dict['s_losses'], label = "s" )
    axes[0].set_title("LOSS")
    axes[0].legend()


    # plot samples from SDE.
    axes[1].scatter(
        xf_sde[:,0], xf_sde[:,1], vmin=0.0, vmax=0.05, alpha = 0.2, c=grab(torch.exp(logpx_sdeflow).detach()))
    axes[1].set_xlim(-10,10)
    axes[1].set_ylim(-10,10)
    axes[1].set_title("Dims 0,1 of Samples from SDE", fontsize=14)


    # plot samples from pflow
    axes[2].scatter(
        xf_pflow[:,0], xf_pflow[:,1], vmin=0.0, vmax=0.05, alpha = 0.2, c=grab(torch.exp(logpx_pflow).detach()))
    axes[2].set_xlim(-10,10)
    axes[2].set_ylim(-10,10)
    axes[2].set_title("Dims 0,1 of Samples from PFlow", fontsize=14)


    # plot likelihood estimates.
    print( data_dict['kl_pflow'])
    axes[3].plot(epochs, data_dict['kl_pflow'],   label='pflow', color='purple')
    axes[3].plot(epochs, data_dict['kl_sdeflow'], label='sde',   color='red')
    axes[3].set_title(r"$KL(\rho_1(x) | \hat\rho(1,x)$")
    axes[3].legend(loc='best')
    ymax = max(data_dict['kl_pflow'])
    axes[3].set_ylim(-5,ymax + ymax*.01)


    fig.suptitle(r"$\epsilon = $" + str(grab(eps)) + r" $n_{likelihood} = $" + str(n_likelihood), fontsize=16, y = 1.05)
    plt.show()
    
    


def train_step(
    prior_bs: int,
    target_bs: int,
    N_t: int,
    interpolant: stochastic_interpolant.Interpolant,
    opt: Any,
    sched: Any
):
    """
    Take a single step of optimization on the training set.
    """
    opt.zero_grad()


    # construct batch
    x0s = base(prior_bs)
    x1s = target(target_bs)
    ts  = torch.rand(size=(N_t,))


    # compute the loss
    loss_start = time.perf_counter()
    loss_val, (loss_v, loss_s) = stochastic_interpolant.loss_sv(
        v, s, x0s, x1s, ts, interpolant, loss_fac=loss_fac
    )
    loss_end = time.perf_counter()


    # compute the gradient
    backprop_start = time.perf_counter()
    loss_val.backward()
    v_grad = torch.tensor([torch.nn.utils.clip_grad_norm_(v.parameters(), float('inf'))])
    s_grad = torch.tensor([torch.nn.utils.clip_grad_norm_(s.parameters(), float('inf'))])
    backprop_end = time.perf_counter()


    # perform the update.
    update_start = time.perf_counter()
    opt.step()
    sched.step()
    update_end = time.perf_counter()


    if counter < 5:
        print(f'[Loss: {loss_end - loss_start}], [Backprop: {backprop_end-backprop_start}], [Update: {update_end-update_start}].')


    return loss_val.detach(), loss_v.detach(), loss_s.detach(), v_grad.detach(), s_grad.detach()

In [None]:
def compute_kl(v,s, exact_interpolant, interpolant, eps = torch.tensor(2.0), bs = 500, n_likelihood=5):
    """
    Currently for ODE only
    """
    
    
    sde_flow = stochastic_interpolant.SDEIntegrator(
        v=v, s=s, dt=torch.tensor(1e-2), eps=eps, interpolant=interpolant, n_save=5, n_likelihood=n_likelihood
    )

    pflow = stochastic_interpolant.PFlowIntegrator(v=v, s=s,  
                                                  method='dopri5', 
                                                  interpolant=interpolant,
                                                  n_step=3)
    
    x1s = exact_interpolant.sample_rho1(bs)
    log_rho1 = exact_interpolant.log_rho1(x1s)
    
    x0s_pflow, dlogp_pflow = pflow.rollout(x1s, reverse=True)   # [n_save x bs x dim], [n_save x bs]
    x0_pflow               = grab(x0s_pflow[-1].squeeze())      # [bs x dim]
    logp0                  = base.log_prob(x0s_pflow[-1])       # [bs]
    log_rho1_hat_ode            = logp0 - dlogp_pflow[-1].squeeze()  # [bs]
    
    
    # ([n_likelihood, bs, dim], [bs])
    with torch.no_grad():
        x0s_sdeflow, dlogps_sdeflow = sde_flow.rollout_likelihood(x1s)
        log_p0s = torch.reshape(
            base.log_prob(x0s_sdeflow.reshape((n_likelihood*bs, ndim))),
            (n_likelihood, bs)
        )
        log_rho1_hat_sde = torch.mean(log_p0s, axis=0) - dlogps_sdeflow
    
    
    
    
    return (log_rho1 - log_rho1_hat_ode).mean(), (log_rho1 - log_rho1_hat_sde).mean()


compute_kl(v,s,exact_interpolant,interpolant)

In [None]:
# def compute_v_diff(v,s, exact_interpolant):
    

### Define target GMM

In [None]:
def setup_random_covs(N: int, d: int):
    Cs = torch.zeros(N, d, d)
    for ii in range(N):
        C = torch.randn(d, d)
        Cs[ii] = (C.T @ C + 0.5*torch.eye(d))/torch.sqrt(torch.tensor(d))
    
    return Cs


def setup_random_means(N: int, d: int, scale: float):
    return scale*torch.randn((N, d))

In [None]:
ndim        = 50
N0          = 1
N1          = 5
scale       = 4
gamma_type  = 'brownian'
path        = 'linear'
p0s         = (torch.ones(N0) / N0)
p1s         = (torch.ones(N1) / N1)
mu0s        = setup_random_means(N0, ndim, scale)
mu1s        = setup_random_means(N1, ndim, scale)
C0s         = torch.eye(ndim).unsqueeze(0)
C1s         = setup_random_covs(N1, ndim)
print(C0s.shape)

exact_interpolant = gmm.GMMInterpolant(
    p0s, p1s, mu0s, mu1s, C0s, C1s, path, gamma_type, device='cuda'
)

In [None]:


target = lambda bs: exact_interpolant.sample_rho1(bs)



from matplotlib import gridspec
unit_size = ndim / 1
fig = plt.figure(figsize = (unit_size * 2,unit_size * 1), constrained_layout=True)
gs = gridspec.GridSpec(ndim, ndim, figure=fig)
gs.update(wspace=0.5)

# ensemble = cfgs_all_batch[-1][::1]
target_samples = grab(target(10000))
# print(true.shape)
for i in range(5):
    for j in range(i):
        ax = plt.subplot(gs[i, j], )
        
        # ax.scatter(ensemble[:,j], ensemble[:,i], alpha=0.02)
        ax.scatter(target_samples[:,j],target_samples[:,i], alpha=0.02)
        ax.set_xlim(-15,15)
        ax.set_ylim(-15,15)
        ax.set_xticks([])
        # ax.set_yticks([])
        
bottom = 0.01; left=0.01
top=1.-bottom; right = 1.-0.5
wspace= 0.0  # set to zero for no spacing
hspace= 0.2
plt.subplots_adjust(top=top, bottom=bottom, left=left, right=right, 
                    wspace=wspace, hspace=hspace)
fig.text(x = 0.01, y = 0.974, s= "Cross sections of " + str(ndim) + "-dimensional target GMM", fontsize = 12)

### Define Base Distribution

In [None]:
class GMMbase(itf.prior.Prior):
    def __init__(self, exact_interpolant):
        super().__init__()
        
        self.exact_interpolant = exact_interpolant
            
    def log_prob(self, x):
        return self.exact_interpolant.log_rho0(x)
    
    def forward(self, bs):
        return self.exact_interpolant.sample_rho0(bs)
    
    
base = GMMbase(exact_interpolant)
base_samples = grab(base(10000))



In [None]:
fig = plt.figure(figsize=(3,3,))
plt.scatter(base_samples[:,0], base_samples[:,1],  label = 'base', alpha = 0.1);
plt.scatter(target_samples[:,0], target_samples[:,1], alpha = 0.1);
plt.title("Bimodal Target")
plt.title("Base vs Target")
plt.show()

### Define Interpolant

In [None]:
interpolant  = stochastic_interpolant.Interpolant(path=path, gamma_type=gamma_type)

### Define velocity field and optimizers

In [None]:
base_lr      = 2e-3
hidden_sizes = [100, 100, 100, 100]
in_size      = (ndim+1)
out_size     = (ndim)
inner_act    = 'relu'
final_act    = 'none'
print_model  = False


v     = itf.fabrics.make_fc_net(hidden_sizes=hidden_sizes, in_size=in_size, out_size=out_size, inner_act=inner_act, final_act=final_act)
s     = itf.fabrics.make_fc_net(hidden_sizes=hidden_sizes, in_size=in_size, out_size=out_size, inner_act=inner_act, final_act=final_act)
opt   = torch.optim.Adam([*v.parameters(), *s.parameters()], lr=base_lr)
sched = torch.optim.lr_scheduler.StepLR(optimizer=opt, step_size=1500, gamma=0.4)


eps          = torch.tensor(0.5)
N_era        = 14
N_epoch      = 500
N_t          = 50    # number of time steps in batch (e.g. to make samples from rho_t)
plot_bs      = 2000 # number of samples to use when plotting
prior_bs     = 25   # number of samples from rho_0 in batch
target_bs    = 100   # number of samples from rho_1 in batch
metrics_freq = 50   # how often to log metrics, e.g. if logp is not super cheap don't do it everytime
plot_freq    = 500  # how often to plot
n_save       = 10   # how often to checkpoint SDE integrator
loss_fac     = 4.0 # ratio of learning rates for w to v
n_likelihood = 20    # number of trajectories used to compute the SDE likelihood


if print_model:
    print("Here's the model v, s:", v, s)

In [None]:
data_dict = {
    'losses': [],
    'v_losses': [],
    's_losses': [],
    'v_grads': [],
    's_grads': [],
    'times': [],
    'logps_pflow': [],
    'logps_sdeflow': [],
    'kl_pflow': [],
    'kl_sdeflow': []
}

counter = 1
for i, era in enumerate(range(N_era)):
    for j, epoch in enumerate(range(N_epoch)):
        loss, v_loss, s_loss, v_grad, s_grad = train_step(
            prior_bs, target_bs, N_t, interpolant, opt, sched
        )


        if (counter - 1) % metrics_freq == 0:
            log_metrics(v, s, exact_interpolant, interpolant, n_save, n_likelihood, prior_bs, v_loss, 
                        s_loss, loss, v_grad, s_grad, eps, data_dict)


        if (counter - 1) % plot_freq == 0:
            make_plots(v, s, interpolant, n_save, n_likelihood, plot_bs, counter, metrics_freq, eps, data_dict)


        counter+=1