In [17]:
import numpy as np

import torch
from torch import nn
from torch import optim

import tqdm

import tkinter

from IPython.display import clear_output

import sys
sys.path.append('..')

import pyro
from samplers import mala, i_sir, ex2_mcmc

from cifar10_experiments.models import Generator, Discriminator

from sampling_utils.adaptive_mc import CISIR, Ex2MCMC, FlowMCMC
from sampling_utils.adaptive_sir_loss import MixKLLoss
from sampling_utils.distributions import (
    Banana,
    CauchyMixture,
    Distribution,
    Funnel,
    HalfBanana,
    IndependentNormal,
)
from sampling_utils.ebm_sampling import MALA
from sampling_utils.flows import RNVP
from sampling_utils.metrics import ESS, acl_spectrum
from sampling_utils.total_variation import (
    average_total_variation,
)

In [18]:
device = 'cpu'
lat_size=100

In [19]:
gen_cifar10 = Generator(lat_size)
gen_cifar10.to(device)

discr_cifar10 = Discriminator()
discr_cifar10.to(device)

prior_cifar10 = torch.distributions.MultivariateNormal(torch.zeros(lat_size).to(device), torch.eye(lat_size).to(device))

In [20]:
gen_cifar10.load_state_dict(torch.load('./weights/generator.pkl', map_location='cpu'))
discr_cifar10.load_state_dict(torch.load('./weights/discriminator.pkl', map_location='cpu'))
gen_cifar10.eval()
discr_cifar10.eval();

In [21]:
def get_energy_wgan(z):
    return (-0.1*discr_cifar10(gen_cifar10(z)).squeeze() - prior_cifar10.log_prob(z).squeeze())

def log_target_dens(x):
    """
    returns the value of a target density - mixture of the 3 gaussians 
    """
    x = torch.FloatTensor(x).to(device)
    return -get_energy_wgan(x).detach().cpu().numpy()

def grad_log_target_dens(x):
    """
    returns the gradient of log-density 
    """
    x = torch.FloatTensor(x).to(device)
    x.requires_grad_(True)
    external_grad = torch.ones(x.shape[0])
    (-get_energy_wgan(x)).backward(gradient=external_grad)
    return x.grad.data.detach().cpu().numpy()

In [22]:
log_target_dens(np.random.randn(2, lat_size))

array([-292.15216, -312.44003], dtype=float32)

In [23]:
grad_log_target_dens(np.random.randn(2, lat_size)).shape

(2, 100)

### Target distribution

In [24]:
class distr:
    """
    Base class for a custom target distribution
    """

    def __init__(self, beta = 1.0):
        super().__init__()
        self.beta = beta

    def log_prob(self, z):
        """
        The method returns target logdensity, estimated at point z
        Input:
        z - datapoint
        Output:
        log_density: log p(z)
        """
        # You should define the class for your custom distribution
        return -get_energy_wgan(z)

    def energy(self, z):
        """
        The method returns target logdensity, estimated at point z
        Input:
        z - datapoint
        Output:
        energy = -log p(z)
        """
        # You should define the class for your custom distribution
        return -get_energy_wgan(z)

    def __call__(self, z):
        return self.log_prob(z)

### Flex2MCMC parameters

In [33]:
params_flex = {
      "N": 5,
      "grad_step": 0.2,
      "adapt_stepsize": False,
      "corr_coef": 0.0,
      "bernoulli_prob_corr": 0.0,
      "mala_steps": 0,
    "flow": {
      "num_flows": 5, # number of normalizing layers 
      "lr": 1e-3, # learning rate 
      "batch_size": 5,
      "n_steps": 100,
    }
}

beta = 1.0
scale_proposal = 1.0

target = distr(beta)

loc_proposal = torch.zeros(lat_size).to(device)
scale_proposal = scale_proposal * torch.ones(lat_size).to(device)
proposal = IndependentNormal(
    dim=lat_size,
    loc=loc_proposal,
    scale=scale_proposal,
    device=device,
)

In [None]:
pyro.set_rng_seed(42)
mcmc = Ex2MCMC(**params_flex, dim=lat_size)
verbose = mcmc.verbose
mcmc.verbose = False
flow = RNVP(params_flex["flow"]["num_flows"], dim=lat_size)
flow_mcmc = FlowMCMC(
    target,
    proposal,
    flow,
    mcmc,
    batch_size=params_flex["flow"]["batch_size"],
    lr=params_flex["flow"]["lr"],
)
flow.train()
out_samples, nll = flow_mcmc.train(
    n_steps=params_flex["flow"]["n_steps"],
)
assert not torch.isnan(
    next(flow.parameters())[0, 0],
).item()

flow.eval()
mcmc.flow = flow
mcmc.verbose = verbose

 10%|█         | 10/100 [00:10<01:30,  1.00s/it]

In [None]:
#sample from a normalizing flow
n_steps_flex2 = 1000
pyro.set_rng_seed(42)
start = proposal.sample((batch_size,))
mcmc.N = 20
mcmc.mala_steps = 0
mcmc.grad_step = 0.1
# s = time.time()
out = mcmc(start, target, proposal, n_steps = n_steps_flex2)
if isinstance(out, tuple):
    sample = out[0]
else:
    sample = out
sample = np.array(
    [_.detach().numpy() for _ in sample],
).reshape(-1, batch_size, dim)
sample_flex2_new = sample
#resample with 0 mala steps
mcmc.mala_steps = 50
out_new = mcmc(start, target, proposal, n_steps = n_steps_flex2)[0]
out_new = np.array(
    [_.detach().numpy() for _ in out_new],
).reshape(-1, batch_size, dim)
sample_flex2_final = out_new
print(sample_flex2_final.shape)