In [14]:
!pip install pyro-ppl
!pip install pot


Collecting pot
  Downloading POT-0.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (823 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.0/823.0 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pot
Successfully installed pot-0.9.3


In [11]:
import numpy as np
import torch
import pandas as pd
from torch import nn
from torch.distributions import MultivariateNormal as MNormal
from torch.distributions import Normal, Cauchy, Categorical
import pyro
from pyro.infer import MCMC, HMC as pyro_hmc, NUTS as pyro_nuts
from matplotlib import pyplot as plt
import seaborn as sns
from typing import Optional, List, Tuple, Iterable, Callable, Union
from tqdm.notebook import tqdm, trange
from scipy.stats import gaussian_kde


sns.set_theme('talk', style="white")


device = 'cpu'

In [7]:


import torch
from torch import nn, optim
from torch.nn import KLDivLoss
from torch.distributions import MultivariateNormal as MNormal
from torch.distributions import Categorical, Normal

import pyro
from pyro.infer import MCMC, HMC as pyro_hmc, NUTS as pyro_nuts

from tqdm.notebook import tqdm, trange
import numpy as np

import matplotlib.pyplot as plt

In [8]:
from scipy.stats import gaussian_kde


class Funnel(object):
    """
    Funnel distribution.

    “Slice sampling”. R. Neal, Annals of statistics, 705 (2003) https://doi.org/10.1214/aos/1056562461

    Args:
        dim - dimension
    """
    def __init__(self, num_dims=2):
        self.num_dims = num_dims
        self.normal_first = Normal(0, 1)

    @property
    def dim(self) -> int:
        return self.num_dims

    def log_prob(self, x: torch.FloatTensor) -> torch.FloatTensor:
        """
        Returns:
            log p(x)
        """
        normal_last = Normal(torch.zeros(x.shape[:-1], device=x.device), torch.exp(x[..., 0] / 2.))
        return normal_last.log_prob(x[..., 1:].permute(-1, *range(x.ndim-1))).sum(0) + self.normal_first.log_prob(x[..., 0])

    def likelihood(self, x: torch.FloatTensor) -> torch.FloatTensor:
        """
        Returns:
            p(x)
        """
        return torch.exp(self.log_prob(x))

    def plot_2d_countour(self, ax):
        """
        Visualizes contour plot of Funnel distribution using log p(x)
        """
        x = np.linspace(-15, 15, 100)
        y = np.linspace(-10, 10, 100)
        X, Y = np.meshgrid(x, y)
        inp = torch.from_numpy(np.stack([X, Y], -1))
        Z = self.log_prob(inp.reshape(-1, 2)).reshape(inp.shape[:-1])

        ax.contour(Y, X, Z.exp(),
                   levels=3,
                   alpha=1., colors='midnightblue', linewidths=1)

    def visualize_dist(self, s=10000):
        """
        Visualizes Funnel distribution using sampled points
        """
        # Generate points from funnel distribution
        points = np.transpose(self.sample(s))
        Y = points[0]
        X = points[1]

        # Calculate the point density
        XY = np.vstack([X,Y])
        Z = gaussian_kde(XY)(XY)

        # Sort the points by density, so that the densest points are plotted last
        idx = Z.argsort()
        X, Y, Z = X[idx], Y[idx], Z[idx]

        plt.scatter(X, Y, c=Z, label=Z)
        plt.colorbar()
        plt.show()
        plt.close()

    def sample(self, num_samples: int) -> torch.Tensor:
        """
        Sample from the Funnel distribution
        """
        all_c = torch.randn((num_samples, self.dim))
        all_c[:, 0] = all_c[:, 0] * 1**0.5
        all_c[:, 1:] = all_c[:, 1:]*(torch.exp(1*all_c[:, 0]))[:, None]
        return all_c

    def estimate_dist(self, s=100000):
        """
        Estimates mean and standard deviation of the Funnel distribution
        by sampling from it
        """
        target_samp = self.sample(s)
        std = torch.std(target_samp, dim=0).numpy()
        m = torch.mean(target_samp, dim=0).numpy()
        return [m, std]


In [4]:

class Banana:

    def __init__(self, b=0.02, dim=2):
        self.b = b
        self.dim = dim
        self.sigma = 10 # can be changed

    def log_prob(self, x: torch.FloatTensor) -> torch.FloatTensor:
        """
        Returns:
            log p(x)
        """
        even = np.arange(0, x.shape[-1], 2)
        odd = np.arange(1, x.shape[-1], 2)
        ll = -0.5 * (x[..., odd] - self.b * x[..., even]**2 + (self.sigma**2) * self.b)**2 - ((x[..., even])**2)/(2 * self.sigma**2)
        return ll.sum(-1)

    def likelihood(self, x: torch.FloatTensor) -> torch.FloatTensor:
        """
        Returns:
            p(x)
        """
        return torch.exp(self.log_prob(x))

    def sample(self, s):
        """
        Sample from the Banana distribution
        """
        torch.manual_seed(926)
        even = np.arange(0, self.dim, 2)
        odd = np.arange(1, self.dim, 2)
        var = torch.ones(self.dim)
        var[..., even] = self.sigma**2
        base_dist = MNormal(torch.zeros(self.dim), torch.diag(var))
        samples = base_dist.sample((s,))
        samples[..., odd] += self.b * samples[..., even]**2 - self.b * self.sigma**2
        return samples

    def plot_2d_countour(self, ax):
        """
        Visualizes contour plot of Banana distribution using log p(x)
        """
        x = np.linspace(-20, 20, 100)
        y = np.linspace(-10, 10, 100)
        X, Y = np.meshgrid(x, y)
        inp = torch.from_numpy(np.stack([X, Y], -1))
        Z = self.log_prob(inp.reshape(-1, 2)).reshape(inp.shape[:-1])
        ax.contour(X, Y, Z.exp(),
                   levels=5,
                   alpha=1., colors='midnightblue', linewidths=1)

    def visualize_dist(self, s=1000):
        """
        Visualizes Banana distribution using sampled points
        """
        # Generate points from distribution
        points = self.sample(s)
        X = points[:, 0]
        Y = points[:, 1]

        # Calculate the point density
        XY = torch.stack([points[:, 0], points[:, 1]], dim=0).numpy()
        Z = gaussian_kde(XY)(XY)

        # Sort the points by density, so that the densest points are plotted last
        idx = Z.argsort()
        X, Y, Z = X[idx], Y[idx], Z[idx]

        plt.scatter(X, Y, c=Z, label=Z)
        plt.colorbar()
        plt.show()
        plt.close()

    def estimate_dist(self, s=10000000):
        """
        Estimates mean and standard deviation of the Banana distribution
        by sampling from it
        """
        target_samp = self.sample(s)
        std = torch.std(target_samp, dim=0).numpy()
        m = torch.mean(target_samp, dim=0).numpy()
        return [m, std]

In [12]:
def MALA(start: torch.FloatTensor,
        target,
        n_samples: int,
        burn_in: int,
        *,
        step_size: float,
        verbose: bool=False) -> Tuple[torch.FloatTensor, List]:
    """
    Metropolis-Adjusted Langevin Algorithm with Normal proposal

    Args:
        start - strating points of shape [n_chains, dim]
        target - target distribution instance with method "log_prob"
        n_samples - number of last samples from each chain to return
        burn_in - number of first samples from each chain to throw away
        step_size - step size for drift term
        verbose - whether to show iterations' bar

    Returns:
        tensor of chains with shape [n_samples, n_chains, dim], acceptance rates for each iteration
    """
    std_normal = MNormal(torch.zeros(start.shape[-1]), torch.eye(start.shape[-1]))
    chains = []
    acceptance_rate = []

    x = start.clone()
    x.requires_grad_(True)
    x.grad = None
    logp_x = target.log_prob(x)
    grad_x = torch.autograd.grad(logp_x.sum(), x)[0]

    range_ = trange if verbose else range
    for step_id in range_(n_samples + burn_in):
        noise =  torch.randn_like(x)
        y = x + step_size * grad_x + noise * (2 * step_size) ** .5

        logp_y = target.log_prob(y)
        grad_y = torch.autograd.grad(logp_y.sum(), y)[0]

        log_qyx = std_normal.log_prob(noise)
        log_qxy = std_normal.log_prob((x - y - step_size * grad_y) / (2 * step_size) ** .5)

        accept_prob = torch.clamp((logp_y + log_qxy - logp_x - log_qyx).exp(), max=1)
        mask = torch.rand_like(accept_prob) < accept_prob

        with torch.no_grad():
            x[mask, :] = y[mask, :]
            logp_x[mask] = logp_y[mask]
            grad_x[mask] = grad_y[mask]

        acceptance_rate.append(mask.float().mean().item())
        if step_id >= burn_in:
            chains.append(x.detach().data.clone())
    chains = torch.stack(chains, 0)
    return chains, acceptance_rate

In [17]:
!git clone https://github.com/svsamsonov/Practical_task
!mv Practical_task/* .
!rm -r Practical_task

Cloning into 'Practical_task'...
remote: Enumerating objects: 16, done.[K
remote: Counting objects: 100% (16/16), done.[K
remote: Compressing objects: 100% (12/12), done.[K
remote: Total 16 (delta 1), reused 6 (delta 0), pack-reused 0[K
Receiving objects: 100% (16/16), 69.74 KiB | 11.62 MiB/s, done.
Resolving deltas: 100% (1/1), done.


In [18]:

# for metrics
# import jax
# import ot
from metrics import ESS, acl_spectrum
from total_variation import (
    average_total_variation,
)

In [19]:
def compute_metrics(
    xs_true,
    xs_pred,
    name=None,
    n_samples=300,
    n_steps = 50,
    scale=1.0,
    trunc_chain_len=None,
    ess_rar=1,
):
    """
    Calculates metrics:
    ESS (Effective sample size),
    EMD (Earth mover’s distance),
    ESTV mean and std (Empirical sliced total variation distance)
    """

    torch.manual_seed(926)
    metrics = dict()
    key = jax.random.PRNGKey(0)

    # ESS
    ess = ESS(
        acl_spectrum(
            xs_pred[::ess_rar] - xs_pred[::ess_rar].mean(0)[None, ...],
        ),
    ).mean()
    metrics["ess"] = ess

    # ESTV
    xs_pred = xs_pred[-trunc_chain_len:]
    try:
      tracker = average_total_variation(
          key,
          xs_true,
          xs_pred,
          n_steps=n_steps,
          n_samples=n_samples,
      )
      metrics["tv_mean"] = tracker.mean()
      metrics["tv_conf_sigma"] = tracker.std_of_mean()
      mean = tracker.mean()
      std = tracker.std()

      # EMD
      metrics["emd"] = 0
      for b in range(xs_pred.shape[1]):
          M = ot.dist(xs_true / scale, xs_pred[:, b,:] / scale)
          emd = ot.lp.emd2([], [], M, numItermax = 1e6)
          metrics["emd"] += emd / xs_pred.shape[1]

      # Print results
      mean = metrics["tv_mean"]
      std = metrics["tv_conf_sigma"]
      ess = metrics["ess"]
      emd = metrics["emd"]

      if name is not None:
          print(f"===={name}====")
      print(
          f"TV distance. Mean: {mean:.3f}, Std: {std:.3f}. \nESS: {ess:.3f} \nEMD: {emd:.3f}",
      )

    except:
      print("During this try, only one distinct point is generated.")

    return metrics

In [21]:
import normflows as nf

In [103]:
device="cpu"
N_CHAINS = 1
dim = 10

proposal  = nf.distributions.base.GaussianMixture(n_modes = 2, dim =  dim, trainable=False)

start_gmm = proposal.sample(N_CHAINS).detach().cpu()

In [104]:
N_SAMPLES = 1000
BURN_IN=200

In [105]:
target = Banana(b=0.02, dim=dim)


In [106]:
chains, acceptance_rates_isir = MALA(start_gmm, target, N_SAMPLES, BURN_IN, step_size = 0.2)


In [107]:
# Calculate metrics
true_chain = target.sample(N_SAMPLES)
metrics = compute_metrics(
                true_chain.numpy(), # True samples
                chains.numpy(), # Generated samples
                name=f"Banana MALA dim=10",
                n_samples=N_SAMPLES,
                trunc_chain_len=N_SAMPLES,
                ess_rar=1,
)

  0%|          | 0/50 [00:00<?, ?it/s]

====Banana MALA dim=10====
TV distance. Mean: 0.316, Std: 0.012. 
ESS: 0.011 
EMD: 338.175
