# Train GRBMs on GMM, MNIST, CelebA, and CIFAR10 

Copyright 2022 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Required packages:

*   pytorch_fid
*   seaborn
*   pandas
*   matplotlib

To change the hyperparameters, please go to the config function.

Important hyperparameters include: 

*   CD_step
*   hidden_size
*   inference_method: ['Gibbs', 'Langevin', 'Langevin-Gibbs'] 

In [None]:
from __future__ import print_function
import logging
from collections import defaultdict
import seaborn as sns
import pandas as pd
import os
import pdb
import PIL
import imageio
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torchvision.datasets.utils import download_file_from_google_drive, check_integrity, verify_str_arg
import numpy as np
from tqdm import tqdm
from PIL import Image  # NOQA
import matplotlib
from typing import Any, List, Union, Optional, Tuple, Callable
from functools import partial
from threading import Thread
from mpl_toolkits.axes_grid1 import make_axes_locatable

matplotlib.use('Agg')
import matplotlib.pyplot as plt  # NOQA

# Commented out IPython magic to ensure Python compatibility.
# %matplotlib inline
EPS = 1e-7
torch.set_default_dtype(torch.float64)
sns.set_theme(style="darkgrid")
"""Here are some helper functions for training and visualization."""


class TensorEncoder(json.JSONEncoder):

  def default(self, obj):
    if isinstance(obj, torch.Tensor):
      return obj.cpu().numpy().tolist()
    return json.JSONEncoder.default(self, obj)


def setup_logging(log_level, log_file, logger_name="exp_logger"):
  """ Setup logging """
  numeric_level = getattr(logging, log_level.upper(), None)
  if not isinstance(numeric_level, int):
    raise ValueError("Invalid log level: %s" % log_level)

  logging.basicConfig(
      filename=log_file,
      filemode="w",
      format=
      "%(levelname)-5s | %(asctime)s | File %(filename)-20s | Line %(lineno)-5d | %(message)s",
      datefmt="%m/%d/%Y %I:%M:%S %p",
      level=numeric_level)

  # define a Handler which writes messages to the sys.stderr
  console = logging.StreamHandler()
  console.setLevel(numeric_level)
  # set a format which is simpler for console use
  formatter = logging.Formatter(
      "%(levelname)-5s | %(asctime)s | %(filename)-25s | line %(lineno)-5d: %(message)s"
  )
  # tell the handler to use this format
  console.setFormatter(formatter)
  # add the handler to the root logger
  logging.getLogger(logger_name).addHandler(console)

  return get_logger(logger_name)


def get_logger(logger_name="exp_logger"):
  return logging.getLogger(logger_name)


def cosine_schedule(eta_min=0, eta_max=1, T=10):
  return [
      eta_min + (eta_max - eta_min) * (1 + math.cos(tt * math.pi / T)) / 2
      for tt in range(T)
  ]


def fig2img(fig):
  import io
  buf = io.BytesIO()
  fig.savefig(buf, bbox_inches='tight')
  buf.seek(0)
  img = Image.open(buf)
  return img


def show_img(matrix, title):
  plt.figure()
  plt.axis('off')
  plt.gray()
  img = np.array(matrix, np.float64)
  # plt.imshow(img, interpolation='bilinear')
  plt.imshow(img)
  plt.title(title)

  fig = plt.gcf()
  img_out = fig2img(fig)
  plt.close()

  return img_out


def save_gif_fancy(imgs, nrow, save_name):
  imgs = (show_img(
      utils.make_grid(xx[1],
                      nrow=nrow,
                      normalize=False,
                      padding=1,
                      pad_value=1.0).permute(1, 2, 0).cpu().numpy(),
      f'sample at {xx[0]:03d} step') for xx in imgs)
  img = next(imgs)
  img.save(fp=save_name,
           format='GIF',
           append_images=imgs,
           save_all=True,
           duration=400,
           loop=0)


def save(model, results_folder, epoch):
  data = {'epoch': epoch, 'model': model.state_dict()}
  torch.save(data, f'{results_folder}/model-{epoch}.pt')


def load(model, results_folder, epoch):
  data = torch.load(f'{results_folder}/model-{epoch}.pt')
  model.load_state_dict(data['model'])


def train(model,
          train_loader,
          optimizer,
          config,
          is_anneal_data_noise=False,
          std=0.0):
  model.train()
  for ii, (data, _) in enumerate(tqdm(train_loader)):
    if config['cuda']:
      data = data.cuda()

    if is_anneal_data_noise:
      data += torch.randn_like(data) * std

    optimizer.zero_grad()
    model.CD_grad(data)
    if config['clip_norm'] > 0:
      nn.utils.clip_grad_norm_(model.parameters(), config['clip_norm'])
    optimizer.step()

    if ii == len(train_loader) - 1:
      recon_loss = model.reconstruction(data).item()

  return recon_loss


def unnormalize_img_tuple(img_tuple, mean, std):
  if isinstance(std, torch.Tensor):
    mean = mean.view(1, -1, 1, 1).to(img_tuple[0][1].device)
    std = std.view(1, -1, 1, 1).to(img_tuple[0][1].device)

  return [(xx[0], (xx[1] * std + mean).clamp(min=0, max=1)) for xx in img_tuple]


def visualize_sampling(model, epoch, config, tag=None, is_show_gif=True):
  tag = '' if tag is None else tag
  B, C, H, W = config['sampling_batch_size'], config['channel'], config[
      'height'], config['width']
  v_init = torch.randn(B, C, H, W).cuda() * config['neg_MC_init_var']
  v_list = model.sampling(v_init,
                          num_steps=config['sampling_steps'],
                          save_gap=config['sampling_gap'])

  if config['dataset'] == 'GMM':
    samples = v_list[-1][1].view(B, -1).cpu().numpy()
    vis_2D_samples(samples, config, tags=f'{epoch:05d}')
    vis_density_GRBM(model, config, epoch=epoch)
  else:
    if is_show_gif:
      v_list = unnormalize_img_tuple(v_list, config['img_mean'],
                                     config['img_std'])
      save_gif_fancy(
          v_list, config['sampling_nrow'],
          f"{config['exp_folder']}/sample_imgs_epoch_{epoch:05d}{tag}.gif")
      img_vis = v_list[-1][1]
    else:
      if isinstance(config['img_std'], torch.Tensor):
        mean = config['img_mean'].view(1, -1, 1, 1).cuda()
        std = config['img_std'].view(1, -1, 1, 1).cuda()
      else:
        mean = config['img_mean']
        std = config['img_std']

      img_vis = (v_list[-1][1] * std + mean).clamp(min=0, max=1)

    utils.save_image(
        utils.make_grid(img_vis,
                        nrow=config['sampling_nrow'],
                        normalize=False,
                        padding=1,
                        pad_value=1.0).cpu(),
        f"{config['exp_folder']}/sample_imgs_epoch_{epoch:05d}{tag}.png")


def vis_2D_samples(samples, config, tags=None):
  f, ax = plt.subplots(figsize=(6, 6))
  # sns.scatterplot(x=samples[:, 0], y=samples[:, 1], color="#4CB391", s=5)
  sns.scatterplot(x=samples[:, 0], y=samples[:, 1], color="#4CB391")
  # sns.histplot(x=samples[:, 0], y=samples[:, 1],
  #              bins=50, pthresh=.1, cmap="mako")
  # sns.kdeplot(x=samples[:, 0], y=samples[:, 1],
  #             levels=5, cmap="crest",
  #             fill=True,
  #             alpha=0.5,
  #             cut=2)
  # sns.kdeplot(x=samples[:, 0],
  #             y=samples[:, 1],
  #             levels=5,
  #             color="b",
  #             linewidths=1)
  ax.set(xlim=(-10, 10))
  ax.set(ylim=(-10, 10))
  plt.show()
  plt.savefig(f"{config['exp_folder']}/samples_{tags}.png", bbox_inches='tight')
  plt.close()


def vis_density_GMM(model, config):
  fig, ax = plt.subplots()
  x_density, y_density = 500, 500
  xses = np.linspace(-10, 10, x_density)
  yses = np.linspace(-10, 10, y_density)
  xy = torch.tensor([[[x, y] for x in xses] for y in yses]).view(-1, 2).cuda()
  log_density_values = model.log_prob(xy)
  log_density_values = log_density_values.detach().view(
      x_density, y_density).cpu().numpy()
  dx = (xses[1] - xses[0]) / 2
  dy = (yses[1] - yses[0]) / 2
  extent = [xses[0] - dx, xses[-1] + dx, yses[0] - dy, yses[-1] + dy]
  im = ax.imshow(np.exp(log_density_values),
                 extent=extent,
                 origin='lower',
                 cmap='viridis')
  ax.set_xlabel('x')
  ax.set_ylabel('y')
  divider = make_axes_locatable(ax)
  cax = divider.append_axes('right', size='5%', pad=0.05)
  cb = fig.colorbar(im, cax=cax)
  cb.set_label('probability density')
  plt.show()
  plt.savefig(f"{config['exp_folder']}/GMM_density.png", bbox_inches='tight')
  plt.close()


def vis_density_GRBM(model, config, epoch=None):
  fig, ax = plt.subplots()
  x_density, y_density = 500, 500
  xses = np.linspace(-10, 10, x_density)
  yses = np.linspace(-10, 10, y_density)
  xy = torch.tensor([[[x, y] for x in xses] for y in yses]).view(-1, 2).cuda()
  eng_val = -model.marginal_energy(xy)
  eng_val = eng_val.detach().view(x_density, y_density).cpu().numpy()
  dx = (xses[1] - xses[0]) / 2
  dy = (yses[1] - yses[0]) / 2
  extent = [xses[0] - dx, xses[-1] + dx, yses[0] - dy, yses[-1] + dy]
  im = ax.imshow(eng_val, extent=extent, origin='lower', cmap='viridis')
  ax.set_xlabel('x')
  ax.set_ylabel('y')
  divider = make_axes_locatable(ax)
  cax = divider.append_axes('right', size='5%', pad=0.05)
  cb = fig.colorbar(im, cax=cax)
  cb.set_label('negative energy')
  plt.show()
  plt.savefig(f"{config['exp_folder']}/GRBM_density_{epoch:05d}.png",
              bbox_inches='tight')
  plt.close()


def plot_curves(stats_dict, config):
  df = pd.DataFrame(data=stats_dict)

  # visualize variance
  sns_plot = sns.lineplot(data=df, x=f'Epoch', y=f'Variance')
  fig = sns_plot.get_figure()
  fig.savefig(
      f"{config['exp_folder']}/Var_curve_{config['resume']:05d}_{config['epochs']:05d}.png",
      bbox_inches='tight')
  plt.clf()

  sns_plot = sns.lineplot(data=df, x=f'Epoch', y=f'Log Variance')
  fig = sns_plot.get_figure()
  fig.savefig(
      f"{config['exp_folder']}/Log_var_curve_{config['resume']:05d}_{config['epochs']:05d}.png",
      bbox_inches='tight')
  plt.clf()

  # visualize reconstruction loss
  sns_plot = sns.lineplot(data=df, x=f'Epoch', y=f'Reconstruction MSE Loss')
  fig = sns_plot.get_figure()
  fig.savefig(
      f"{config['exp_folder']}/Recon_curve_{config['resume']:05d}_{config['epochs']:05d}.png",
      bbox_inches='tight')
  plt.clf()


class ReplayBuffer(object):

  def __init__(self, buffer_size=5) -> None:
    self.buffer_size = buffer_size
    self._queue = []

  def enqueue(self, x):
    if len(self._queue) < self.buffer_size:
      self._queue += [x]
    else:
      self._queue.pop(0)
      self._queue += [x]

  def empty(self):
    return len(self._queue) == 0

  def sample(self, num_sample=128):
    assert len(self._queue) > 0
    num_sample_total = sum([qq[0].shape[0] for qq in self._queue])
    rand_idx = torch.randperm(num_sample_total)
    idx = rand_idx[:num_sample]

    h_out = []
    for ii in range(len(self._queue[0])):
      h_out += [torch.concat([qq[ii] for qq in self._queue], dim=0)[idx, :]]

    return h_out


"""Here are some datasets."""


class GMMDataset(torch.utils.data.Dataset):

  def __init__(self, samples):
    self.samples = samples

  def __len__(self):
    return self.samples.shape[0]

  def __getitem__(self, idx):
    return self.samples[idx, :], torch.ones(1).to(self.samples.device)


class GMM(nn.Module):
  """ Gaussian Mixture Models 
      N.B.: covariance is assumed to be diagonal
  """

  def __init__(self, w, mu, sigma):
    """
      p(x) = sum_i w[i] N(mu[i], sigma[i]^2 * I)

      config:
        w: shape K X 1, mixture coefficients, must sum to 1
        mu: shape K X D, mean
        sigma: shape K X D, (diagonal) variance 
    """
    super().__init__()
    self.register_buffer('w', w)
    self.register_buffer('mu', mu)
    self.register_buffer('sigma', sigma)
    self.K = w.shape[0]
    self.D = mu.shape[1]

  @torch.no_grad()
  def log_gaussian(self, x, mu, sigma):
    """ log density of single (diagonal-covariance) multivariate Gaussian"""
    return -0.5 * ((x - mu)**2 / sigma**2).sum(dim=1) - 0.5 * (
        self.D * np.log(2 * np.pi) + torch.log(torch.prod(sigma**2)))

  @torch.no_grad()
  def log_prob(self, x):
    return torch.logsumexp(
        torch.stack([
            torch.log(self.w[kk]) +
            self.log_gaussian(x, self.mu[kk], self.sigma[kk])
            for kk in range(self.K)
        ]), 0)

  @torch.no_grad()
  def sampling(self, num_samples):
    m = torch.distributions.Categorical(self.w)
    idx = m.sample((num_samples,))
    return self.mu[idx, :] + torch.randn(num_samples, self.D).to(
        self.w.device) * self.sigma[idx, :]

  @torch.no_grad()
  def langevin_sampling(self, x, num_steps=10, eta=1.0e+0, is_anneal=False):
    eta_list = cosine_schedule(eta_max=eta, T=num_steps)
    for ii in range(num_steps):
      eta_ii = eta_list[ii] if is_anneal else eta
      x = x.detach()
      x.requires_grad = True
      eng = -self.log_prob(x).sum()
      grad = torch.autograd.grad(eng, x)[0]
      x = x - eta_ii * grad + torch.randn_like(x) * np.sqrt(eta_ii * 2)

    return x.detach()


class CelebA(datasets.VisionDataset):
  """`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.

    config:
        root (string): Root directory where images are downloaded to.
        split (string): One of {'train', 'valid', 'test', 'all'}.
            Accordingly dataset is selected.
        target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
            or ``landmarks``. Can also be a list to output a tuple with all specified target types.
            The targets represent:
                ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes
                ``identity`` (int): label for each person (data points with the same identity are the same person)
                ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height)
                ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
                    righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
            Defaults to ``attr``. If empty, ``None`` will be returned as target.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.ToTensor``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

  base_folder = "celeba"
  # There currently does not appear to be a easy way to extract 7z in python (without introducing additional
  # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
  # right now.
  file_list = [
      # File ID                         MD5 Hash                            Filename
      ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb",
       "img_align_celeba.zip"),
      # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
      # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
      ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c",
       "list_attr_celeba.txt"),
      ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2",
       "identity_CelebA.txt"),
      ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16",
       "list_bbox_celeba.txt"),
      ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c",
       "list_landmarks_align_celeba.txt"),
      # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
      ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668",
       "list_eval_partition.txt"),
  ]

  def __init__(
      self,
      root: str,
      split: str = "train",
      target_type: Union[List[str], str] = "attr",
      transform: Optional[Callable] = None,
      target_transform: Optional[Callable] = None,
      download: bool = False,
  ) -> None:
    self.return_imgid = 0
    import pandas
    super(CelebA, self).__init__(root,
                                 transform=transform,
                                 target_transform=target_transform)
    self.split = split
    if isinstance(target_type, list):
      self.target_type = target_type
    else:
      self.target_type = [target_type]

    if not self.target_type and self.target_transform is not None:
      raise RuntimeError(
          'target_transform is specified but target_type is empty')

    if download:
      self.download()

    split_map = {
        "train": 0,
        "valid": 1,
        "test": 2,
        "all": None,
    }
    split_ = split_map[verify_str_arg(split.lower(), "split",
                                      ("train", "valid", "test", "all"))]

    fn = partial(os.path.join, self.root, self.base_folder)
    splits = pandas.read_csv(fn("list_eval_partition.txt"),
                             delim_whitespace=True,
                             header=None,
                             index_col=0)

    mask = slice(None) if split_ is None else (splits[1] == split_)

    self.filename = splits[mask].index.values

  def _check_integrity(self) -> bool:
    for (_, md5, filename) in self.file_list:
      fpath = os.path.join(self.root, self.base_folder, filename)
      _, ext = os.path.splitext(filename)
      # Allow original archive to be deleted (zip and 7z)
      # Only need the extracted images
      if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
        return False

    # Should check a hash of the images
    return os.path.isdir(
        os.path.join(self.root, self.base_folder, "img_align_celeba"))

  def download(self) -> None:
    import zipfile

    if self._check_integrity():
      print('Files already downloaded and verified')
      return

    for (file_id, md5, filename) in self.file_list:
      download_file_from_google_drive(file_id,
                                      os.path.join(self.root, self.base_folder),
                                      filename, md5)

    with zipfile.ZipFile(
        os.path.join(self.root, self.base_folder, "img_align_celeba.zip"),
        "r") as f:
      f.extractall(os.path.join(self.root, self.base_folder))

  def __getitem__(self, index: int) -> Tuple[Any, Any]:
    X = PIL.Image.open(
        os.path.join(self.root, self.base_folder, "img_align_celeba",
                     self.filename[index]))

    if self.transform is not None:
      X = self.transform(X)

    target = torch.zeros(X.shape[0])
    return X, target

  def label2imgid(self):
    self.return_imgid = 1

  def __len__(self) -> int:
    return len(self.filename)

  def extra_repr(self) -> str:
    lines = ["Target type: {target_type}", "Split: {split}"]
    return '\n'.join(lines).format(**self.__dict__)


"""Here is the model."""


class GRBM(nn.Module):
  """ Gaussian-Bernoulli Restricted Boltzmann Machines (GRBM) """

  def __init__(self,
               visible_size,
               hidden_size,
               CD_step=1,
               CD_burnin=0,
               init_var=1e-0,
               neg_MC_init_var=1e-0,
               inference_method='Gibbs',
               Langevin_step=10,
               Langevin_eta=1.0,
               is_anneal_Langevin=True,
               Langevin_adjust_step=0) -> None:
    super().__init__()
    # we use samples in [CD_burnin, CD_step) steps
    assert CD_burnin >= 0 and CD_burnin <= CD_step
    assert inference_method in ['Gibbs', 'Langevin', 'Langevin-Gibbs']

    self.visible_size = visible_size
    self.hidden_size = hidden_size
    self.CD_step = CD_step
    self.CD_burnin = CD_burnin
    self.init_var = init_var
    self.neg_MC_init_var = neg_MC_init_var
    self.inference_method = inference_method
    self.Langevin_step = Langevin_step
    self.Langevin_eta = Langevin_eta
    self.is_anneal_Langevin = is_anneal_Langevin
    self.Langevin_adjust_step = Langevin_adjust_step

    self.W = nn.Parameter(torch.Tensor(visible_size, hidden_size))
    self.b = nn.Parameter(torch.Tensor(hidden_size))
    self.mu = nn.Parameter(torch.Tensor(visible_size))
    self.log_var = nn.Parameter(torch.Tensor(visible_size))
    self.reset_parameters()

  def reset_parameters(self):
    nn.init.normal_(self.W,
                    std=1.0 * self.init_var /
                    np.sqrt(self.visible_size + self.hidden_size))
    nn.init.constant_(self.b, 0.0)
    nn.init.constant_(self.mu, 0.0)
    nn.init.constant_(self.log_var,
                      np.log(self.init_var))  # init variance = 1.0

  def get_var(self):
    return self.log_var.exp().clip(min=1e-8)

  def set_Langevin_adjust_step(self, step):
    self.Langevin_adjust_step = step

  @torch.no_grad()
  def energy(self, v, h):
    var = self.get_var()
    eng = 0.5 * ((v - self.mu)**2 / var).sum(dim=1)
    eng -= ((v / var).mm(self.W) * h).sum(dim=1) + h.mv(self.b)
    return eng

  @torch.no_grad()
  def marginal_energy(self, v):
    var = self.get_var()
    eng = 0.5 * ((v - self.mu)**2 / var).sum(dim=1)
    eng -= F.softplus((v / var).mm(self.W) + self.b).sum(dim=1)
    return eng

  @torch.no_grad()
  def energy_grad_v(self, v, h):
    B = v.shape[0]
    var = self.get_var()
    return ((v - self.mu) / var - h.mm(self.W.T) / var) / B

  @torch.no_grad()
  def marginal_energy_grad_v(self, v):
    B = v.shape[0]
    var = self.get_var()
    return ((v - self.mu) / var -
            torch.sigmoid((v / var).mm(self.W) + self.b).mm(self.W.T) / var) / B

  @torch.no_grad()
  def energy_grad_param(self, v, h):
    var = self.get_var()
    grad = {}
    grad['W'] = -torch.einsum("bi,bj->ij", v / var, h) / v.shape[0]
    grad['b'] = -h.mean(dim=0)
    grad['mu'] = ((self.mu - v) / var).mean(dim=0)
    grad['log_var'] = (-0.5 * (v - self.mu)**2 / var +
                       ((v / var) * h.mm(self.W.T))).mean(dim=0)
    return grad

  @torch.no_grad()
  def marginal_energy_grad_param(self, v):
    var = self.get_var()
    vv = v / var
    tmp = torch.sigmoid(vv.mm(self.W) + self.b)
    grad = {}
    grad['W'] = -torch.einsum("bi,bj->ij", vv, tmp) / v.shape[0]
    grad['b'] = -tmp.mean(dim=0)
    grad['mu'] = ((self.mu - v) / var).mean(dim=0)
    grad['log_var'] = (-0.5 * (v - self.mu)**2 / var +
                       (vv * tmp.mm(self.W.T))).mean(dim=0)
    return grad

  @torch.no_grad()
  def prob_h_given_v(self, v, var):
    return torch.sigmoid((v / var).mm(self.W) + self.b)

  @torch.no_grad()
  def prob_v_given_h(self, h):
    return h.mm(self.W.T) + self.mu

  @torch.no_grad()
  def log_metropolis_ratio_v_given_h(self, v_old, v_new, h, grad_old, eta):
    """ Metropolis-Hasting ratio of accepting the move from old to new state """
    eng_diff = -self.energy(v_new, h) + self.energy(v_old, h)
    proposal_eng_new = - \
        torch.pow(v_old - v_new + eta *
                  self.energy_grad_v(v_new, h), 2.0).sum(dim=1) / (4 * eta)
    proposal_eng_old = - \
        torch.pow(v_new - v_old + eta * grad_old,
                  2.0).sum(dim=1) / (4 * eta)

    return eng_diff + proposal_eng_new - proposal_eng_old

  @torch.no_grad()
  def log_metropolis_ratio_Gibbs_joint(self, v_old, h_old, v_new, h_new, var):
    """ Metropolis-Hasting ratio of accepting the move from old to new state """
    marginal_eng_diff = -self.marginal_energy(v_new) + self.marginal_energy(
        v_old)
    eng_diff = -0.5 * (
        (v_old - self.mu - h_new.mm(self.W.T))**2 / var).sum(dim=1) + 0.5 * (
            (v_new - self.mu - h_old.mm(self.W.T))**2 / var).sum(dim=1)
    return eng_diff + marginal_eng_diff

  @torch.no_grad()
  def log_metropolis_ratio_Langevin_Gibbs_joint(self, v_old, h_old, v_new,
                                                h_new, eta):
    """ Metropolis-Hasting ratio of accepting the move from old to new state """
    eng_diff = -self.marginal_energy(v_new) + self.marginal_energy(v_old)
    proposal_eng_new = - \
        torch.pow(v_old - v_new + eta * self.energy_grad_v(v_new, h_new), 2.0).sum(dim=1) / (4 * eta)
    proposal_eng_old = - \
        torch.pow(v_new - v_old + eta * self.energy_grad_v(v_old, h_old), 2.0).sum(dim=1) / (4 * eta)

    return eng_diff + proposal_eng_new - proposal_eng_old

  @torch.no_grad()
  def log_metropolis_ratio_marginal(self, v_old, v_new, grad_old, eta):
    """ Metropolis-Hasting ratio of accepting the move from old to new state """
    eng_diff = -self.marginal_energy(v_new) + self.marginal_energy(v_old)
    proposal_eng_new = - \
        torch.pow(v_old - v_new + eta *
                  self.marginal_energy_grad_v(v_new), 2.0).sum(dim=1) / (4 * eta)
    proposal_eng_old = - \
        torch.pow(v_new - v_old + eta * grad_old,
                  2.0).sum(dim=1) / (4 * eta)

    return eng_diff + proposal_eng_new - proposal_eng_old

  @torch.no_grad()
  def Gibbs_sampling_vh(self, v, num_steps=10, burn_in=0):
    samples, var = [], self.get_var()
    std = var.sqrt()
    h = torch.bernoulli(self.prob_h_given_v(v, var))
    for ii in range(num_steps):
      # backward sampling
      mu = self.prob_v_given_h(h)
      v = mu + torch.randn_like(mu) * std

      # forward sampling
      h = torch.bernoulli(self.prob_h_given_v(v, var))

      if ii >= burn_in:
        samples += [(v, h)]

    return samples

  @torch.no_grad()
  def Langevin_sampling_v(self,
                          v,
                          num_steps=10,
                          eta=1.0e+0,
                          burn_in=0,
                          is_anneal=True,
                          adjust_step=0):
    eta_list = cosine_schedule(eta_max=eta, T=num_steps)
    samples = []

    for ii in range(num_steps):
      eta_ii = eta_list[ii] if is_anneal else eta
      grad_v = self.marginal_energy_grad_v(v)

      v_new = v - eta_ii * grad_v + \
          torch.randn_like(v) * np.sqrt(eta_ii * 2)

      if ii >= adjust_step:
        tmp_u = torch.rand(v.shape[0]).to(v.device)
        log_ratio = self.log_metropolis_ratio_marginal(v, v_new, grad_v, eta_ii)
        ratio = torch.minimum(torch.ones_like(log_ratio), log_ratio.exp())
        v = v_new * (tmp_u < ratio).float().view(
            -1, 1) + v * (tmp_u >= ratio).float().view(-1, 1)
      else:
        v = v_new

      if ii >= burn_in:
        samples += [v]

    return samples

  @torch.no_grad()
  def Langevin_Gibbs_sampling_vh(self,
                                 v,
                                 num_steps=10,
                                 num_steps_Langevin=10,
                                 eta=1.0e+0,
                                 burn_in=0,
                                 is_anneal=True,
                                 adjust_step=0):
    samples, var = [], self.get_var()
    eta_list = cosine_schedule(eta_max=eta, T=num_steps_Langevin)

    h = torch.bernoulli(self.prob_h_given_v(v, var))
    for ii in range(num_steps):
      v_old, h_old = v, h
      # backward sampling
      for jj in range(num_steps_Langevin):
        eta_jj = eta_list[jj] if is_anneal else eta
        grad_v = self.energy_grad_v(v, h)
        v = v - eta_jj * grad_v + torch.randn_like(v) * np.sqrt(eta_jj * 2)

      # forward sampling
      h = torch.bernoulli(self.prob_h_given_v(v, var))

      if ii >= adjust_step:
        tmp_u = torch.rand(v.shape[0]).to(v.device)
        log_ratio = self.log_metropolis_ratio_Langevin_Gibbs_joint(
            v_old, h_old, v, h, eta)
        ratio = torch.minimum(torch.ones_like(log_ratio), log_ratio.exp())
        v = v * (tmp_u < ratio).float().view(
            -1, 1) + v_old * (tmp_u >= ratio).float().view(-1, 1)
        h = h * (tmp_u < ratio).float().view(
            -1, 1) + h_old * (tmp_u >= ratio).float().view(-1, 1)

      if ii >= burn_in:
        samples += [(v, h)]

    return samples

  @torch.no_grad()
  def reconstruction(self, v):
    v, var = v.view(v.shape[0], -1), self.get_var()
    prob_h = self.prob_h_given_v(v, var)
    v_bar = self.prob_v_given_h(prob_h)
    return F.mse_loss(v, v_bar)

  @torch.no_grad()
  def sampling(self, v_init, num_steps=1, save_gap=1):
    v_shape = v_init.shape
    v = v_init.view(v_shape[0], -1)
    var = self.get_var()
    var_mean = var.mean().item()

    if self.inference_method == 'Gibbs':
      samples = self.Gibbs_sampling_vh(v, num_steps=num_steps - 1)
      samples = [xx[0] for xx in samples]  # extract v
    elif self.inference_method == 'Langevin':
      samples = self.Langevin_sampling_v(v,
                                         num_steps=num_steps - 1,
                                         eta=self.Langevin_eta * var_mean,
                                         is_anneal=self.is_anneal_Langevin,
                                         adjust_step=self.Langevin_adjust_step)
    elif self.inference_method == 'Langevin-Gibbs':
      samples = self.Langevin_Gibbs_sampling_vh(
          v,
          num_steps=num_steps - 1,
          num_steps_Langevin=self.Langevin_step,
          eta=self.Langevin_eta * var_mean,
          is_anneal=self.is_anneal_Langevin,
          adjust_step=self.Langevin_adjust_step)
      samples = [xx[0] for xx in samples]  # extract v

    # use conditional mean as the last sample
    h = torch.bernoulli(self.prob_h_given_v(samples[-1], var))
    mu = self.prob_v_given_h(h)
    v_list = [(0, v_init)] + [(ii + 1, samples[ii].view(v_shape).detach())
                              for ii in range(num_steps - 1)
                              if (ii + 1) % save_gap == 0
                             ] + [(num_steps, mu.view(v_shape).detach())]

    return v_list

  @torch.no_grad()
  def positive_grad(self, v):
    h = torch.bernoulli(self.prob_h_given_v(v, self.get_var()))
    grad = self.energy_grad_param(v, h)
    return grad

  @torch.no_grad()
  def negative_grad(self, v):
    var = self.get_var()
    var_mean = var.mean().item()
    if self.inference_method == 'Gibbs':
      samples = self.Gibbs_sampling_vh(v,
                                       num_steps=self.CD_step,
                                       burn_in=self.CD_burnin)
      v_neg = torch.cat([xx[0] for xx in samples], dim=0)
      h_neg = torch.cat([xx[1] for xx in samples], dim=0)
      grad = self.energy_grad_param(v_neg, h_neg)
    elif self.inference_method == 'Langevin':
      samples = self.Langevin_sampling_v(v,
                                         num_steps=self.CD_step,
                                         burn_in=self.CD_burnin,
                                         eta=self.Langevin_eta * var_mean,
                                         is_anneal=self.is_anneal_Langevin,
                                         adjust_step=self.Langevin_adjust_step)
      v_neg = torch.cat(samples, dim=0)
      grad = self.marginal_energy_grad_param(v_neg)

    elif self.inference_method == 'Langevin-Gibbs':
      samples = self.Langevin_Gibbs_sampling_vh(
          v,
          num_steps=self.CD_step,
          burn_in=self.CD_burnin,
          num_steps_Langevin=self.Langevin_step,
          eta=self.Langevin_eta * var_mean,
          is_anneal=self.is_anneal_Langevin,
          adjust_step=self.Langevin_adjust_step)
      v_neg = torch.cat([xx[0] for xx in samples], dim=0)
      h_neg = torch.cat([xx[1] for xx in samples], dim=0)
      grad = self.energy_grad_param(v_neg, h_neg)

    return grad

  @torch.no_grad()
  def CD_grad(self, v):
    v = v.view(v.shape[0], -1)
    # postive gradient
    grad_pos = self.positive_grad(v)

    # negative gradient
    v_neg = torch.randn_like(v) * self.neg_MC_init_var
    grad_neg = self.negative_grad(v_neg)

    # compute update
    for name, param in self.named_parameters():
      param.grad = grad_pos[name] - grad_neg[name]


"""Here is the config of running experiments."""


def get_config(pid):
  config = {}
  config['dataset'] = 'GMM'
  # config['dataset'] = 'MNIST'
  # config['dataset'] = 'CIFAR10'
  # config['dataset'] = 'CelebA'
  config['cuda'] = True
  config['model'] = 'GRBM'
  config['batch_size'] = 128
  config['epochs'] = 10000
  config['lr'] = 1.0e-2
  config['clip_norm'] = 10.0
  config['wd'] = 0.0e-4
  config['seed'] = 1  # random seed
  config['optimizer'] = 'SGD'
  config['resume'] = 0
  config[
      'is_vis_verbose'] = True  # visualize sampling process, filters, hiddens if True
  config['init_var'] = 1e-0  # init variance of GRBM
  config['hidden_size'] = 4096
  config['CD_step'] = 100
  config['CD_burnin'] = 0
  config['Langevin_step'] = 1
  config['Langevin_eta'] = 1.0e+1
  config['is_anneal_Langevin'] = True
  config['Langevin_adjust_step'] = 100
  config['use_replay_buffer'] = False  # ignore for now
  config['is_anneal_data_noise'] = False  # ignore for now
  config['inference_method'] = 'Gibbs'
  # config['inference_method'] = 'Langevin'
  # config['inference_method'] = 'Langevin-Gibbs'
  config['neg_MC_init_var'] = 1.0  # ignore for now
  config['data_init_var'] = 2.0  # ignore for now
  config['sampling_batch_size'] = 100
  config['sampling_steps'] = config['CD_step']
  config['sampling_gap'] = min(5, config['sampling_steps'])
  config['sampling_nrow'] = 10

  if config['dataset'] == 'GMM':
    config['batch_size'] = 100
    config['num_samples'] = 1000
    config['height'] = 1
    config['width'] = 1
    config['channel'] = 2
    config['log_interval'] = 100
    config['save_interval'] = 50000
    config['epochs'] = 50000
    config['hidden_size'] = 256
  elif config['dataset'] == 'MNIST':
    config['height'] = 28
    config['width'] = 28
    config['channel'] = 1
    config['img_mean'] = torch.tensor([0.1307])
    config['img_std'] = torch.tensor([0.3081])
    config['log_interval'] = 10
    config['save_interval'] = 100
  elif config['dataset'] == 'CIFAR10':
    config['height'] = 32
    config['width'] = 32
    config['channel'] = 3
    config['img_mean'] = torch.tensor([0.4914, 0.4822, 0.4465])
    config['img_std'] = torch.tensor([0.2470, 0.2435, 0.2616])
    config['log_interval'] = 10
    config['save_interval'] = 10000
    # config['log_interval'] = 1
    # config['save_interval'] = 10
  elif config['dataset'] == 'CelebA':
    config['height'] = 32
    config['width'] = 32
    config['channel'] = 3
    config['crop_size'] = 140
    config['img_mean'] = torch.tensor([0.5240, 0.4152, 0.3590])
    config['img_std'] = torch.tensor([0.2868, 0.2530, 0.2453])
    config['log_interval'] = 1
    config['save_interval'] = 5

  config[
      'visible_size'] = config['height'] * config['width'] * config['channel']

  if config['inference_method'] == 'Gibbs':
    config[
        'exp_folder'] = f"exp/{config['dataset']}_{config['model']}_{pid}_inference={config['inference_method']}_H={config['hidden_size']}_lr={config['lr']}_B={config['batch_size']}_initvar={config['init_var']}_CD={config['CD_step']}_burnin={config['CD_burnin']}_is_anneal_data_noise={config['is_anneal_data_noise']}_{config['optimizer']}"
  elif config['inference_method'] == 'Langevin':
    config[
        'exp_folder'] = f"exp/{config['dataset']}_{config['model']}_{pid}_inference={config['inference_method']}_is_anneal_Langevin={config['is_anneal_Langevin']}_Langevin_adjust_step={config['Langevin_adjust_step']}_Langevin_eta={config['Langevin_eta']}_H={config['hidden_size']}_lr={config['lr']}_B={config['batch_size']}_initvar={config['init_var']}_CD={config['CD_step']}_burnin={config['CD_burnin']}_is_anneal_data_noise={config['is_anneal_data_noise']}_{config['optimizer']}"
  elif config['inference_method'] == 'Langevin-Gibbs':
    config[
        'exp_folder'] = f"exp/{config['dataset']}_{config['model']}_{pid}_inference={config['inference_method']}_is_anneal_Langevin={config['is_anneal_Langevin']}_Langevin_adjust_step={config['Langevin_adjust_step']}_Langevin_step={config['Langevin_step']}_Langevin_eta={config['Langevin_eta']}_H={config['hidden_size']}_lr={config['lr']}_B={config['batch_size']}_initvar={config['init_var']}_CD={config['CD_step']}_burnin={config['CD_burnin']}_is_anneal_data_noise={config['is_anneal_data_noise']}_{config['optimizer']}"

  return config


"""Here is the main training loop."""


def train_model():
  """Let us train a GRBM and see how it performs"""
  pid = os.getpid()
  config = get_config(pid)
  if not os.path.isdir(config['exp_folder']):
    os.makedirs(config['exp_folder'])

  log_file = os.path.join(config['exp_folder'], f'log_exp_{pid}.txt')
  logger = setup_logging('INFO', log_file)
  logger.info('Writing log file to {}'.format(log_file))

  with open(os.path.join(config['exp_folder'], f'config_{pid}.json'),
            'w') as outfile:
    json.dump(config, outfile, cls=TensorEncoder, indent=4)

  if config['dataset'] == 'GMM':
    gmm_model = GMM(torch.tensor([0.33, 0.33, 0.34]),
                    torch.tensor([[-5, -5], [5, -5], [0, 5]]),
                    torch.tensor([[1.25, 0.5], [1.25, 0.5], [0.5,
                                                             1.25]])).cuda()
    # gmm_model = GMM(torch.tensor([0.33, 0.33, 0.34]),
    #                 torch.tensor([[-5, -5], [5, -5], [0, 5]]),
    #                 torch.tensor([[1, 1], [1, 1], [1, 1]])).cuda()
    # gmm_model = GMM(torch.tensor([1.0]), torch.tensor([[2, 2]]),
    #              torch.tensor([[2.0, 0.5]])).cuda()

    vis_density_GMM(gmm_model, config)
    samples = gmm_model.sampling(config['num_samples'])
    vis_2D_samples(samples.cpu().numpy(), config, tags='ground_truth')
    train_set = GMMDataset(samples)
  elif config['dataset'] == 'MNIST':
    train_set = datasets.MNIST('./data',
                               train=True,
                               download=True,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Normalize(config['img_mean'],
                                                        config['img_std'])
                               ]))
  elif config['dataset'] == 'CIFAR10':
    train_set = datasets.CIFAR10('./data',
                                 train=True,
                                 download=True,
                                 transform=transforms.Compose([
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         config['img_mean'], config['img_std'])
                                 ]))
  elif config['dataset'] == 'CelebA':
    train_set = CelebA('./data',
                       split='train',
                       download=False,
                       transform=transforms.Compose([
                           transforms.CenterCrop(config['crop_size']),
                           transforms.Resize(config['height']),
                           transforms.ToTensor(),
                           transforms.Normalize(config['img_mean'],
                                                config['img_std'])
                       ]))

  # use a subset to debug
  # train_set = torch.utils.data.Subset(train_set, range(128))

  train_loader = torch.utils.data.DataLoader(train_set,
                                             batch_size=config['batch_size'],
                                             shuffle=True)

  model = GRBM(config['visible_size'],
               config['hidden_size'],
               CD_step=config['CD_step'],
               CD_burnin=config['CD_burnin'],
               init_var=config['init_var'],
               neg_MC_init_var=config['neg_MC_init_var'],
               inference_method=config['inference_method'],
               Langevin_step=config['Langevin_step'],
               Langevin_eta=config['Langevin_eta'],
               is_anneal_Langevin=config['is_anneal_Langevin'],
               Langevin_adjust_step=config['Langevin_adjust_step'])

  if config['cuda']:
    model.cuda()

  param_wd, param_no_wd = [], []
  for xx, yy in model.named_parameters():
    if 'W' in xx:
      param_wd += [yy]
    else:
      param_no_wd += [yy]

  if config['optimizer'] == 'Adam':
    optimizer = optim.Adam([{
        'params': param_no_wd,
        'weight_decay': 0
    }, {
        'params': param_wd
    }],
                           lr=config['lr'],
                           weight_decay=config['wd'])
  else:
    optimizer = optim.SGD([{
        'params': param_no_wd,
        'weight_decay': 0
    }, {
        'params': param_wd
    }],
                          lr=config['lr'],
                          momentum=0.0,
                          weight_decay=config['wd'])

  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
      optimizer, config['epochs'])

  if config['resume'] > 0:
    load(model, config['exp_folder'], config['resume'])

  stats_dict = defaultdict(list)
  for epoch in range(config['resume']):
    scheduler.step()

  is_show_training_data = False
  data_std = [
      math.sqrt(config['data_init_var']) *
      (1 + math.cos(tt * math.pi / config['epochs'])) / 2
      for tt in range(config['epochs'])
  ]
  for epoch in range(config['resume'] + 1, config['epochs'] + 1):
    recon_loss = train(model,
                       train_loader,
                       optimizer,
                       config,
                       is_anneal_data_noise=config['is_anneal_data_noise'],
                       std=data_std[epoch - 1])

    var = model.get_var().detach().cpu().numpy()
    var_list = var.tolist()
    log_var_list = np.log(var).tolist()
    stats_dict['Variance'] += var_list
    stats_dict['Log Variance'] += log_var_list
    stats_dict['Epoch'] += [epoch] * len(var_list)
    stats_dict['Units'] += range(1, len(var_list) + 1)
    stats_dict['Reconstruction MSE Loss'] += [recon_loss] * len(var_list)

    if epoch % config['log_interval'] == 0:
      if config['dataset'] == 'GMM':
        logger.info(
            f'PID={pid} || {epoch} epoch || mean = {model.mu.detach().cpu().numpy()} || var={model.get_var().detach().cpu().numpy()} || Reconstruction Loss = {recon_loss}'
        )
      else:
        logger.info(
            f'PID={pid} || {epoch} epoch || var={model.get_var().mean().item()} || Reconstruction Loss = {recon_loss}'
        )

      visualize_sampling(model,
                         epoch,
                         config,
                         is_show_gif=config['is_vis_verbose'])
      if config['is_vis_verbose']:
        # visualize filters
        filters = model.W.T.view(model.W.shape[1], config['channel'],
                                 config['height'], config['width'])
        utils.save_image(
            filters,
            f"{config['exp_folder']}/filters_epoch_{epoch:05d}.png",
            nrow=8,
            normalize=True,
            padding=1,
            pad_value=1.0)

        # visualize hidden states
        data, _ = next(iter(train_loader))
        h_pos = model.prob_h_given_v(
            data.view(data.shape[0], -1).cuda(), model.get_var())
        utils.save_image(h_pos.view(1, 1, -1, config['hidden_size']),
                         f"{config['exp_folder']}/hidden_epoch_{epoch:05d}.png",
                         normalize=True)

        # visualize one mini-batch of training data
        if not is_show_training_data and config['dataset'] != 'GMM':
          data, _ = next(iter(train_loader))
          mean = config['img_mean'].view(1, -1, 1, 1).to(data.device)
          std = config['img_std'].view(1, -1, 1, 1).to(data.device)
          vis_data = (data * std + mean).clamp(min=0, max=1)
          utils.save_image(
              utils.make_grid(vis_data,
                              nrow=config['sampling_nrow'],
                              normalize=False,
                              padding=1,
                              pad_value=1.0).cpu(),
              f"{config['exp_folder']}/training_imgs.png")
          is_show_training_data = True

    # save samples periodically.
    if epoch % config['save_interval'] == 0:
      save(model, config['exp_folder'], epoch)
      # plotting a lot of points is slow so we use another thread
      thread = Thread(target=plot_curves, args=(stats_dict, config))
      thread.start()
    #   thread.join()

    scheduler.step()


if __name__ == '__main__':
  train_model()
  # test_model()


# Evaluate Trained Models

Required packages:

*   pytorch_fid
*   seaborn
*   pandas
*   matplotlib






In [None]:
import os
import pdb
import glob
import json
import subprocess
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import imageio

import numpy as np
from tqdm import tqdm
import subprocess
import seaborn as sns
import pandas as pd
from collections import defaultdict
import matplotlib

matplotlib.use('Agg')
import matplotlib.pyplot as plt  # NOQA

sns.set_theme(style="darkgrid")

from scipy.stats import entropy

import torch
from torchvision import datasets, transforms, utils
from torch import nn
from torch.nn import functional as F
import torch.utils.data
from torchvision.models.inception import inception_v3
from model.grbm_minimal import CelebA, GRBM

parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--path',
                    type=str,
                    default=None,
                    help='path of model folder')
parser.add_argument('--is_test',
                    type=bool,
                    default=True,
                    help='evaluate using training or testing set')
parser.add_argument('--gpu_list', nargs='+', default=['0', '1'], help='gpu ID')


class CustomImageDataset(torch.utils.data.Dataset):

  def __init__(self, root_dir, transform=None):
    """
        config:        
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """

    self.root_dir = root_dir
    self.img_names = glob.glob(os.path.join(self.root_dir, '*.png'))
    self.transform = transform

  def __len__(self):
    return len(self.img_names)

  def __getitem__(self, idx):
    image = imageio.imread(self.img_names[idx])

    if self.transform:
      sample = self.transform(image)
      if sample.shape[0] == 1:
        sample = sample.repeat((3, 1, 1))

    return sample


def inception_score(img_path, cuda=True, batch_size=32, resize=False, splits=1):
  """Computes the inception score of the generated images imgs

    imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]
    cuda -- whether or not to run on GPU
    batch_size -- batch size for feeding into Inception v3
    splits -- number of splits
    """
  custom_transform = transforms.Compose(
      [transforms.ToTensor(),
       transforms.Normalize((0.5,), (0.5,))])
  dataset = CustomImageDataset(img_path, transform=custom_transform)

  N = len(dataset)

  assert batch_size > 0
  assert N > batch_size

  # Set up dtype
  if cuda:
    dtype = torch.cuda.FloatTensor
  else:
    if torch.cuda.is_available():
      print(
          "WARNING: You have a CUDA device, so you should probably set cuda=True"
      )
    dtype = torch.FloatTensor

  # Set up dataloader
  dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)

  # Load inception model
  inception_model = inception_v3(pretrained=True,
                                 transform_input=False).type(dtype)
  inception_model.eval()
  up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype)

  def get_pred(x):
    if resize:
      x = up(x)
    x = inception_model(x)
    return F.softmax(x, dim=1).data.cpu().numpy()

  # Get predictions
  preds = np.zeros((N, 1000))

  for i, batch in enumerate(dataloader, 0):
    # batch = batch.type(dtype)
    # batchv = Variable(batch)
    batchv = batch.type(dtype)
    batch_size_i = batch.size()[0]

    preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv)

  # Now compute the mean kl-div
  split_scores = []

  for k in range(splits):
    part = preds[k * (N // splits):(k + 1) * (N // splits), :]
    py = np.mean(part, axis=0)
    scores = []
    for i in range(part.shape[0]):
      pyx = part[i, :]
      scores.append(entropy(pyx, py))
    split_scores.append(np.exp(np.mean(scores)))

  return np.mean(split_scores), np.std(split_scores)


def load(model, model_name):
  data = torch.load(model_name)
  model.load_state_dict(data['model'])


def eval_model():
  input_args = parser.parse_args()
  config_file = glob.glob(os.path.join(input_args.path, 'config*.json'))[0]
  config = json.load(open(config_file, 'r'))
  config['img_mean'] = torch.from_numpy(np.array(config['img_mean']))
  config['img_std'] = torch.from_numpy(np.array(config['img_std']))
  config['exp_folder'] = input_args.path
  config['cuda'] = True
  config['split'] = 'test' if input_args.is_test else 'train'
  config['batch_size'] = 100
  config['sampling_steps'] = 100
  config['sampling_gap'] = 5
  config['sampling_nrow'] = 10
  config['model_img_folder'] = f"{config['exp_folder']}/gen_imgs"

  if config['dataset'] == 'MNIST':
    data_set = datasets.MNIST('./data',
                              train=True,
                              download=True,
                              transform=transforms.Compose(
                                  [transforms.ToTensor()]))
    config['gt_img_folder'] = f"data/MNIST/imgs_{config['split']}"
  elif config['dataset'] == 'CIFAR10':
    data_set = datasets.CIFAR10('./data',
                                train=True,
                                download=True,
                                transform=transforms.Compose(
                                    [transforms.ToTensor()]))
    config['gt_img_folder'] = f"data/CIFAR10/imgs_{config['split']}"
  elif config['dataset'] == 'CelebA':
    data_set = CelebA('./data',
                      split='train',
                      download=False,
                      transform=transforms.Compose([
                          transforms.CenterCrop(config['crop_size']),
                          transforms.Resize(config['height']),
                          transforms.ToTensor()
                      ]))
    config['gt_img_folder'] = f"data/celeba/imgs_{config['split']}"

  if not os.path.isdir(config['gt_img_folder']):
    os.makedirs(config['gt_img_folder'])

  if not os.path.isdir(config['model_img_folder']):
    os.makedirs(config['model_img_folder'])

  data_set = torch.utils.data.Subset(data_set, range(128))

  data_loader = torch.utils.data.DataLoader(data_set,
                                            batch_size=config['batch_size'])

  model = GRBM(config['visible_size'],
               config['hidden_size'],
               CD_step=config['CD_step'],
               CD_burnin=config['CD_burnin'],
               init_var=config['init_var'],
               neg_MC_init_var=config['neg_MC_init_var'],
               inference_method=config['inference_method'],
               Langevin_step=config['Langevin_step'],
               Langevin_eta=config['Langevin_eta'],
               is_anneal_Langevin=config['is_anneal_Langevin'],
               Langevin_adjust_step=config['Langevin_adjust_step'])

  if config['cuda']:
    model.cuda()

  model_names = glob.glob(os.path.join(config['exp_folder'], '*.pt'))

  score_dict = defaultdict(list)
  best_epoch, best_FID_score, best_IS_score = 0, 1e+5, 0
  if isinstance(config['img_mean'], torch.Tensor):
    config['img_mean'] = config['img_mean'].view(1, -1, 1, 1).cuda()
    config['img_std'] = config['img_std'].view(1, -1, 1, 1).cuda()

  for name in model_names:
    load(model, name)
    count = 0
    for data, _ in tqdm(data_loader):
      imgs = (data.permute(0, 2, 3, 1).numpy() * 255.0).astype(np.uint8)
      for ii in range(data.shape[0]):
        imageio.imsave(f"{config['gt_img_folder']}/img_{count+ii+1:06d}.png",
                       imgs[ii])

      v_init = torch.randn(config['batch_size'], config['channel'],
                           config['height'],
                           config['width']).cuda() * config['neg_MC_init_var']
      v_list = model.sampling(v_init,
                              num_steps=config['sampling_steps'],
                              save_gap=config['sampling_gap'])

      imgs = ((v_list[-1][1] * config['img_std'] + config['img_mean']).clamp(
          min=0, max=1).cpu().permute(0, 2, 3, 1).numpy() * 255.0).astype(
              np.uint8)

      for ii in range(data.shape[0]):
        imageio.imsave(f"{config['model_img_folder']}/img_{count+ii+1:06d}.png",
                       imgs[ii])

      count += data.shape[0]

    cmd = f"/opt/conda/bin/python -m pytorch_fid {config['model_img_folder']} {config['gt_img_folder']} --device cuda:0"

    process = subprocess.Popen(cmd,
                               stdout=subprocess.PIPE,
                               stderr=None,
                               shell=True)
    output = process.communicate()[0]
    epoch = name.split('/')[-1].split('-')[1][:-3]
    FID_score = float(output.decode('UTF-8').split(':')[1].strip())

    IS_score_mean, IS_score_std = inception_score(config['model_img_folder'],
                                                  cuda=True,
                                                  batch_size=32,
                                                  resize=True,
                                                  splits=10)

    score_dict['FID'] += [FID_score]
    score_dict['Epoch'] += [int(epoch)]

    print(
        f'Epoch = {epoch} || FID score = {FID_score} || Inception Score = {IS_score_mean}'
    )
    if FID_score < best_FID_score:
      best_epoch = epoch
      best_FID_score = FID_score
      best_IS_score = IS_score_mean

  print(
      f'Best Epoch = {best_epoch} || Best FID score = {best_FID_score} || Best Inception score = {best_IS_score}'
  )
  print(config['exp_folder'])

  df = pd.DataFrame(data=score_dict)
  sns_plot = sns.lineplot(data=df, x="Epoch", y="FID", sort=True)
  fig = sns_plot.get_figure()
  fig.savefig(f"{config['exp_folder']}/FID_curve.png")
  plt.clf()

# ignored
def parallel_eval_model():

  def last_7chars(x):
    return (x[-7:])

  input_args = parser.parse_args()
  model_names = sorted(glob.glob(os.path.join(input_args.path, 'epoch*')),
                       key=last_7chars)
  num_jobs_parallel = len(input_args.gpu_list)
  results = {}

  for ii in range(0, len(model_names), num_jobs_parallel):
    process_list = []
    for jj in range(min(num_jobs_parallel, len(model_names) - ii)):
      name = model_names[ii + jj]
      command = [
          'python', '-m', 'pytorch_fid', name, 'data/MNIST/test', '--device',
          'cuda:{}'.format(input_args.gpu_list[jj])
      ]
      process_list += [
          subprocess.Popen(command, encoding="utf8", stdout=subprocess.PIPE)
      ]

    for process in process_list:
      process.wait()

    for jj, process in enumerate(process_list):
      epoch = model_names[ii + jj][-7:]
      results[epoch] = float(process.stdout.readline().rstrip('\n')[4:])

  sorted_results = list(sorted(results.items(), key=lambda item: item[1]))
  with open(os.path.join(input_args.path, 'test_results.txt'), 'w') as ff:
    ff.write('{}\n'.format(input_args.path))
    for ii in range(len(sorted_results)):
      ff.write('FID score @ Epoch = {} & {}\n'.format(sorted_results[ii][1],
                                                      sorted_results[ii][0]))

  print('Best FID score & Epoch = {} & {}'.format(sorted_results[0][1],
                                                  sorted_results[0][0]))


if __name__ == '__main__':
  eval_model()