In [None]:
# # when running for the first time

!pip install git+https://github.com/kwotsin/mimicry.git
!pip3 install pyro-ppl

In [None]:
import numpy as np
import torch
from torch import nn
from torch.distributions import MultivariateNormal as MNormal
from torch.distributions import Categorical
import pyro
from pyro.infer import MCMC, HMC as pyro_hmc, NUTS as pyro_nuts
from matplotlib import pyplot as plt
import seaborn as sns
from typing import Optional, List, Tuple, Iterable, Callable, Union
from tqdm import tqdm, trange
from scipy.stats import gaussian_kde
from scipy.special import logsumexp
from pathlib import Path


from torchvision.utils import make_grid

# from torch_mimicry.nets import sngan

sns.set_theme('talk', style="white")

In [None]:
N_CHAINS = 10
N_SAMPLES = 1000
BURN_IN = 100

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def ema(series: Iterable, n: int) -> List:
    """
    returns an n period exponential moving average for
    the time series
    """
    series = np.array(series)
    ema = []
    j = 1

    #get n sma first and calculate the next n period ema
    sma = sum(series[:n]) / n
    multiplier = 2 / float(1 + n)
    ema.append(sma)

    #EMA(current) = ( (Price(current) - EMA(prev) ) x Multiplier) + EMA(prev)
    ema.append(( (series[n] - sma) * multiplier) + sma)

    #now calculate the rest of the values
    for i in series[n+1:]:
        tmp = ( (i - ema[j]) * multiplier) + ema[j]
        j = j + 1
        ema.append(tmp)

    return ema

In [None]:
def HMC(start,
        target,
        n_samples: int,
        burn_in: int,
        *,
        step_size: float,
        num_leapfrog_steps: float = 1,
        verbose: bool = False) -> torch.FloatTensor:
    """
    Hamiltonian Monte Carlo

    Args:
        start - strating points of shape [n_chains x dim]
        target - target distribution instance with method "log_prob"
        n_samples - number of last samples from each chain to return
        burn_in - number of first samples from each chain to throw away
        step_size - step size for drift term
        verbose - whether to show iterations' bar

    Returns:
        tensor of chains with shape [n_samples, n_chains, dim], acceptance rates for each iteration
    """

    x = start.clone()
    x.requires_grad_(False)
    def energy(z):
        z = z["points"]
        return -target.log_prob(z).sum()

    kernel = pyro_hmc(
        potential_fn=energy, step_size=step_size, num_steps=num_leapfrog_steps, full_mass=False
        )

    init_params = {"points": x}
    mcmc_true = MCMC(
        kernel=kernel,
        num_samples=n_samples,
        initial_params=init_params,
        warmup_steps=burn_in,
    )
    mcmc_true.run()

    q_true = mcmc_true.get_samples(group_by_chain=True)["points"]
    samples_true = q_true.view(-1, *start.shape).detach().cpu()

    return samples_true

In [None]:
def MALA(start: torch.FloatTensor,
        target,
        n_samples: int,
        burn_in: int,
        *,
        step_size: float,
        verbose: bool=False) -> Tuple[torch.FloatTensor, List]:
    """
    Metropolis-Adjusted Langevin Algorithm with Normal proposal

    Args:
        start - strating points of shape [n_chains x dim]
        target - target distribution instance with method "log_prob"
        step_size - step size for drift term
        verbose - whether show iterations' bar

    Returns:
        sequence of slices per each iteration, acceptance rates per each iteration
    """
    std_normal = MNormal(torch.zeros(start.shape[-1], device=start.device), torch.eye(start.shape[-1], device=start.device))
    chains = []
    acceptance_rate = []

    x = start.clone()
    x.requires_grad_(True)
    x.grad = None
    logp_x = target.log_prob(x)
    grad_x = torch.autograd.grad(logp_x.sum(), x)[0]

    range_ = trange if verbose else range
    for step_id in range_(n_samples + burn_in):
        noise =  torch.randn_like(x)
        y = x + step_size * grad_x + noise * (2 * step_size) ** .5

        logp_y = target.log_prob(y)
        grad_y = torch.autograd.grad(logp_y.sum(), y)[0]

        log_qyx = std_normal.log_prob(noise)
        log_qxy = std_normal.log_prob((x - y - step_size * grad_y) / (2 * step_size) ** .5)

        accept_prob = torch.clamp((logp_y + log_qxy - logp_x - log_qyx).exp(), max=1)
        mask = torch.rand_like(accept_prob) < accept_prob

        with torch.no_grad():
            x[mask, :] = y[mask, :]
            logp_x[mask] = logp_y[mask]
            grad_x[mask] = grad_y[mask]

        acceptance_rate.append(mask.float().mean().item())
        if step_id >= burn_in:
            chains.append(x.detach().data.cpu().clone())
    chains = torch.stack(chains, 0)
    return chains, acceptance_rate

In [None]:
def ISIR(start: torch.FloatTensor,
        target,
        proposal,
        n_samples: int,
        burn_in: int,
        *,
        n_particles: int,
        verbose: bool=False) -> Tuple[List[torch.FloatTensor], List]:
    """
    Iterated Sampling Importance Resampling

    Args:
        start - strating points of shape [n_chains x dim]
        target - target distribution instance with method "log_prob"
        proposal - proposal distribution instance with methods "log_prob" and "sample"
        n_samples - number of last samples from each chain to return
        burn_in - number of first samples from each chain to throw away
        n_particles - number of particles including one from previous step
        verbose - whether to show iterations' bar

    Returns:
        tensor of chains with shape [n_samples, n_chains, dim], acceptance rates for each iteration
    """
    chains = []
    acceptance_rate = []

    x = start.clone()
    logp_x = target.log_prob(x)
    logq_x = proposal.log_prob(x)

    range_ = trange if verbose else range
    for step_id in range_(n_samples + burn_in):
        particles = proposal.sample((x.shape[0], n_particles - 1))
        logqs = torch.cat([logq_x[:, None], proposal.log_prob(particles)], 1)
        logps = torch.cat([logp_x[:, None], target.log_prob(particles)], 1)
        particles = torch.cat([x[:, None, :], particles], 1)

        log_weights = logps - logqs.to(logps.device)
        indices = Categorical(logits=log_weights).sample()

        x = particles[np.arange(x.shape[0]), indices.to(x.device)].detach()
        logp_x = logps[np.arange(x.shape[0]), indices].detach()
        logq_x = logqs[np.arange(x.shape[0]), indices.to(logq_x.device)].detach()

        acceptance_rate.append((indices != 0).float().mean().item())
        if step_id >= burn_in:
            chains.append(x.detach().data.cpu().clone())
    chains = torch.stack(chains, 0)
    return chains, acceptance_rate

In [None]:
class TargetGAN(object):
    def __init__(
            self,
            gen: nn.Module,
            dis: nn.Module,
            device: Union[str, int, torch.device],
            batch_size: int = 128
            ):
        self.gen = gen
        self.dim = gen.latent_dim
        self.dis = dis
        self.device = device
        self.batch_size = batch_size

    def batch_log_prob(self, z: torch.FloatTensor):
        magic_const = 1 #try differnt ones here
        return  self.gen.prior.log_prob(z.to(self.device)) + magic_const * self.dis(self.gen(z.to(self.device))).squeeze()

    def log_prob(self, z: torch.FloatTensor, batch_size: Optional[int] = None):
        z_flat = z.reshape(-1, self.dim)
        batch_size = batch_size or self.batch_size
        return torch.cat(list(map(self.batch_log_prob, z_flat.split(batch_size, 0))), 0).reshape(z.shape[:-1])


Download the pre-trained DC-GAN model, architecure is taken from Mimicry repo.

In [None]:
!pip install --upgrade --no-cache-dir gdown

! if [ ! -f "cifar10_dcgan_G.pth" ]; then gdown 1k_DwjkprptuCXK_gJzFF9eUCHTD9vOUH; fi
! if [ ! -f "cifar10_dcgan_D.pth" ]; then gdown 1o6cXzpAPhGZxJOJseic8oGhsuilcLdLK; fi

DC-GAN architecture here

In [None]:
# DCGAN

from torch import nn


class DCGANGenerator(nn.Module):
    def __init__(
        self,
        ngpu=1,
        nc=3,
        nz=100,
        ngf=64,
        mean=(0.5, 0.5, 0.5),
        std=(0.5, 0.5, 0.5),
    ):
        super().__init__()
        self.z_dim = self.latent_dim = nz
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, kernel_size=1, stride=1, padding=0, bias=False),
            nn.Tanh(),
        )

    def forward(self, input):
        if input.ndim == 2:
            input = input[..., None, None]
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        return output


class DCGANDiscriminator(nn.Module):
    def __init__(
        self,
        ngpu=1,
        nc=3,
        ndf=64,
        mean=(0.5, 0.5, 0.5),
        std=(0.5, 0.5, 0.5),
        output_layer="sigmoid",
    ):
        super().__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 2, 2, 0, bias=False),
        )

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        return output.view(-1, 1).squeeze(1)


batch_size = 64
gen = DCGANGenerator().to(device)
dis = DCGANDiscriminator().to(device)

gen.load_state_dict(torch.load('cifar10_dcgan_G.pth'))
dis.load_state_dict(torch.load('cifar10_dcgan_D.pth'))
gen.eval()
dis.eval();

prior = MNormal(
    torch.zeros(gen.latent_dim).to(device), torch.eye(gen.latent_dim).to(device)
    )

gen.prior = prior
gan_target = TargetGAN(gen, dis, device=device, batch_size=batch_size)

In [None]:
#implement Metropolis-Hastings here

In [None]:
#sample here

Code to compute the Inception score:

In [None]:
N_FID_SAMPLES = 1000
N_STEPS = 50

In [None]:
import os
import random
import time

import numpy as np
import tensorflow as tf
import torch

from torch_mimicry.datasets.image_loader import get_dataset_images
from torch_mimicry.metrics.fid import fid_utils
from torch_mimicry.metrics.inception_model import inception_utils
from torch_mimicry.metrics.compute_fid import compute_real_dist_stats


def _normalize_images(images):
    """
    Given a tensor of images, uses the torchvision
    normalization method to convert floating point data to integers. See reference
    at: https://pytorch.org/docs/stable/_modules/torchvision/utils.html#save_image
    The function uses the normalization from make_grid and save_image functions.
    Args:
        images (Tensor): Batch of images of shape (N, 3, H, W).
    Returns:
        ndarray: Batch of normalized images of shape (N, H, W, 3).
    """
    # Shift the image from [-1, 1] range to [0, 1] range.
    min_val = float(images.min())
    max_val = float(images.max())
    images.clamp_(min=min_val, max=max_val)
    images.add_(-min_val).div_(max_val - min_val + 1e-5)

    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
    images = images.mul_(255).add_(0.5).clamp_(0, 255).permute(0, 2, 3, 1).to(
        'cpu', torch.uint8).numpy()

    return images

In [None]:
from torch_mimicry.metrics.inception_model import inception_utils
from torch_mimicry.metrics.inception_score import inception_score_utils as tf_inception_score


def inception_score(images,
                    device=None,
                    splits=10,
                    log_dir='./log',
                    seed=0,
                    print_every=20):
    """
    Computes the inception score of generated images.
    Args:
        images (torch.Tensor): Generated images.
        num_samples (int): The number of samples to generate.
        batch_size (int): Batch size per feedforward step for inception model.
        splits (int): The number of splits to use for computing IS.
        log_dir (str): Path to store metric computation objects.
        seed (int): Random seed for generation.
    Returns:
        Mean and standard deviation of the inception score computed from using
        num_samples generated images.
    """
    start_time = time.time()

    if device is None:
        device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

    # Make sure the random seeds are fixed
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    # Build inception
    inception_path = os.path.join(log_dir, 'metrics/inception_model')
    inception_utils.create_inception_graph(inception_path)

    is_mean, is_std = tf_inception_score.get_inception_score(images,
                                                             splits=splits,
                                                             device=device)

    print("INFO: Inception Score: {:.4f} ± {:.4f} [Time Taken: {:.4f} secs]".
          format(is_mean, is_std,
                 time.time() - start_time))

    return is_mean, is_std