<a href="https://colab.research.google.com/github/karimul/Riset-EBM/blob/main/improved_contrastive_divergence_v5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Mounting to Google Drive

In [None]:
from google.colab import drive
import os
drive.mount('/content/drive')

ROOT = "/content/drive/MyDrive/Colab Notebooks"
sample_dir = os.path.join(ROOT, 'improved_contrastive_divergence.v5')
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
os.chdir(sample_dir)

## Dependencies

In [None]:
!pip install geomloss

In [None]:
from easydict import EasyDict
from tqdm import tqdm
import time
import timeit
import os.path as osp
import pandas as pd
from PIL import Image
import pickle
from imageio import imread
import cv2
import scipy.spatial as ss

import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset
import torchvision
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import MNIST
from torch.nn import Dropout
from torch.optim import Adam, SGD
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from torchvision import models


import numpy as np
import random
import matplotlib.pyplot as plt
from scipy import linalg
from math import exp, log
from geomloss import SamplesLoss

from autograd.numpy import sqrt, sin, cos, exp, pi, prod
from autograd.numpy.random import normal

In [None]:
%load_ext tensorboard

## Configuration

In [None]:
flags = EasyDict()

# Configurations for distributed training
flags['slurm'] = False # whether we are on slurm
flags['repel_im'] = True # maximize entropy by repeling images from each other
flags['hmc'] = False # use the hamiltonian monte carlo sampler
flags['sampler'] = 'csgld' # use the adaptively precondition SGLD sampler
flags['square_energy'] = False # make the energy square
flags['alias'] = False # make the energy square
flags['cpu'] = torch.device("cpu")
flags['gpu'] = torch.device("cuda:0")

flags['dataset'] = 'mnist' # cifar10 or celeba
flags['batch_size'] = 128 #128 # batch size during training
flags['multiscale'] = False # A multiscale EBM
flags['self_attn'] = True #Use self attention in models
flags['sigmoid'] = False # Apply sigmoid on energy (can improve the stability)
flags['anneal'] = False # Decrease noise over Langevin steps
flags['data_workers'] = 4 # Number of different data workers to load data in parallel
flags['buffer_size'] = 10000 # Size of inputs

# General Experiment Settings
flags['exp'] = 'default' #name of experiments
flags['log_interval'] = 50 #log outputs every so many batches
flags['save_interval'] = 500 # save outputs every so many batches
flags['test_interval'] = 500 # evaluate outputs every so many batches
flags['resume_iter'] = 0 #iteration to resume training from
flags['train'] = True # whether to train or test
flags['transform'] = True # apply data augmentation when sampling from the replay buffer
flags['kl'] = True # apply a KL term to loss
flags['entropy'] = 'kl' 
flags['cuda'] = True # move device on cuda
flags['epoch_num'] = 10 # Number of Epochs to train on
flags['ensembles'] = 1 #Number of ensembles to train models with
flags['lr'] = 2e-4 #Learning for training
flags['kl_coeff'] = 1.0 #coefficient for kl

# EBM Specific Experiments Settings
flags['objective'] = 'cd' #use the cd objective

# Setting for MCMC sampling
flags['num_steps'] = 40 # Steps of gradient descent for training
flags['step_lr'] = 10.0 # Size of steps for gradient descent
flags['replay_batch'] = True # Use MCMC chains initialized from a replay buffer.
flags['reservoir'] = True # Use a reservoir of past entires
flags['noise_scale'] = 1. # Relative amount of noise for MCMC
flags['init_noise'] = 0.1
flags['momentum'] = 0.9
flags['eps'] = 1e-6
flags['step_size'] = 10

# Architecture Settings
flags['filter_dim'] = 64 #64 #number of filters for conv nets
flags['im_size'] = 32 #32 #size of images
flags['spec_norm'] = False #Whether to use spectral normalization on weights
flags['norm'] = True #Use group norm in models norm in models

# Conditional settings
flags['cond'] = False #conditional generation with the model
flags['all_step'] = False #backprop through all langevin steps
flags['log_grad'] = False #log the gradient norm of the kl term
flags['cond_idx'] = 0 #conditioned index

DIM = 2048
device = torch.device('cuda:0')

In [None]:
writer = SummaryWriter(comment="_{sampler}_{entropy}_{dataset}".format(dataset=flags.dataset, entropy=flags.entropy, sampler=flags.sampler))

## Utils

In [None]:
# Functions for adaptations with PyTorch:
def to_np_array(*arrays):
    """Transform torch tensors/Variables into numpy arrays"""
    array_list = []
    for array in arrays:
        if isinstance(array, Variable):
            if array.is_cuda:
                array = array.cpu()
            array = array.data
        if isinstance(array, torch.FloatTensor) or isinstance(array, torch.LongTensor) or isinstance(array, torch.ByteTensor) or isinstance(array, torch.cuda.FloatTensor) or isinstance(array, torch.cuda.LongTensor) or isinstance(array, torch.cuda.ByteTensor):
            if array.is_cuda:
                array = array.cpu()
            array = array.numpy()
        array_list.append(array)
    if len(array_list) == 1:
        array_list = array_list[0]
    return array_list

In [None]:
def kldiv(x, xp, k=3, base=2):
    """ KL Divergence between p and q for x~p(x), xp~q(x)
        x, xp should be a list of vectors, e.g. x = [[1.3], [3.7], [5.1], [2.4]]
        if x is a one-dimensional scalar and we have four samples
    """
    assert k <= len(x) - 1, "Set k smaller than num. samples - 1"
    assert k <= len(xp) - 1, "Set k smaller than num. samples - 1"
    assert len(x[0]) == len(xp[0]), "Two distributions must have same dim."
    x, xp = to_np_array(x, xp)
    d = len(x[0])
    n = len(x)
    m = len(xp)
    const = log(m) - log(n - 1)
    tree = ss.cKDTree(x)
    treep = ss.cKDTree(xp)
    nn = [tree.query(point, k + 1, p=float('inf'))[0][k] for point in x]
    nnp = [treep.query(point, k, p=float('inf'))[0][k - 1] for point in x]
    return (const + d * np.mean(np.log(nnp)) - d * np.mean(np.log(nn))) / log(base)

In [None]:
def swish(x):
    return x * torch.sigmoid(x)

In [None]:
class WSConv2d(nn.Conv2d):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(WSConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)

    def forward(self, x):
        weight = self.weight
        weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
                                  keepdim=True).mean(dim=3, keepdim=True)
        weight = weight - weight_mean
        std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
        weight = weight / std.expand_as(weight)
        return F.conv2d(x, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

In [None]:
def compress_x_mod(x_mod):
    x_mod = (255 * np.clip(x_mod, 0, 1)).astype(np.uint8)
    return x_mod


def decompress_x_mod(x_mod):
    x_mod = x_mod / 256  + \
        np.random.uniform(0, 1 / 256, x_mod.shape)
    return x_mod

In [None]:
def ema_model(models, models_ema, mu=0.99):
    for model, model_ema in zip(models, models_ema):
        for param, param_ema in zip(model.parameters(), model_ema.parameters()):
            param_ema.data[:] = mu * param_ema.data + (1 - mu) * param.data

## Downsample

In [None]:
class Downsample(nn.Module):
    def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0):
        super(Downsample, self).__init__()
        self.filt_size = filt_size
        self.pad_off = pad_off
        self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))]
        self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes]
        self.stride = stride
        self.off = int((self.stride-1)/2.)
        self.channels = channels

        if(self.filt_size==1):
            a = np.array([1.,])
        elif(self.filt_size==2):
            a = np.array([1., 1.])
        elif(self.filt_size==3):
            a = np.array([1., 2., 1.])
        elif(self.filt_size==4):
            a = np.array([1., 3., 3., 1.])
        elif(self.filt_size==5):
            a = np.array([1., 4., 6., 4., 1.])
        elif(self.filt_size==6):
            a = np.array([1., 5., 10., 10., 5., 1.])
        elif(self.filt_size==7):
            a = np.array([1., 6., 15., 20., 15., 6., 1.])

        filt = torch.Tensor(a[:,None]*a[None,:])
        filt = filt/torch.sum(filt)
        self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1)))

        self.pad = get_pad_layer(pad_type)(self.pad_sizes)

    def forward(self, inp):
        if(self.filt_size==1):
            if(self.pad_off==0):
                return inp[:,:,::self.stride,::self.stride]
            else:
                return self.pad(inp)[:,:,::self.stride,::self.stride]
        else:
            return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])

def get_pad_layer(pad_type):
    if(pad_type in ['refl','reflect']):
        PadLayer = nn.ReflectionPad2d
    elif(pad_type in ['repl','replicate']):
        PadLayer = nn.ReplicationPad2d
    elif(pad_type=='zero'):
        PadLayer = nn.ZeroPad2d
    else:
        print('Pad type [%s] not recognized'%pad_type)
    return PadLayer

## Models

In [None]:
class Self_Attn(nn.Module):
    """ Self attention Layer"""
    def __init__(self,in_dim,activation):
        super(Self_Attn,self).__init__()
        self.chanel_in = in_dim
        self.activation = activation

        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1) #

    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N) 
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N

        out = torch.bmm(proj_value,attention.permute(0,2,1) )
        out = out.view(m_batchsize,C,width,height)

        out = self.gamma*out + x
        return out,attention

In [None]:
class CondResBlock(nn.Module):
    def __init__(self, args, downsample=True, rescale=True, filters=64, latent_dim=64, im_size=64, classes=512, norm=True, spec_norm=False):
        super(CondResBlock, self).__init__()

        self.filters = filters
        self.latent_dim = latent_dim
        self.im_size = im_size
        self.downsample = downsample

        if filters <= 128:
            self.bn1 = nn.InstanceNorm2d(filters, affine=True)
        else:
            self.bn1 = nn.GroupNorm(32, filters)

        if not norm:
            self.bn1 = None

        self.args = args

        if spec_norm:
            self.conv1 = spectral_norm(nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1))
        else:
            self.conv1 = WSConv2d(filters, filters, kernel_size=3, stride=1, padding=1)

        if filters <= 128:
            self.bn2 = nn.InstanceNorm2d(filters, affine=True)
        else:
            self.bn2 = nn.GroupNorm(32, filters, affine=True)

        if not norm:
            self.bn2 = None

        if spec_norm:
            self.conv2 = spectral_norm(nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1))
        else:
            self.conv2 = WSConv2d(filters, filters, kernel_size=3, stride=1, padding=1)

        self.dropout = Dropout(0.2)

        # Upscale to an mask of image
        self.latent_map = nn.Linear(classes, 2*filters)
        self.latent_map_2 = nn.Linear(classes, 2*filters)

        self.relu = torch.nn.ReLU(inplace=True)
        self.act = swish

        # Upscale to mask of image
        if downsample:
            if rescale:
                self.conv_downsample = nn.Conv2d(filters, 2 * filters, kernel_size=3, stride=1, padding=1)

                if args.alias:
                    self.avg_pool = Downsample(channels=2*filters)
                else:
                    self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1)
            else:
                self.conv_downsample = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)

                if args.alias:
                    self.avg_pool = Downsample(channels=filters)
                else:
                    self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1)


    def forward(self, x, y):
        x_orig = x

        if y is not None:
            latent_map = self.latent_map(y).view(-1, 2*self.filters, 1, 1)

            gain = latent_map[:, :self.filters]
            bias = latent_map[:, self.filters:]

        x = self.conv1(x)

        if self.bn1 is not None:
            x = self.bn1(x)

        if y is not None:
            x = gain * x + bias

        x = self.act(x)

        if y is not None:
            latent_map = self.latent_map_2(y).view(-1, 2*self.filters, 1, 1)
            gain = latent_map[:, :self.filters]
            bias = latent_map[:, self.filters:]

        x = self.conv2(x)

        if self.bn2 is not None:
            x = self.bn2(x)

        if y is not None:
            x = gain * x + bias

        x = self.act(x)

        x_out = x

        if self.downsample:
            x_out = self.conv_downsample(x_out)
            x_out = self.act(self.avg_pool(x_out))

        return x_out

## MNIST Model

In [None]:
class MNISTModel(nn.Module):
    def __init__(self, args):
        super(MNISTModel, self).__init__()
        self.act = swish
        # self.relu = torch.nn.ReLU(inplace=True)

        self.args = args
        self.filter_dim = args.filter_dim
        self.init_main_model()
        self.init_label_map()
        self.filter_dim = args.filter_dim

        # self.act = self.relu
        self.cond = args.cond
        self.sigmoid = args.sigmoid


    def init_main_model(self):
        args = self.args
        filter_dim = self.filter_dim
        im_size = 28
        self.conv1 = nn.Conv2d(1, filter_dim, kernel_size=3, stride=1, padding=1)
        self.res1 = CondResBlock(args, filters=filter_dim, latent_dim=1, im_size=im_size)
        self.res2 = CondResBlock(args, filters=2*filter_dim, latent_dim=1, im_size=im_size)

        self.res3 = CondResBlock(args, filters=4*filter_dim, latent_dim=1, im_size=im_size)
        self.energy_map = nn.Linear(filter_dim*8, 1)


    def init_label_map(self):
        args = self.args

        self.map_fc1 = nn.Linear(10, 256)
        self.map_fc2 = nn.Linear(256, 256)

    def main_model(self, x, latent):
        x = x.view(-1, 1, 28, 28)
        x = self.act(self.conv1(x))
        x = self.res1(x, latent)
        x = self.res2(x, latent)
        x = self.res3(x, latent)
        x = self.act(x)
        x = x.mean(dim=2).mean(dim=2)
        energy = self.energy_map(x)

        return energy

    def label_map(self, latent):
        x = self.act(self.map_fc1(latent))
        x = self.map_fc2(x)

        return x

    def forward(self, x, latent):
        args = self.args
        x = x.view(x.size(0), -1)

        if self.cond:
            latent = self.label_map(latent)
        else:
            latent = None

        energy = self.main_model(x, latent)

        return energy

## Standard CNN Model

In [None]:
class StandardCNN(nn.Module):
    def __init__(self):
        super(StandardCNN, self).__init__()
        self.conv1 = nn.utils.spectral_norm(nn.Conv2d(3, 64, 3, 1, 1))
        self.conv2 = nn.utils.spectral_norm(nn.Conv2d(64, 64, 4, 2, 1))

        self.conv3 = nn.utils.spectral_norm(nn.Conv2d(64, 128, 3, 1, 1))
        self.conv4 = nn.utils.spectral_norm(nn.Conv2d(128, 128, 4, 2, 1))

        self.conv5 = nn.utils.spectral_norm(nn.Conv2d(128, 256, 3, 1, 1))
        self.conv6 = nn.utils.spectral_norm(nn.Conv2d(256, 256, 4, 2, 1))

        self.conv7 = nn.utils.spectral_norm(nn.Conv2d(256, 512, 3, 1, 1))

        self.pool = nn.MaxPool2d(2, 2)
        self.act = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.dense = nn.utils.spectral_norm(nn.Linear(512 * 4 * 4, 1))

    def forward(self, x):
        x = self.act(self.conv1(x))
        x = self.act(self.conv2(x))
        # x = self.pool(x)
        x = self.act(self.conv3(x))
        x = self.act(self.conv4(x))
        # x = self.pool(x)
        x = self.act(self.conv5(x))
        x = self.act(self.conv6(x))
        # x = self.pool(x)
        x = self.act(self.conv7(x))

        x = self.dense(x.view(x.shape[0], -1))

        return x

## CelebA Model

In [None]:
class CelebAModel(nn.Module):
    def __init__(self, args, debug=False):
        super(CelebAModel, self).__init__()
        self.act = swish
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.cond = args.cond

        self.args = args
        self.init_main_model()

        if args.multiscale:
            self.init_mid_model()
            self.init_small_model()

        self.relu = torch.nn.ReLU(inplace=True)
        self.downsample = Downsample(channels=3)
        self.heir_weight = nn.Parameter(torch.Tensor([1.0, 1.0, 1.0]))
        self.debug = debug

    def init_main_model(self):
        args = self.args
        filter_dim = args.filter_dim
        latent_dim = args.filter_dim
        im_size = args.im_size

        self.conv1 = nn.Conv2d(3, filter_dim // 2, kernel_size=3, stride=1, padding=1)

        self.res_1a = CondResBlock(args, filters=filter_dim // 2, latent_dim=latent_dim, im_size=im_size, downsample=True, classes=2, norm=args.norm, spec_norm=args.spec_norm)
        self.res_1b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=False, classes=2, norm=args.norm, spec_norm=args.spec_norm)

        self.res_2a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=True, rescale=False, classes=2, norm=args.norm, spec_norm=args.spec_norm)
        self.res_2b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, classes=2, norm=args.norm, spec_norm=args.spec_norm)

        self.res_3a = CondResBlock(args, filters=2*filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, classes=2, norm=args.norm, spec_norm=args.spec_norm)
        self.res_3b = CondResBlock(args, filters=2*filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, classes=2, norm=args.norm, spec_norm=args.spec_norm)

        self.res_4a = CondResBlock(args, filters=4*filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, classes=2, norm=args.norm, spec_norm=args.spec_norm)
        self.res_4b = CondResBlock(args, filters=4*filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, classes=2, norm=args.norm, spec_norm=args.spec_norm)

        self.self_attn = Self_Attn(4 * filter_dim, self.act)

        self.energy_map = nn.Linear(filter_dim*8, 1)

    def init_mid_model(self):
        args = self.args
        filter_dim = args.filter_dim
        latent_dim = args.filter_dim
        im_size = args.im_size

        self.mid_conv1 = nn.Conv2d(3, filter_dim, kernel_size=3, stride=1, padding=1)

        self.mid_res_1a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=True, rescale=False, classes=2)
        self.mid_res_1b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=False, classes=2)

        self.mid_res_2a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=True, rescale=False, classes=2)
        self.mid_res_2b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, classes=2)

        self.mid_res_3a = CondResBlock(args, filters=2*filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, classes=2)
        self.mid_res_3b = CondResBlock(args, filters=2*filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, classes=2)

        self.mid_energy_map = nn.Linear(filter_dim*4, 1)
        self.avg_pool = Downsample(channels=3)

    def init_small_model(self):
        args = self.args
        filter_dim = args.filter_dim
        latent_dim = args.filter_dim
        im_size = args.im_size

        self.small_conv1 = nn.Conv2d(3, filter_dim, kernel_size=3, stride=1, padding=1)

        self.small_res_1a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=True, rescale=False, classes=2)
        self.small_res_1b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=False, classes=2)

        self.small_res_2a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=True, rescale=False, classes=2)
        self.small_res_2b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, classes=2)

        self.small_energy_map = nn.Linear(filter_dim*2, 1)

    def main_model(self, x, latent):
        x = self.act(self.conv1(x))

        x = self.res_1a(x, latent)
        x = self.res_1b(x, latent)

        x = self.res_2a(x, latent)
        x = self.res_2b(x, latent)


        x = self.res_3a(x, latent)
        x = self.res_3b(x, latent)

        if self.args.self_attn:
            x, _ = self.self_attn(x)

        x = self.res_4a(x, latent)
        x = self.res_4b(x, latent)
        x = self.act(x)

        x = x.mean(dim=2).mean(dim=2)

        x = x.view(x.size(0), -1)
        energy = self.energy_map(x)

        if self.args.square_energy:
            energy = torch.pow(energy, 2)

        if self.args.sigmoid:
            energy = F.sigmoid(energy)

        return energy

    def mid_model(self, x, latent):
        x = F.avg_pool2d(x, 3, stride=2, padding=1)

        x = self.act(self.mid_conv1(x))

        x = self.mid_res_1a(x, latent)
        x = self.mid_res_1b(x, latent)

        x = self.mid_res_2a(x, latent)
        x = self.mid_res_2b(x, latent)

        x = self.mid_res_3a(x, latent)
        x = self.mid_res_3b(x, latent)
        x = self.act(x)

        x = x.mean(dim=2).mean(dim=2)

        x = x.view(x.size(0), -1)
        energy = self.mid_energy_map(x)

        if self.args.square_energy:
            energy = torch.pow(energy, 2)

        if self.args.sigmoid:
            energy = F.sigmoid(energy)

        return energy

    def small_model(self, x, latent):
        x = F.avg_pool2d(x, 3, stride=2, padding=1)
        x = F.avg_pool2d(x, 3, stride=2, padding=1)

        x = self.act(self.small_conv1(x))

        x = self.small_res_1a(x, latent)
        x = self.small_res_1b(x, latent)

        x = self.small_res_2a(x, latent)
        x = self.small_res_2b(x, latent)
        x = self.act(x)

        x = x.mean(dim=2).mean(dim=2)

        x = x.view(x.size(0), -1)
        energy = self.small_energy_map(x)

        if self.args.square_energy:
            energy = torch.pow(energy, 2)

        if self.args.sigmoid:
            energy = F.sigmoid(energy)

        return energy

    def label_map(self, latent):
        x = self.act(self.map_fc1(latent))
        x = self.act(self.map_fc2(x))
        x = self.act(self.map_fc3(x))
        x = self.act(self.map_fc4(x))

        return x

    def forward(self, x, latent):
        args = self.args

        if not self.cond:
            latent = None

        energy = self.main_model(x, latent)

        if args.multiscale:
            large_energy = energy
            mid_energy = self.mid_model(x, latent)
            small_energy = self.small_model(x, latent)
            energy = torch.cat([small_energy, mid_energy, large_energy], dim=-1)

        return energy

## ResNet Model

In [None]:
class ResNetModel(nn.Module):
    def __init__(self, args):
        super(ResNetModel, self).__init__()
        self.act = swish

        self.args = args
        self.spec_norm = args.spec_norm
        self.norm = args.norm
        self.init_main_model()

        if args.multiscale:
            self.init_mid_model()
            self.init_small_model()

        self.relu = torch.nn.ReLU(inplace=True)
        self.downsample = Downsample(channels=3)

        self.cond = args.cond

    def init_main_model(self):
        args = self.args
        filter_dim = args.filter_dim
        latent_dim = args.filter_dim
        im_size = args.im_size

        self.conv1 = nn.Conv2d(3, filter_dim, kernel_size=3, stride=1, padding=1)
        self.res_1a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, spec_norm=self.spec_norm, norm=self.norm)
        self.res_1b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=False, spec_norm=self.spec_norm, norm=self.norm)

        self.res_2a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, spec_norm=self.spec_norm, norm=self.norm)
        self.res_2b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, spec_norm=self.spec_norm, norm=self.norm)

        self.res_3a = CondResBlock(args, filters=2*filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, spec_norm=self.spec_norm, norm=self.norm)
        self.res_3b = CondResBlock(args, filters=2*filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, spec_norm=self.spec_norm, norm=self.norm)

        self.res_4a = CondResBlock(args, filters=4*filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, spec_norm=self.spec_norm, norm=self.norm)
        self.res_4b = CondResBlock(args, filters=4*filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, spec_norm=self.spec_norm, norm=self.norm)

        self.self_attn = Self_Attn(2 * filter_dim, self.act)

        self.energy_map = nn.Linear(filter_dim*8, 1)

    def init_mid_model(self):
        args = self.args
        filter_dim = args.filter_dim
        latent_dim = args.filter_dim
        im_size = args.im_size

        self.mid_conv1 = nn.Conv2d(3, filter_dim, kernel_size=3, stride=1, padding=1)
        self.mid_res_1a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, spec_norm=self.spec_norm, norm=self.norm)
        self.mid_res_1b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=False, spec_norm=self.spec_norm, norm=self.norm)

        self.mid_res_2a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, spec_norm=self.spec_norm, norm=self.norm)
        self.mid_res_2b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, spec_norm=self.spec_norm, norm=self.norm)

        self.mid_res_3a = CondResBlock(args, filters=2*filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, spec_norm=self.spec_norm, norm=self.norm)
        self.mid_res_3b = CondResBlock(args, filters=2*filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, spec_norm=self.spec_norm, norm=self.norm)

        self.mid_energy_map = nn.Linear(filter_dim*4, 1)
        self.avg_pool = Downsample(channels=3)

    def init_small_model(self):
        args = self.args
        filter_dim = args.filter_dim
        latent_dim = args.filter_dim
        im_size = args.im_size

        self.small_conv1 = nn.Conv2d(3, filter_dim, kernel_size=3, stride=1, padding=1)
        self.small_res_1a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, spec_norm=self.spec_norm, norm=self.norm)
        self.small_res_1b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=False, spec_norm=self.spec_norm, norm=self.norm)

        self.small_res_2a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, spec_norm=self.spec_norm, norm=self.norm)
        self.small_res_2b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, spec_norm=self.spec_norm, norm=self.norm)

        self.small_energy_map = nn.Linear(filter_dim*2, 1)

    def main_model(self, x, latent, compute_feat=False):
        x = self.act(self.conv1(x))

        x = self.res_1a(x, latent)
        x = self.res_1b(x, latent)

        x = self.res_2a(x, latent)
        x = self.res_2b(x, latent)

        if self.args.self_attn:
            x, _ = self.self_attn(x)

        x = self.res_3a(x, latent)
        x = self.res_3b(x, latent)

        x = self.res_4a(x, latent)
        x = self.res_4b(x, latent)
        x = self.act(x)

        x = x.mean(dim=2).mean(dim=2)

        if compute_feat:
            return x

        x = x.view(x.size(0), -1)
        energy = self.energy_map(x)

        if self.args.square_energy:
            energy = torch.pow(energy, 2)

        if self.args.sigmoid:
            energy = F.sigmoid(energy)

        return energy

    def mid_model(self, x, latent):
        x = F.avg_pool2d(x, 3, stride=2, padding=1)

        x = self.act(self.mid_conv1(x))

        x = self.mid_res_1a(x, latent)
        x = self.mid_res_1b(x, latent)

        x = self.mid_res_2a(x, latent)
        x = self.mid_res_2b(x, latent)

        x = self.mid_res_3a(x, latent)
        x = self.mid_res_3b(x, latent)
        x = self.act(x)

        x = x.mean(dim=2).mean(dim=2)

        x = x.view(x.size(0), -1)
        energy = self.mid_energy_map(x)

        if self.args.square_energy:
            energy = torch.pow(energy, 2)

        if self.args.sigmoid:
            energy = F.sigmoid(energy)

        return energy

    def small_model(self, x, latent):
        x = F.avg_pool2d(x, 3, stride=2, padding=1)
        x = F.avg_pool2d(x, 3, stride=2, padding=1)

        x = self.act(self.small_conv1(x))

        x = self.small_res_1a(x, latent)
        x = self.small_res_1b(x, latent)

        x = self.small_res_2a(x, latent)
        x = self.small_res_2b(x, latent)
        x = self.act(x)

        x = x.mean(dim=2).mean(dim=2)

        x = x.view(x.size(0), -1)
        energy = self.small_energy_map(x)

        if self.args.square_energy:
            energy = torch.pow(energy, 2)

        if self.args.sigmoid:
            energy = F.sigmoid(energy)

        return energy

    def forward(self, x, latent):
        args = self.args

        if self.cond:
            latent = self.label_map(latent)
        else:
            latent = None

        energy = self.main_model(x, latent)

        if args.multiscale:
            large_energy = energy
            mid_energy = self.mid_model(x, latent)
            small_energy = self.small_model(x, latent)

            # Add a seperate energy penalizing the different energies from each model
            energy = torch.cat([small_energy, mid_energy, large_energy], dim=-1)

        return energy

    def compute_feat(self, x, latent):
        return self.main_model(x, None, compute_feat=True)

## Replay Buffer

In [None]:
class GaussianBlur(object):

    def __init__(self, min=0.1, max=2.0, kernel_size=9):
        self.min = min
        self.max = max
        self.kernel_size = kernel_size

    def __call__(self, sample):
        sample = np.array(sample)

        # blur the image with a 50% chance
        prob = np.random.random_sample()

        if prob < 0.5:
            sigma = (self.max - self.min) * np.random.random_sample() + self.min
            sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)

        return sample

In [None]:
class ReplayBuffer(object):
    def __init__(self, size, transform, dataset):
        """Create Replay buffer.
        Parameters
        ----------
        size: int
            Max number of transitions to store in the buffer. When the buffer
            overflows the old memories are dropped.
        """
        self._storage = []
        self._maxsize = size
        self._next_idx = 0
        self.gaussian_blur = GaussianBlur()

        def get_color_distortion(s=1.0):
        # s is the strength of color distortion.
            color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.4*s)
            rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
            rnd_gray = transforms.RandomGrayscale(p=0.2)
            color_distort = transforms.Compose([
                rnd_color_jitter,
                rnd_gray])
            return color_distort

        color_transform = get_color_distortion()

        if dataset in ("cifar10", "celeba", "cats"):
            im_size = 32
        elif dataset == "continual":
            im_size = 64
        elif dataset == "celebahq":
            im_size = 128
        elif dataset == "object":
            im_size = 128
        elif dataset == "mnist":
            im_size = 28
        elif dataset == "moving_mnist":
            im_size = 28
        elif dataset == "imagenet":
            im_size = 128
        elif dataset == "lsun":
            im_size = 128
        else:
            assert False

        self.dataset = dataset
        if transform:
            if dataset in ("cifar10", "celeba", "cats"):
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.08, 1.0)), transforms.RandomHorizontalFlip(), color_transform, transforms.ToTensor()])
            elif dataset == "continual":
                color_transform = get_color_distortion(0.1)
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.7, 1.0)), color_transform, transforms.ToTensor()])
            elif dataset == "celebahq":
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.08, 1.0)), transforms.RandomHorizontalFlip(), color_transform, transforms.ToTensor()])
            elif dataset == "imagenet":
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.01, 1.0)), transforms.RandomHorizontalFlip(), color_transform, transforms.ToTensor()])
            elif dataset == "object":
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.01, 1.0)), transforms.RandomHorizontalFlip(), color_transform, transforms.ToTensor()])
            elif dataset == "lsun":
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.08, 1.0)), transforms.RandomHorizontalFlip(), color_transform, transforms.ToTensor()])
            elif dataset == "mnist":
                self.transform = None
            elif dataset == "moving_mnist":
                self.transform = None
            else:
                assert False
        else:
            self.transform = None

    def __len__(self):
        return len(self._storage)

    def add(self, ims):
        batch_size = ims.shape[0]
        if self._next_idx >= len(self._storage):
            self._storage.extend(list(ims))
        else:
            if batch_size + self._next_idx < self._maxsize:
                self._storage[self._next_idx:self._next_idx +
                              batch_size] = list(ims)
            else:
                split_idx = self._maxsize - self._next_idx
                self._storage[self._next_idx:] = list(ims)[:split_idx]
                self._storage[:batch_size - split_idx] = list(ims)[split_idx:]
        self._next_idx = (self._next_idx + ims.shape[0]) % self._maxsize

    def _encode_sample(self, idxes, no_transform=False, downsample=False):
        ims = []
        for i in idxes:
            im = self._storage[i]

            if self.dataset != "mnist":
                if (self.transform is not None) and (not no_transform):
                    im = im.transpose((1, 2, 0))
                    im = np.array(self.transform(Image.fromarray(np.array(im))))

                # if downsample and (self.dataset in ["celeba", "object", "imagenet"]):
                #     im = im[:, ::4, ::4]

            im = im * 255
            ims.append(im)
        return np.array(ims)

    def sample(self, batch_size, no_transform=False, downsample=False):
        """Sample a batch of experiences.
        Parameters
        ----------
        batch_size: int
            How many transitions to sample.
        Returns
        -------
        obs_batch: np.array
            batch of observations
        act_batch: np.array
            batch of actions executed given obs_batch
        rew_batch: np.array
            rewards received as results of executing act_batch
        next_obs_batch: np.array
            next set of observations seen after executing act_batch
        done_mask: np.array
            done_mask[i] = 1 if executing act_batch[i] resulted in
            the end of an episode and 0 otherwise.
        """
        idxes = [random.randint(0, len(self._storage) - 1)
                 for _ in range(batch_size)]
        return self._encode_sample(idxes, no_transform=no_transform, downsample=downsample), idxes

    def set_elms(self, data, idxes):
        if len(self._storage) < self._maxsize:
            self.add(data)
        else:
            for i, ix in enumerate(idxes):
                self._storage[ix] = data[i]

In [None]:
class ReservoirBuffer(object):
    def __init__(self, size, transform, dataset):
        """Create Replay buffer.
        Parameters
        ----------
        size: int
            Max number of transitions to store in the buffer. When the buffer
            overflows the old memories are dropped.
        """
        self._storage = []
        self._maxsize = size
        self._next_idx = 0
        self.n = 0

        def get_color_distortion(s=1.0):
        # s is the strength of color distortion.
            color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.4*s)
            rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
            rnd_gray = transforms.RandomGrayscale(p=0.2)
            color_distort = transforms.Compose([
                rnd_color_jitter,
                rnd_gray])
            return color_distort

        if dataset in ("cifar10", "celeba", "cats"):
            im_size = 32
        elif dataset == "continual":
            im_size = 64
        elif dataset == "celeba":
            im_size = 128
        elif dataset == "object":
            im_size = 128
        elif dataset == "mnist":
            im_size = 28
        elif dataset == "moving_mnist":
            im_size = 28
        elif dataset == "imagenet":
            im_size = 128
        elif dataset == "lsun":
            im_size = 128
        elif dataset == "stl":
            im_size = 48
        else:
            assert False

        color_transform = get_color_distortion(0.5)
        self.dataset = dataset

        if transform:
            if dataset in ("cifar10", "celeba", "cats"):
                color_transform = get_color_distortion(1.0)
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.08, 1.0)), transforms.RandomHorizontalFlip(), color_transform, transforms.ToTensor()])
                # self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.03, 1.0)), transforms.RandomHorizontalFlip(), color_transform, GaussianBlur(kernel_size=5), transforms.ToTensor()])
            elif dataset == "continual":
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.08, 1.0)), transforms.RandomHorizontalFlip(), color_transform, GaussianBlur(kernel_size=5), transforms.ToTensor()])
            elif dataset == "celeba":
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.08, 1.0)), transforms.RandomHorizontalFlip(), color_transform, GaussianBlur(kernel_size=5), transforms.ToTensor()])
            elif dataset == "imagenet":
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.6, 1.0)), transforms.RandomHorizontalFlip(), color_transform, GaussianBlur(kernel_size=11), transforms.ToTensor()])
            elif dataset == "lsun":
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.08, 1.0)), transforms.RandomHorizontalFlip(), color_transform, GaussianBlur(kernel_size=5), transforms.ToTensor()])
            elif dataset == "stl":
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.04, 1.0)), transforms.RandomHorizontalFlip(), color_transform, GaussianBlur(kernel_size=11), transforms.ToTensor()])
            elif dataset == "object":
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.08, 1.0)), transforms.RandomHorizontalFlip(), color_transform, transforms.ToTensor()])
            elif dataset == "mnist":
                self.transform = None
            elif dataset == "moving_mnist":
                self.transform = None
            else:
                assert False
        else:
            self.transform = None

    def __len__(self):
        return len(self._storage)

    def add(self, ims):
        batch_size = ims.shape[0]
        if self._next_idx >= len(self._storage):
            self._storage.extend(list(ims))
            self.n = self.n + ims.shape[0]
        else:
            for im in ims:
                self.n = self.n + 1
                ix = random.randint(0, self.n - 1)

                if ix < len(self._storage):
                    self._storage[ix] = im

        self._next_idx = (self._next_idx + ims.shape[0]) % self._maxsize


    def _encode_sample(self, idxes, no_transform=False, downsample=False):
        ims = []
        for i in idxes:
            im = self._storage[i]

            if self.dataset != "mnist":
                if (self.transform is not None) and (not no_transform):
                    im = im.transpose((1, 2, 0))
                    im = np.array(self.transform(Image.fromarray(im)))

            im = im * 255

            ims.append(im)
        return np.array(ims)

    def sample(self, batch_size, no_transform=False, downsample=False):
        """Sample a batch of experiences.
        Parameters
        ----------
        batch_size: int
            How many transitions to sample.
        Returns
        -------
        obs_batch: np.array
            batch of observations
        act_batch: np.array
            batch of actions executed given obs_batch
        rew_batch: np.array
            rewards received as results of executing act_batch
        next_obs_batch: np.array
            next set of observations seen after executing act_batch
        done_mask: np.array
            done_mask[i] = 1 if executing act_batch[i] resulted in
            the end of an episode and 0 otherwise.
        """
        idxes = [random.randint(0, len(self._storage) - 1)
                 for _ in range(batch_size)]
        return self._encode_sample(idxes, no_transform=no_transform, downsample=downsample), idxes

## Dataset

In [None]:
class Mnist(Dataset):
    def __init__(self, train=True, rescale=1.0):
        self.data = MNIST(
            "data/mnist",
            transform=transforms.ToTensor(),
            download=True, train=train)
        self.labels = np.eye(10)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        im, label = self.data[index]
        label = self.labels[label]
        im = im.squeeze()
        im = im.numpy() / 256 * 255 + np.random.uniform(0, 1. / 256, (28, 28))
        im = np.clip(im, 0, 1)
        s = 28
        im_corrupt = np.random.uniform(0, 1, (s, s, 1))
        im = im[:, :, None]

        return torch.Tensor(im_corrupt), torch.Tensor(im), label

In [None]:
 class CelebAHQ(Dataset):

    def __init__(self, cond_idx=1, filter_idx=0):
        self.path = "/content/data/celebAHQ/data128x128/{:05}.jpg"
        self.hq_labels = pd.read_csv(os.path.join(sample_dir, "data/celebAHQ/image_list.txt"), sep="\s+")
        self.labels = pd.read_csv(os.path.join(sample_dir, "data/celebAHQ/list_attr_celeba.txt"), sep="\s+", skiprows=1)
        self.cond_idx = cond_idx
        self.filter_idx = filter_idx

    def __len__(self):
        return self.hq_labels.shape[0] 

    def __getitem__(self, index):
        info = self.hq_labels.iloc[index]
        info = self.labels.iloc[info.orig_idx]

        path = self.path.format(index+1)
        im = np.array(Image.open(path))
        image_size = 128
        # im = imresize(im, (image_size, image_size))
        im = im / 256
        im = im + np.random.uniform(0, 1 / 256., im.shape)

        label = int(info.iloc[self.cond_idx])
        if label == -1:
            label = 0
        label = np.eye(2)[label]

        im_corrupt = np.random.uniform(
            0, 1, size=(image_size, image_size, 3))

        return im_corrupt, im, label

In [None]:
class CelebADataset(Dataset):
    def __init__(
            self,
            FLAGS,
            split='train',
            augment=False,
            noise=True,
            rescale=1.0):

        if augment:
            transform_list = [
                torchvision.transforms.RandomCrop(32, padding=4),
                torchvision.transforms.RandomHorizontalFlip(),
                torchvision.transforms.ToTensor(),
            ]

            transform = transforms.Compose(transform_list)
        else:
            # transform = transforms.ToTensor()
            transform = transforms.Compose([
                # resize
                transforms.Resize(32),
                # center-crop
                transforms.CenterCrop(32),
                # to-tensor
                transforms.ToTensor()
            ])

        self.data = torchvision.datasets.CelebA(
            "/content/data",
            transform=transform,
            split=split,
            download=True)
        self.one_hot_map = np.eye(10)
        self.noise = noise
        self.rescale = rescale
        self.FLAGS = FLAGS

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        FLAGS = self.FLAGS
        
        im, label = self.data[index]

        im = np.transpose(im, (1, 2, 0)).numpy()
        image_size = 32
        label = self.one_hot_map[label]

        im = im * 255 / 256

        im = im * self.rescale + \
            np.random.uniform(0, 1 / 256., im.shape)

        # np.random.seed((index + int(time.time() * 1e7)) % 2**32)

        im_corrupt = np.random.uniform(
            0.0, self.rescale, (image_size, image_size, 3))

        return torch.Tensor(im_corrupt), torch.Tensor(im), label
        # return torch.Tensor(im), label

In [None]:
class Cats(Dataset):
    def __init__(
            self,
            augment=False,
            noise=True,
            rescale=1.0):

        if augment:
            transform_list = [
                torchvision.transforms.RandomCrop(32, padding=4),
                torchvision.transforms.RandomHorizontalFlip(),
                torchvision.transforms.ToTensor(),
            ]

            transform = transforms.Compose(transform_list)
        else:
            # transform = transforms.ToTensor()
            transform = transforms.Compose([
                # resize
                transforms.Resize(32),
                # center-crop
                transforms.CenterCrop(32),
                # to-tensor
                transforms.ToTensor()
            ])

        self.data = torchvision.datasets.ImageFolder('/content/data/cats', transform = transform)
        self.one_hot_map = np.eye(10)
        self.noise = noise
        self.rescale = rescale

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):        
        im, label = self.data[index]

        im = np.transpose(im, (1, 2, 0)).numpy()
        image_size = 32
        label = self.one_hot_map[label]

        im = im * 255 / 256

        im = im * self.rescale + \
            np.random.uniform(0, 1 / 256., im.shape)

        im_corrupt = np.random.uniform(
            0.0, self.rescale, (image_size, image_size, 3))

        return torch.Tensor(im_corrupt), torch.Tensor(im), label

In [None]:
class Cifar10(Dataset):
    def __init__(
            self,
            FLAGS,
            train=True,
            full=False,
            augment=False,
            noise=True,
            rescale=1.0):

        if augment:
            transform_list = [
                torchvision.transforms.RandomCrop(32, padding=4),
                torchvision.transforms.RandomHorizontalFlip(),
                torchvision.transforms.ToTensor(),
            ]

            transform = transforms.Compose(transform_list)
        else:
            transform = transforms.ToTensor()

        self.full = full
        self.data = torchvision.datasets.CIFAR10(
            "./data/cifar10",
            transform=transform,
            train=train,
            download=True)
        self.test_data = torchvision.datasets.CIFAR10(
            "./data/cifar10",
            transform=transform,
            train=False,
            download=True)
        self.one_hot_map = np.eye(10)
        self.noise = noise
        self.rescale = rescale
        self.FLAGS = FLAGS

    def __len__(self):

        if self.full:
            return len(self.data) + len(self.test_data)
        else:
            return len(self.data)

    def __getitem__(self, index):
        FLAGS = self.FLAGS
        if self.full:
            if index >= len(self.data):
                im, label = self.test_data[index - len(self.data)]
            else:
                im, label = self.data[index]
        else:
            im, label = self.data[index]

        im = np.transpose(im, (1, 2, 0)).numpy()
        image_size = 32
        label = self.one_hot_map[label]

        im = im * 255 / 256

        im = im * self.rescale + \
            np.random.uniform(0, 1 / 256., im.shape)

        # np.random.seed((index + int(time.time() * 1e7)) % 2**32)

        im_corrupt = np.random.uniform(
            0.0, self.rescale, (image_size, image_size, 3))

        return torch.Tensor(im_corrupt), torch.Tensor(im), label

## Sampling ##

In [None]:
def stochastic_f(energy): 
    return energy.detach().cpu().numpy() + 0.32*normal(size=1)

In [None]:
def gen_image_csgld(label, FLAGS, model, im_neg, num_steps, sample=False):
    im_noise = torch.randn_like(im_neg).detach()

    im_negs_samples = []

    parts = 100
    Gcum = np.array(range(parts, 0, -1)) * 1.0 / sum(range(parts, 0, -1))
    J = parts - 1
    bouncy_move = 0
    grad_mul = 1.
    zeta = 0.75
    T = 1
    decay_lr = 100.0

    for i in range(num_steps):
        im_noise.normal_()

        if FLAGS.anneal:
            im_neg = im_neg + 0.001 * (num_steps - i - 1) / num_steps * im_noise
        else:
            im_neg = im_neg + 0.001 * im_noise

        im_neg.requires_grad_(requires_grad=True)
        energy = model.forward(im_neg, label)
        # print("energy : ", energy)
        lower_bound, upper_bound = np.min(energy.detach().cpu().numpy()) - 1, np.max(energy.detach().cpu().numpy()) + 1
        partition=[lower_bound, upper_bound]

        if FLAGS.all_step:
            im_grad = torch.autograd.grad([energy.sum()], [im_neg], create_graph=True)[0]
        else:
            im_grad = torch.autograd.grad([energy.sum()], [im_neg])[0]

        if i == num_steps - 1:
            im_neg_orig = im_neg
            im_neg = im_neg - FLAGS.step_lr * grad_mul * im_grad

            if FLAGS.dataset in ("cifar10", "celeba", "cats"):
                n = 128
            elif FLAGS.dataset == "celebahq":
                # Save space
                n = 128
            elif FLAGS.dataset == "lsun":
                # Save space
                n = 32
            elif FLAGS.dataset == "object":
                # Save space
                n = 32
            elif FLAGS.dataset == "mnist":
                n = 128
            elif FLAGS.dataset == "imagenet":
                n = 32
            elif FLAGS.dataset == "stl":
                n = 32

            im_neg_kl = im_neg_orig[:n]
            if sample:
                pass
            else:
                energy = model.forward(im_neg_kl, label)
                im_grad = torch.autograd.grad([energy.sum()], [im_neg_kl], create_graph=True)[0]

            im_neg_kl = im_neg_kl - FLAGS.step_lr * grad_mul * im_grad[:n]
            im_neg_kl = torch.clamp(im_neg_kl, 0, 1)
        else:
            im_neg = im_neg - FLAGS.step_lr * grad_mul * im_grad

        print("\n grad_mul: ", grad_mul)
        div_f = (partition[1] - partition[0]) / parts
        grad_mul = 1 + zeta * T * (np.log(Gcum[J]) - np.log(Gcum[J-1])) / div_f
      
        J = (min(max(int((stochastic_f(energy).mean() - partition[0]) / div_f + 1), 1), parts - 1))
        step_size = min(decay_lr, 10./(i**0.8+100))
        Gcum[:J] = Gcum[:J] + step_size * Gcum[J]**zeta * (-Gcum[:J])
        Gcum[J] = Gcum[J] + step_size * Gcum[J]**zeta * (1 - Gcum[J])
        Gcum[(J+1):] = Gcum[(J+1):] + step_size * Gcum[J]**zeta * (-Gcum[(J+1):])

        if grad_mul < 0:
            bouncy_move = bouncy_move + 1
            print("\n bouncy_move : ", bouncy_move)

        im_neg = im_neg.detach()

        if sample:
            im_negs_samples.append(im_neg)

        im_neg = torch.clamp(im_neg, 0, 1)

    if sample:
        return im_neg, im_neg_kl, im_negs_samples, np.abs(im_grad.detach().cpu().numpy()).mean()
    else:
        return im_neg, im_neg_kl, np.abs(im_grad.detach().cpu().numpy()).mean()

In [None]:
def gen_image_resgld(label, FLAGS, model, im_neg, num_steps, sample=False):

    im_noise = torch.randn_like(im_neg).detach()

    T_multiply=0.9
    T = 0.9
    var=0.1
    resgld_beta_high = im_neg
    resgld_beta_low = im_neg
    swaps = 0

    noise_scale = sqrt(2e-6 * FLAGS.step_lr * T)

    print("noise_scale : ", noise_scale)
    print("noise_scale * T_multiply: ", noise_scale* T_multiply)

    im_negs_samples = []

    for i in range(num_steps):
        im_noise.normal_()

        resgld_beta_low = resgld_beta_low + noise_scale * im_noise
        resgld_beta_high = resgld_beta_high + noise_scale * T_multiply * im_noise

        resgld_beta_high.requires_grad_(requires_grad=True)
        energy_high = model.forward(resgld_beta_high, label)

        resgld_beta_low.requires_grad_(requires_grad=True)
        energy_low = model.forward(resgld_beta_low, label)

        im_grad_low = torch.autograd.grad([energy_low.sum()], [resgld_beta_low])[0]
        im_grad_high = torch.autograd.grad([energy_high.sum()], [resgld_beta_high])[0]
      
        if i == num_steps - 1:
            im_neg_orig = resgld_beta_low
            resgld_beta_low = resgld_beta_low - FLAGS.step_lr * im_grad_low 
            resgld_beta_high = resgld_beta_high - FLAGS.step_lr * im_grad_high 

            if FLAGS.dataset in ("cifar10", "celeba", "cats"):
                n = 128
            elif FLAGS.dataset == "celebahq":
                # Save space
                n = 128
            elif FLAGS.dataset == "lsun":
                # Save space
                n = 32
            elif FLAGS.dataset == "object":
                # Save space
                n = 32
            elif FLAGS.dataset == "mnist":
                n = 128
            elif FLAGS.dataset == "imagenet":
                n = 32
            elif FLAGS.dataset == "stl":
                n = 32

            im_neg_kl = im_neg_orig[:n]
            if sample:
                pass
            else:
                energy = model.forward(im_neg_kl, label)
                im_grad = torch.autograd.grad([energy.sum()], [im_neg_kl], create_graph=True)[0]

                im_neg_kl = im_neg_kl - FLAGS.step_lr * im_grad[:n]
                im_neg_kl = torch.clamp(im_neg_kl, 0, 1)
        else:
            resgld_beta_low = resgld_beta_low - FLAGS.step_lr * im_grad_low
            resgld_beta_high = resgld_beta_high - FLAGS.step_lr * im_grad_high * T_multiply

        dT = 1 / T - 1 / (T * T_multiply)
        swap_rate = torch.exp(dT * (energy_low - energy_high - dT * var))
        intensity_r = 0.1
        # print("swap_rate", swap_rate)
        swap_rate = swap_rate.mean().item()
        print("swap_rate", swap_rate)
        random = np.random.uniform(0, 1)
        print("random", random)
        if random < intensity_r * swap_rate:
            resgld_beta_high, resgld_beta_low = resgld_beta_low, resgld_beta_high
            swaps += 1
            print("swaps : ", swaps)

        im_neg = resgld_beta_low.detach()

        if sample:
            im_negs_samples.append(im_neg)

        im_neg = torch.clamp(im_neg, 0, 1)

    if sample:
        return im_neg, im_neg_kl, im_negs_samples, np.abs(im_grad_low.detach().cpu().numpy()).mean()
    else:
        return im_neg, im_neg_kl, np.abs(im_grad_low.detach().cpu().numpy()).mean()

In [None]:
def rescale_im(image):
    image = np.clip(image, 0, 1)
    return (np.clip(image * 256, 0, 255)).astype(np.uint8)

In [None]:
def gen_image(label, FLAGS, model, im_neg, num_steps, sample=False):
    im_noise = torch.randn_like(im_neg).detach()

    im_negs_samples = []

    for i in range(num_steps):
        im_noise.normal_()

        if FLAGS.anneal:
            im_neg = im_neg + 0.001 * (num_steps - i - 1) / num_steps * im_noise
        else:
            im_neg = im_neg + 0.001 * im_noise

        im_neg.requires_grad_(requires_grad=True)
        energy = model.forward(im_neg, label)

        if FLAGS.all_step:
            im_grad = torch.autograd.grad([energy.sum()], [im_neg], create_graph=True)[0]
        else:
            im_grad = torch.autograd.grad([energy.sum()], [im_neg])[0]

        if i == num_steps - 1:
            im_neg_orig = im_neg
            im_neg = im_neg - FLAGS.step_lr * im_grad

            if FLAGS.dataset in ("cifar10", "celeba", "cats"):
                n = 128
            elif FLAGS.dataset == "celebahq":
                # Save space
                n = 128
            elif FLAGS.dataset == "lsun":
                # Save space
                n = 32
            elif FLAGS.dataset == "object":
                # Save space
                n = 32
            elif FLAGS.dataset == "mnist":
                n = 128
            elif FLAGS.dataset == "imagenet":
                n = 32
            elif FLAGS.dataset == "stl":
                n = 32

            im_neg_kl = im_neg_orig[:n]
            if sample:
                pass
            else:
                energy = model.forward(im_neg_kl, label)
                im_grad = torch.autograd.grad([energy.sum()], [im_neg_kl], create_graph=True)[0]

            im_neg_kl = im_neg_kl - FLAGS.step_lr * im_grad[:n]
            im_neg_kl = torch.clamp(im_neg_kl, 0, 1)
        else:
            im_neg = im_neg - FLAGS.step_lr * im_grad

        im_neg = im_neg.detach()

        if sample:
            im_negs_samples.append(im_neg)

        im_neg = torch.clamp(im_neg, 0, 1)

    if sample:
        return im_neg, im_neg_kl, im_negs_samples, np.abs(im_grad.detach().cpu().numpy()).mean()
    else:
        return im_neg, im_neg_kl, np.abs(im_grad.detach().cpu().numpy()).mean()

In [None]:
def gen_image_cycsgld(label, FLAGS, model, im_neg, num_steps, sample=False):
    im_noise = torch.randn_like(im_neg).detach()
    # total=1000
    # cycles=20
    # sub_total = total / cycles
    # T = 1e-7
    total=1e6
    cycles=20
    sub_total = total / cycles
    T = 1e-6
    
    im_negs_samples = []

    for i in range(num_steps):
        im_noise.normal_()
        iters = i
        r_remainder = (iters % sub_total) * 1.0 / sub_total
        cyc_lr = FLAGS.step_lr * 5 / 2 * (cos(pi * r_remainder) + 1)
        print("\ncyc_lr", cyc_lr)

        if FLAGS.anneal:
            im_neg = im_neg + 0.001 * (num_steps - i - 1) / num_steps * im_noise
        else:
            # im_neg = im_neg + 0.001 * im_noise
            im_neg = im_neg + sqrt(2 * cyc_lr * T) * im_noise
        print("\nnoise_cyc_lr", sqrt(2 * cyc_lr * T))
        im_neg.requires_grad_(requires_grad=True)
        energy = model.forward(im_neg, label)

        if FLAGS.all_step:
            im_grad = torch.autograd.grad([energy.sum()], [im_neg], create_graph=True)[0]
        else:
            im_grad = torch.autograd.grad([energy.sum()], [im_neg])[0]

        if i == num_steps - 1:
            im_neg_orig = im_neg
            im_neg = im_neg - cyc_lr * im_grad

            if FLAGS.dataset in ("cifar10", "celeba", "cats"):
                n = 128
            elif FLAGS.dataset == "celebahq":
                # Save space
                n = 128
            elif FLAGS.dataset == "lsun":
                # Save space
                n = 32
            elif FLAGS.dataset == "object":
                # Save space
                n = 32
            elif FLAGS.dataset == "mnist":
                n = 128
            elif FLAGS.dataset == "imagenet":
                n = 32
            elif FLAGS.dataset == "stl":
                n = 32

            im_neg_kl = im_neg_orig[:n]
            if sample:
                pass
            else:
                energy = model.forward(im_neg_kl, label)
                im_grad = torch.autograd.grad([energy.sum()], [im_neg_kl], create_graph=True)[0]

            im_neg_kl = im_neg_kl - cyc_lr * im_grad[:n]
            im_neg_kl = torch.clamp(im_neg_kl, 0, 1)
        else:
            im_neg = im_neg - cyc_lr * im_grad

        im_neg = im_neg.detach()

        if sample:
            im_negs_samples.append(im_neg)

        im_neg = torch.clamp(im_neg, 0, 1)

    if sample:
        return im_neg, im_neg_kl, im_negs_samples, np.abs(im_grad.detach().cpu().numpy()).mean()
    else:
        return im_neg, im_neg_kl, np.abs(im_grad.detach().cpu().numpy()).mean()

In [None]:
def gen_image_psgld(label, FLAGS, model, im_neg, num_steps, sample=False):
    square_avg = torch.zeros_like(im_neg)
    im_negs_samples = []

    for i in range(num_steps):

        avg = square_avg.sqrt().add_(FLAGS.eps)
        im_noise = torch.normal(mean=0,std=avg)

        if FLAGS.anneal:
            im_neg = im_neg + 0.001 * (num_steps - i - 1) / num_steps * im_noise
        else:
            im_neg = im_neg + 0.001 * im_noise

        im_neg.requires_grad_(requires_grad=True)
        energy = model.forward(im_neg, label)

        if FLAGS.all_step:
            im_grad = torch.autograd.grad([energy.sum()], [im_neg], create_graph=True)[0]
        else:
            im_grad = torch.autograd.grad([energy.sum()], [im_neg])[0]

        square_avg.mul_(FLAGS.momentum).addcmul_(1 - FLAGS.momentum, im_neg.data, im_neg.data)
        
        if i == num_steps - 1:
            im_neg_orig = im_neg
            im_neg = im_neg - FLAGS.step_lr * im_grad / avg

            if FLAGS.dataset in ("cifar10", "celeba", "cats"):
                n = 128
            elif FLAGS.dataset == "celebahq":
                # Save space
                n = 128
            elif FLAGS.dataset == "lsun":
                # Save space
                n = 32
            elif FLAGS.dataset == "object":
                # Save space
                n = 32
            elif FLAGS.dataset == "mnist":
                n = 128
            elif FLAGS.dataset == "imagenet":
                n = 32
            elif FLAGS.dataset == "stl":
                n = 32

            im_neg_kl = im_neg_orig[:n]
            if sample:
                pass
            else:
                energy = model.forward(im_neg_kl, label)
                im_grad = torch.autograd.grad([energy.sum()], [im_neg_kl], create_graph=True)[0]

            im_neg_kl = im_neg_kl - FLAGS.step_lr * im_grad[:n] 
            im_neg_kl = torch.clamp(im_neg_kl, 0, 1)
        else:
            im_neg = im_neg - FLAGS.step_lr * im_grad

        im_neg = im_neg.detach()

        if sample:
            im_negs_samples.append(im_neg)

        im_neg = torch.clamp(im_neg, 0, 1)

    if sample:
        return im_neg, im_neg_kl, im_negs_samples, np.abs(im_grad.detach().cpu().numpy()).mean()
    else:
        return im_neg, im_neg_kl, np.abs(im_grad.detach().cpu().numpy()).mean()

In [None]:
def gen_image_asgld(label, FLAGS, model, im_neg, num_steps, sample=False):
    stepsize = 0.2
    noise_scale = np.sqrt(stepsize * 0.01)
    im_noise = torch.randn_like(im_neg).detach() * noise_scale

    im_negs_samples = []
    
    # Intialize mean and variance to zero
    mean = torch.zeros_like(im_neg.data)
    std = torch.zeros_like(im_neg.data)
    weight_decay = 5e-4
    v_noise=0.001
    momentum=0.9
    eps=1e-6
    for i in range(num_steps):
        # im_noise.normal_()
        # Getting mean,std at previous step
        old_mean = mean.clone()
        old_std = std.clone()

        im_noise = torch.normal(mean=old_mean, std=old_std)
        # updt = x_negative.data.add(v_noise,im_noise)

        if FLAGS.anneal:
            im_neg = im_neg + 0.001 * (num_steps - i - 1) / num_steps * im_noise
        else:
            im_neg = im_neg + 0.001 * im_noise

        im_neg.requires_grad_(requires_grad=True)
        energy = model.forward(im_neg, label)

        if FLAGS.all_step:
            im_grad = torch.autograd.grad([energy.sum()], [im_neg], create_graph=True)[0]
        else:
            im_grad = torch.autograd.grad([energy.sum()], [im_neg])[0]

        # Updating mean
        mean = mean.mul(momentum).add(im_neg)
        
        # Updating std
        part_var1 = im_neg.add(-old_mean)
        part_var2 = im_neg.add(-mean)
        
        new_std = torch.pow(old_std,2).mul(momentum).addcmul(1,part_var1,part_var2).add(eps)                
        new_std = torch.pow(torch.abs_(new_std),1/2)
        std.add_(-1,std).add_(new_std)        

        if i == num_steps - 1:
            im_neg_orig = im_neg
            im_neg = im_neg - FLAGS.step_lr * im_grad

            if FLAGS.dataset in ("cifar10", "celeba", "cats"):
                n = 128
            elif FLAGS.dataset == "celebahq":
                # Save space
                n = 128
            elif FLAGS.dataset == "lsun":
                # Save space
                n = 32
            elif FLAGS.dataset == "object":
                # Save space
                n = 32
            elif FLAGS.dataset == "mnist":
                n = 128
            elif FLAGS.dataset == "imagenet":
                n = 32
            elif FLAGS.dataset == "stl":
                n = 32

            im_neg_kl = im_neg_orig[:n]
            if sample:
                pass
            else:
                energy = model.forward(im_neg_kl, label)
                im_grad = torch.autograd.grad([energy.sum()], [im_neg_kl], create_graph=True)[0]
                
            im_neg_kl = im_neg_kl - FLAGS.step_lr * im_grad[:n]
            im_neg_kl = torch.clamp(im_neg_kl, 0, 1)
        else:
            im_neg = im_neg - FLAGS.step_lr * im_grad

        im_neg = im_neg.detach()

        if sample:
            im_negs_samples.append(im_neg)

        im_neg = torch.clamp(im_neg, 0, 1)
    
    if sample:
        return im_neg, im_neg_kl, im_negs_samples, np.abs(im_grad.detach().cpu().numpy()).mean()
    else:
        return im_neg, im_neg_kl, np.abs(im_grad.detach().cpu().numpy()).mean()

## Training

In [None]:
def test(model, logger, dataloader):
    pass

In [None]:
def log_tensorboard(data):  
    writer.add_scalar("replay buffer length", data["length_replay_buffer"], data["iter"])
    writer.add_scalar("repel loss", data["loss_repel"], data["iter"])
    writer.add_scalar("batch loss", data["loss"], data["iter"])
    writer.add_scalar("average loss", data["avg_loss"], data["iter"])
    writer.add_scalar("KL mean loss", data["kl_mean"], data["iter"])
    
    writer.add_scalar("FID", data["fid"], data["iter"])
    writer.add_scalar("IS mean", data["is_mean"], data["iter"])
    writer.add_scalar("IS std", data["is_std"], data["iter"])
    writer.add_scalar("SSIM", data["ssim"], data["iter"])

    writer.add_scalar("positive energy mean", data["e_pos"], data["iter"])
    writer.add_scalar("positive energy std", data["e_pos_std"], data["iter"])

    writer.add_scalar("negative energy mean", data["e_neg"], data["iter"])
    writer.add_scalar("negative energy std", data["e_neg_std"], data["iter"])

    writer.add_scalar("energy different", data["e_diff"], data["iter"])
    writer.add_scalar("x gradient", data["x_grad"], data["iter"])

    writer.add_images("positive examples", data["positive_samples"], data["iter"])
    writer.add_images("negative examples", data["negative_samples"], data["iter"])


In [None]:
def train(model, optimizer, dataloader,logdir, resume_iter, FLAGS, best_inception):
    if FLAGS.replay_batch:
        if FLAGS.reservoir:
            replay_buffer = ReservoirBuffer(FLAGS.buffer_size, FLAGS.transform, FLAGS.dataset)
        else:
            replay_buffer = ReplayBuffer(FLAGS.buffer_size, FLAGS.transform, FLAGS.dataset)

    dist_sinkhorn = SamplesLoss('sinkhorn')
    itr = resume_iter
    im_neg = None
    gd_steps = 1

    optimizer.zero_grad()

    num_steps = FLAGS.num_steps

    for epoch in range(FLAGS.epoch_num):
        print("epoch : ", epoch)
        tock = time.time()
        average_loss = 0.0
        for data_corrupt, data, label in tqdm(dataloader):
            label = label.float().to(FLAGS.gpu, non_blocking=True)
            data = data.permute(0, 3, 1, 2).float().contiguous()
            
            # Generate samples to evaluate inception score
            if itr % FLAGS.save_interval == 0:
                if FLAGS.dataset in ("cifar10", "celeba", "cats"):
                    data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (128, 32, 32, 3)))
                    repeat = 128 // FLAGS.batch_size + 1
                    label = torch.cat([label] * repeat, axis=0)
                    label = label[:128]
                elif FLAGS.dataset == "celebahq":
                    data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (data.shape[0], 128, 128, 3)))
                    label = label[:data.shape[0]]
                    data_corrupt = data_corrupt[:label.shape[0]]
                elif FLAGS.dataset == "stl":
                    data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (32, 48, 48, 3)))
                    label = label[:32]
                    data_corrupt = data_corrupt[:label.shape[0]]
                elif FLAGS.dataset == "lsun":
                    data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (32, 128, 128, 3)))
                    label = label[:32]
                    data_corrupt = data_corrupt[:label.shape[0]]
                elif FLAGS.dataset == "imagenet":
                    data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (32, 128, 128, 3)))
                    label = label[:32]
                    data_corrupt = data_corrupt[:label.shape[0]]
                elif FLAGS.dataset == "object":
                    data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (32, 128, 128, 3)))
                    label = label[:32]
                    data_corrupt = data_corrupt[:label.shape[0]]
                elif FLAGS.dataset == "mnist":
                    data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (128, 28, 28, 1)))
                    label = label[:128]
                    data_corrupt = data_corrupt[:label.shape[0]]
                else:
                    assert False
            
            data_corrupt = torch.Tensor(data_corrupt.float()).permute(0, 3, 1, 2).float().contiguous()
            data = data.to(FLAGS.gpu, non_blocking=True)
            data_corrupt = data_corrupt.to(FLAGS.gpu, non_blocking=True)
            
            if FLAGS.replay_batch and len(replay_buffer) >= FLAGS.batch_size:
                replay_batch, idxs = replay_buffer.sample(data_corrupt.size(0))
                replay_batch = decompress_x_mod(replay_batch)
                replay_mask = (
                    np.random.uniform(
                        0,
                        1,
                        data_corrupt.size(0)) > 0.001)
                data_corrupt[replay_mask] = torch.Tensor(replay_batch[replay_mask]).to(FLAGS.gpu, non_blocking=True)
            else:
                idxs = None

            if FLAGS.sampler == "psgld":
                if itr % FLAGS.save_interval == 0:
                    im_neg, im_neg_kl, im_samples, x_grad = gen_image_psgld(label, FLAGS, model, data_corrupt, num_steps, sample=True)
                else:
                    im_neg, im_neg_kl, x_grad = gen_image_psgld(label, FLAGS, model, data_corrupt, num_steps)       
            elif FLAGS.sampler == "asgld":
                if itr % FLAGS.save_interval == 0:
                    im_neg, im_neg_kl, im_samples, x_grad = gen_image_asgld(label, FLAGS, model, data_corrupt, num_steps, sample=True)
                else:
                    im_neg, im_neg_kl, x_grad = gen_image_asgld(label, FLAGS, model, data_corrupt, num_steps)
            elif FLAGS.sampler == "sgld":
                if itr % FLAGS.save_interval == 0:
                    im_neg, im_neg_kl, im_samples, x_grad = gen_image(label, FLAGS, model, data_corrupt, num_steps, sample=True)
                else:
                    im_neg, im_neg_kl, x_grad = gen_image(label, FLAGS, model, data_corrupt, num_steps)
            elif FLAGS.sampler == "cycsgld":
                if itr % FLAGS.save_interval == 0:
                    im_neg, im_neg_kl, im_samples, x_grad = gen_image_cycsgld(label, FLAGS, model, data_corrupt, num_steps, sample=True)
                else:
                    im_neg, im_neg_kl, x_grad = gen_image_cycsgld(label, FLAGS, model, data_corrupt, num_steps)
            elif FLAGS.sampler == "resgld":
                if itr % FLAGS.save_interval == 0:
                    im_neg, im_neg_kl, im_samples, x_grad = gen_image_resgld(label, FLAGS, model, data_corrupt, num_steps, sample=True)
                else:
                    im_neg, im_neg_kl, x_grad = gen_image_resgld(label, FLAGS, model, data_corrupt, num_steps)
            elif FLAGS.sampler == "csgld":
                if itr % FLAGS.save_interval == 0:
                    im_neg, im_neg_kl, im_samples, x_grad = gen_image_csgld(label, FLAGS, model, data_corrupt, num_steps, sample=True)
                else:
                    im_neg, im_neg_kl, x_grad = gen_image_csgld(label, FLAGS, model, data_corrupt, num_steps)
            else:
                assert False
            
            data_corrupt = None
            energy_pos = model.forward(data, label[:data.size(0)])
            energy_neg = model.forward(im_neg, label)
            
            if FLAGS.replay_batch and (im_neg is not None):
                replay_buffer.add(compress_x_mod(im_neg.detach().cpu().numpy()))

            loss = energy_pos.mean() - energy_neg.mean() 
            loss = loss  + (torch.pow(energy_pos, 2).mean() + torch.pow(energy_neg, 2).mean())

            if FLAGS.kl:
                model.requires_grad_(False)
                loss_kl = model.forward(im_neg_kl, label)
                model.requires_grad_(True)
                loss = loss + FLAGS.kl_coeff * loss_kl.mean()

                if FLAGS.repel_im:
                    start = timeit.timeit()
                    bs = im_neg_kl.size(0)

                    if FLAGS.dataset in ["celebahq", "imagenet", "object", "lsun", "stl"]:
                        im_neg_kl = im_neg_kl[:, :, :, :].contiguous()

                    im_flat = torch.clamp(im_neg_kl.view(bs, -1), 0, 1)

                    if FLAGS.dataset in ("cifar10", "celeba", "cats"):
                        if len(replay_buffer) > 1000:
                            compare_batch, idxs = replay_buffer.sample(100, no_transform=False)
                            compare_batch = decompress_x_mod(compare_batch)
                            compare_batch = torch.Tensor(compare_batch).to(FLAGS.gpu, non_blocking=True)
                            compare_flat = compare_batch.view(100, -1)

                            if FLAGS.entropy == 'kl':
                                dist_matrix = torch.norm(im_flat[:, None, :] - compare_flat[None, :, :], p=2, dim=-1)
                                loss_repel = torch.log(dist_matrix.min(dim=1)[0]).mean()
                                # loss_repel = kldiv(im_flat, compare_flat)
                                loss = loss - 0.3 * loss_repel
                            elif FLAGS.entropy == 'sinkhorn':
                                dist_matrix = dist_sinkhorn(im_flat, compare_flat)
                                loss_repel = torch.log(dist_matrix).sum()
                                loss = loss - 0.03 * loss_repel
                            else:
                                assert False
                                
                                                      
                        else:
                            loss_repel = torch.zeros(1)
                        
                        # loss = loss - 0.3 * loss_repel
                    else:
                        if len(replay_buffer) > 1000:
                            compare_batch, idxs = replay_buffer.sample(100, no_transform=False, downsample=True)
                            compare_batch = decompress_x_mod(compare_batch)
                            compare_batch = torch.Tensor(compare_batch).to(FLAGS.gpu, non_blocking=True)
                            compare_flat = compare_batch.view(100, -1)
                            
                            if FLAGS.entropy == 'kl':
                                dist_matrix = torch.norm(im_flat[:, None, :] - compare_flat[None, :, :], p=2, dim=-1)
                                loss_repel = torch.log(dist_matrix.min(dim=1)[0]).mean()
                                # loss_repel = kldiv(im_flat, compare_flat)
                            elif FLAGS.entropy == 'sinkhorn':
                                dist_matrix = dist_sinkhorn(im_flat, compare_flat)
                                loss_repel = torch.log(dist_matrix).sum()
                            else:
                                assert False
                        else:
                            loss_repel = torch.zeros(1).to(FLAGS.gpu, non_blocking=True)

                        if FLAGS.entropy == 'kl':
                            loss = loss - 0.3 * loss_repel  
                        elif FLAGS.entropy == 'sinkhorn':
                            loss = loss - 0.03 * loss_repel
                        else:
                            assert False

                    end = timeit.timeit()
                else:
                    loss_repel = torch.zeros(1)

            else:
                loss_kl = torch.zeros(1)
                loss_repel = torch.zeros(1)

            if FLAGS.log_grad and len(replay_buffer) > 1000:
                loss_kl = loss_kl - 0.1 * loss_repel
                loss_kl = loss_kl.mean()
                loss_ml = energy_pos.mean() - energy_neg.mean()

                loss_ml.backward(retain_graph=True)
                ele = []

                for param in model.parameters():
                    if param.grad is not None:
                        ele.append(torch.norm(param.grad.data))

                ele = torch.stack(ele, dim=0)
                ml_grad = torch.mean(ele)
                model.zero_grad()

                loss_kl.backward(retain_graph=True) 
                ele = []

                for param in model.parameters():
                    if param.grad is not None:
                        ele.append(torch.norm(param.grad.data))

                ele = torch.stack(ele, dim=0)
                kl_grad = torch.mean(ele)
                model.zero_grad()

            else:
                ml_grad = None
                kl_grad = None

            loss.backward()

            clip_grad_norm_(model.parameters(), 0.5)

            optimizer.step()
            optimizer.zero_grad()

            # ema_model(models, models_ema)

            if torch.isnan(energy_pos.mean()):
                assert False

            if torch.abs(energy_pos.mean()) > 10.0:
                assert False
            average_loss += (loss - average_loss) / (itr + 1)
            if itr % FLAGS.log_interval == 0:
                tick = time.time()
                if FLAGS.dataset == "mnist":
                    IS, FID = (0, 0), 0
                else:
                    IS, FID = get_inception_score_and_fid(im_neg, './cats_test.npz', verbose=True)
                
                ssim_value = ssim(im_neg.to(FLAGS.gpu, non_blocking=True), data.to(FLAGS.gpu, non_blocking=True))

                kvs = {}
                kvs['fid'] = FID
                kvs['is_mean'] = IS[0]
                kvs['is_std'] = IS[1]
                kvs['ssim'] = ssim_value
                kvs['e_pos'] = energy_pos.mean().item()
                kvs['e_pos_std'] = energy_pos.std().item()
                kvs['e_neg'] = energy_neg.mean().item()
                kvs['kl_mean'] = loss_kl.mean().item()
                kvs['loss_repel'] = loss_repel.mean().item()
                kvs['loss'] = loss
                kvs['avg_loss'] = average_loss
                kvs['e_neg_std'] = energy_neg.std().item()
                kvs['e_diff'] = kvs['e_pos'] - kvs['e_neg']
                # kvs['x_grad'] = np.abs(x_grad.detach().cpu().numpy()).mean()
                kvs['x_grad'] = x_grad
                kvs['iter'] = itr
                # kvs['hmc_loss'] = hmc_loss.item()
                kvs['num_steps'] = num_steps
                # kvs['t_diff'] = tick - tock
                kvs['positive_samples'] = data.detach()
                kvs['negative_samples'] = im_neg.detach()

                if FLAGS.replay_batch:
                    kvs['length_replay_buffer'] = len(replay_buffer)

                # if (ml_grad is not None):
                #     kvs['kl_grad'] = kl_grad
                #     kvs['ml_grad'] = ml_grad

                log_tensorboard(kvs)
                tock = tick

            if itr % FLAGS.save_interval == 0 and (FLAGS.save_interval != 0):
                model_path = osp.join(logdir, "model_{}.pth".format(itr))
                ckpt = {'optimizer_state_dict': optimizer.state_dict(),
                            'FLAGS': FLAGS, 'best_inception': best_inception}

                for i in range(FLAGS.ensembles):
                    ckpt['model_state_dict_{}'.format(i)] = model.state_dict()
                    # ckpt['ema_model_state_dict_{}'.format(i)] = model.state_dict()

                torch.save(ckpt, model_path)

            # if itr % FLAGS.save_interval == 0 and rank_idx == 0:
            #     im_samples = im_samples[::10]
            #     im_samples_total = torch.stack(im_samples, dim=1).detach().cpu().permute(0, 1, 3, 4, 2).numpy()
            #     try_im = im_neg
            #     orig_im = data_corrupt
            #     actual_im = rescale_im(data.detach().permute(0, 2, 3, 1).cpu().numpy())

            #     orig_im = rescale_im(orig_im.detach().permute(0, 2, 3, 1).cpu().numpy())
            #     try_im = rescale_im(try_im.detach().permute(0, 2, 3, 1).cpu().numpy()).squeeze()
            #     im_samples_total = rescale_im(im_samples_total)

            #     if rank_idx == 0:
            #         score, std = get_inception_score(list(try_im), splits=1)
            #         print("Inception score of {} with std of {}".format(
            #                 score, std))
            #         # kvs = {}
            #         # kvs['inception_score'] = score
            #         # kvs['inception_score_std'] = std
            #         # logger.writekvs(kvs)
            #         writer.add_scalar("inception score", score, itr)
            #         writer.add_scalar("inception score std", std, itr)

            #         if score > best_inception:
            #             model_path = osp.join(logdir, "model_best.pth")
            #             torch.save(ckpt, model_path)
            #             best_inception = score

            itr += 1

In [None]:
def main_single(FLAGS):
    print("Values of args: ", FLAGS)

    if FLAGS.dataset == "cifar10":
        train_dataset = Cifar10(FLAGS)
        # valid_dataset = Cifar10(FLAGS, split='valid', augment=False)
        # test_dataset = Cifar10(FLAGS, split='test', augment=False)
    elif FLAGS.dataset == "celeba":
        train_dataset = CelebADataset(FLAGS)
        # valid_dataset = CelebADataset(FLAGS, train=False, augment=False)
        # test_dataset = CelebADataset(FLAGS, train=False, augment=False)
    elif FLAGS.dataset == "cats":
        train_dataset = Cats()
    elif FLAGS.dataset == "stl":
        train_dataset = STLDataset(FLAGS)
        # valid_dataset = STLDataset(FLAGS, train=False)
        # test_dataset = STLDataset(FLAGS, train=False)
    elif FLAGS.dataset == "object":
        train_dataset = ObjectDataset(FLAGS.cond_idx)
        # valid_dataset = ObjectDataset(FLAGS.cond_idx)
        # test_dataset = ObjectDataset(FLAGS.cond_idx)
    elif FLAGS.dataset == "imagenet":
        train_dataset = ImageNet()
        # valid_dataset = ImageNet()
        # test_dataset = ImageNet()
    elif FLAGS.dataset == "mnist":
        train_dataset = Mnist(train=True)
        # valid_dataset = Mnist(train=False)
        # test_dataset = Mnist(train=False)
    elif FLAGS.dataset == "celebahq":
        train_dataset = CelebAHQ(cond_idx=FLAGS.cond_idx)
        # valid_dataset = CelebAHQ(cond_idx=FLAGS.cond_idx)
        # test_dataset = CelebAHQ(cond_idx=FLAGS.cond_idx)
    elif FLAGS.dataset == "lsun":
        train_dataset = LSUNBed(cond_idx=FLAGS.cond_idx)
        # valid_dataset = LSUNBed(cond_idx=FLAGS.cond_idx)
        # test_dataset = LSUNBed(cond_idx=FLAGS.cond_idx)
    else:
        assert False

    train_dataloader = DataLoader(train_dataset, num_workers=FLAGS.data_workers, batch_size=FLAGS.batch_size, shuffle=True, drop_last=True)
    # valid_dataloader = DataLoader(valid_dataset, num_workers=FLAGS.data_workers, batch_size=FLAGS.batch_size, shuffle=True, drop_last=True)
    # test_dataloader = DataLoader(test_dataset, num_workers=FLAGS.data_workers, batch_size=FLAGS.batch_size, shuffle=True, drop_last=True)

    logdir = osp.join(sample_dir, FLAGS.exp, FLAGS.dataset)

    best_inception = 0.0
    
    if FLAGS.resume_iter != 0:
        FLAGS_OLD = FLAGS
        model_path = osp.join(logdir, "model_{}.pth".format(FLAGS.resume_iter))
        checkpoint = torch.load(model_path)
        best_inception = checkpoint['best_inception']
        FLAGS = checkpoint['FLAGS']

        FLAGS.resume_iter = FLAGS_OLD.resume_iter
        FLAGS_OLD = None

    if FLAGS.dataset in ("cifar10", "celeba", "cats"):
        model_fn = ResNetModel
    elif FLAGS.dataset == "stl":
        model_fn = ResNetModel
    elif FLAGS.dataset == "object":
        model_fn = CelebAModel
    elif FLAGS.dataset == "mnist":
        model_fn = MNISTModel
    elif FLAGS.dataset == "celebahq":
        model_fn = CelebAModel
    elif FLAGS.dataset == "lsun":
        model_fn = CelebAModel
    elif FLAGS.dataset == "imagenet":
        model_fn = ImagenetModel
    else:
        assert False

    model = model_fn(FLAGS).train()
    # models_ema = model_fn(FLAGS).train()

    if FLAGS.cuda:
        model = model.to(FLAGS.gpu)

    optimizer = Adam(model.parameters(), lr=FLAGS.lr, betas=(0.0, 0.9), eps=1e-8)

    # ema_model(models, models_ema, mu=0.0)

    it = FLAGS.resume_iter

    if not osp.exists(logdir):
        os.makedirs(logdir)

    checkpoint = None
    if FLAGS.resume_iter != 0:
        print("FLAGS.resume_iter:",FLAGS.resume_iter)
        model_path = osp.join(logdir, "model_{}.pth".format(FLAGS.resume_iter))
        checkpoint = torch.load(model_path)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        for i in range(FLAGS.ensembles):
            model.load_state_dict(checkpoint['model_state_dict_{}'.format(i)])
            # model_ema.load_state_dict(checkpoint['ema_model_state_dict_{}'.format(i)])
 

    print("New Values of args: ", FLAGS)

    pytorch_total_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
    print("Number of parameters for models", pytorch_total_params)

    train(model, optimizer, train_dataloader, logdir, FLAGS.resume_iter, FLAGS, best_inception)

## Calculate FID AND IS

In [None]:
try:
    from torchvision.models.utils import load_state_dict_from_url
except ImportError:
    from torch.utils.model_zoo import load_url as load_state_dict_from_url
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'

In [None]:
class InceptionV3(nn.Module):
    """Pretrained InceptionV3 network returning feature maps"""

    # Index of default block of inception to return,
    # corresponds to output of final average pooling
    DEFAULT_BLOCK_INDEX = 3

    # Maps feature dimensionality to their output blocks indices
    BLOCK_INDEX_BY_DIM = {
        64: 0,      # First max pooling features
        192: 1,     # Second max pooling featurs
        768: 2,     # Pre-aux classifier features
        2048: 3,    # Final average pooling features
        'prob': 4,  # softmax layer
    }

    def __init__(self,
                 output_blocks=[DEFAULT_BLOCK_INDEX],
                 resize_input=True,
                 normalize_input=True,
                 requires_grad=False,
                 use_fid_inception=True):
        """Build pretrained InceptionV3

        Parameters
        ----------
        output_blocks : list of int
            Indices of blocks to return features of. Possible values are:
                - 0: corresponds to output of first max pooling
                - 1: corresponds to output of second max pooling
                - 2: corresponds to output which is fed to aux classifier
                - 3: corresponds to output of final average pooling
        resize_input : bool
            If true, bilinearly resizes input to width and height 299 before
            feeding input to model. As the network without fully connected
            layers is fully convolutional, it should be able to handle inputs
            of arbitrary size, so resizing might not be strictly needed
        normalize_input : bool
            If true, scales the input from range (0, 1) to the range the
            pretrained Inception network expects, namely (-1, 1)
        requires_grad : bool
            If true, parameters of the model require gradients. Possibly useful
            for finetuning the network
        use_fid_inception : bool
            If true, uses the pretrained Inception model used in Tensorflow's
            FID implementation. If false, uses the pretrained Inception model
            available in torchvision. The FID Inception model has different
            weights and a slightly different structure from torchvision's
            Inception model. If you want to compute FID scores, you are
            strongly advised to set this parameter to true to get comparable
            results.
        """
        super(InceptionV3, self).__init__()

        self.resize_input = resize_input
        self.normalize_input = normalize_input
        self.output_blocks = sorted(output_blocks)
        self.last_needed_block = max(output_blocks)

        # assert self.last_needed_block <= 3, \
        #     'Last possible output block index is 3'

        self.blocks = nn.ModuleList()

        if use_fid_inception:
            inception = fid_inception_v3()
        else:
            inception = models.inception_v3(
                pretrained=True, init_weights=False)

        # Block 0: input to maxpool1
        block0 = [
            inception.Conv2d_1a_3x3,
            inception.Conv2d_2a_3x3,
            inception.Conv2d_2b_3x3,
            nn.MaxPool2d(kernel_size=3, stride=2)
        ]
        self.blocks.append(nn.Sequential(*block0))

        # Block 1: maxpool1 to maxpool2
        if self.last_needed_block >= 1:
            block1 = [
                inception.Conv2d_3b_1x1,
                inception.Conv2d_4a_3x3,
                nn.MaxPool2d(kernel_size=3, stride=2)
            ]
            self.blocks.append(nn.Sequential(*block1))

        # Block 2: maxpool2 to aux classifier
        if self.last_needed_block >= 2:
            block2 = [
                inception.Mixed_5b,
                inception.Mixed_5c,
                inception.Mixed_5d,
                inception.Mixed_6a,
                inception.Mixed_6b,
                inception.Mixed_6c,
                inception.Mixed_6d,
                inception.Mixed_6e,
            ]
            self.blocks.append(nn.Sequential(*block2))

        # Block 3: aux classifier to final avgpool
        if self.last_needed_block >= 3:
            block3 = [
                inception.Mixed_7a,
                inception.Mixed_7b,
                inception.Mixed_7c,
                nn.AdaptiveAvgPool2d(output_size=(1, 1))
            ]
            self.blocks.append(nn.Sequential(*block3))

        if self.last_needed_block >= 4:
            self.fc = inception.fc
            self.fc.bias = None

        for param in self.parameters():
            param.requires_grad = requires_grad

    def forward(self, inp):
        """Get Inception feature maps

        Parameters
        ----------
        inp : torch.autograd.Variable
            Input tensor of shape Bx3xHxW. Values are expected to be in
            range (0, 1)

        Returns
        -------
        List of torch.autograd.Variable, corresponding to the selected output
        block, sorted ascending by index
        """
        outp = []
        x = inp

        if self.resize_input:
            x = F.interpolate(x,
                              size=(299, 299),
                              mode='bilinear',
                              align_corners=False)

        if self.normalize_input:
            x = 2 * x - 1  # Scale from range (0, 1) to range (-1, 1)

        for idx, block in enumerate(self.blocks):
            x = block(x)
            if idx in self.output_blocks:
                outp.append(x)

            if idx == self.last_needed_block:
                break

        if self.last_needed_block >= 4:
            x = F.dropout(x, training=self.training)
            # N x 2048 x 1 x 1
            x = torch.flatten(x, 1)
            # N x 2048
            x = self.fc(x)
            x = F.softmax(x, dim=1)
            outp.append(x)

        return outp

In [None]:
def fid_inception_v3():
    """Build pretrained Inception model for FID computation

    The Inception model for FID computation uses a different set of weights
    and has a slightly different structure than torchvision's Inception.

    This method first constructs torchvision's Inception and then patches the
    necessary parts that are different in the FID Inception model.
    """
    inception = models.inception_v3(num_classes=1008,
                                    aux_logits=False,
                                    pretrained=False,
                                    init_weights=False)
    inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
    inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
    inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
    inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
    inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
    inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
    inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
    inception.Mixed_7b = FIDInceptionE_1(1280)
    inception.Mixed_7c = FIDInceptionE_2(2048)

    state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
    inception.load_state_dict(state_dict)
    return inception

In [None]:
class FIDInceptionA(models.inception.InceptionA):
    """InceptionA block patched for FID computation"""
    def __init__(self, in_channels, pool_features):
        super(FIDInceptionA, self).__init__(in_channels, pool_features)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        # Patch: Tensorflow's average pool does not use the padded zero's in
        # its average calculation
        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
                                   count_include_pad=False)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)


In [None]:
class FIDInceptionC(models.inception.InceptionC):
    """InceptionC block patched for FID computation"""
    def __init__(self, in_channels, channels_7x7):
        super(FIDInceptionC, self).__init__(in_channels, channels_7x7)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        # Patch: Tensorflow's average pool does not use the padded zero's in
        # its average calculation
        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
                                   count_include_pad=False)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
        return torch.cat(outputs, 1)

In [None]:
class FIDInceptionE_1(models.inception.InceptionE):
    """First InceptionE block patched for FID computation"""
    def __init__(self, in_channels):
        super(FIDInceptionE_1, self).__init__(in_channels)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        # Patch: Tensorflow's average pool does not use the padded zero's in
        # its average calculation
        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
                                   count_include_pad=False)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)

In [None]:
class FIDInceptionE_2(models.inception.InceptionE):
    """Second InceptionE block patched for FID computation"""
    def __init__(self, in_channels):
        super(FIDInceptionE_2, self).__init__(in_channels)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        # Patch: The FID Inception model uses max pooling instead of average
        # pooling. This is likely an error in this specific Inception
        # implementation, as other Inception models use average pooling here
        # (which matches the description in the paper).
        branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)

In [None]:
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6,
                               use_torch=False):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).

    Stable version by Dougal J. Sutherland.

    Params:
    -- mu1   : Numpy array containing the activations of a layer of the
               inception net (like returned by the function 'get_predictions')
               for generated samples.
    -- mu2   : The sample mean over activations, precalculated on an
               representative data set.
    -- sigma1: The covariance matrix over activations for generated samples.
    -- sigma2: The covariance matrix over activations, precalculated on an
               representative data set.

    Returns:
    --   : The Frechet Distance.
    """

    if use_torch:
        assert mu1.shape == mu2.shape, \
            'Training and test mean vectors have different lengths'
        assert sigma1.shape == sigma2.shape, \
            'Training and test covariances have different dimensions'

        diff = mu1 - mu2
        # Run 50 itrs of newton-schulz to get the matrix sqrt of
        # sigma1 dot sigma2
        covmean = sqrt_newton_schulz(sigma1.mm(sigma2).unsqueeze(0), 50)
        if torch.any(torch.isnan(covmean)):
            return float('nan')
        covmean = covmean.squeeze()
        out = (diff.dot(diff) +
               torch.trace(sigma1) +
               torch.trace(sigma2) -
               2 * torch.trace(covmean)).cpu().item()
    else:
        mu1 = np.atleast_1d(mu1)
        mu2 = np.atleast_1d(mu2)

        sigma1 = np.atleast_2d(sigma1)
        sigma2 = np.atleast_2d(sigma2)

        assert mu1.shape == mu2.shape, \
            'Training and test mean vectors have different lengths'
        assert sigma1.shape == sigma2.shape, \
            'Training and test covariances have different dimensions'

        diff = mu1 - mu2

        # Product might be almost singular
        covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
        if not np.isfinite(covmean).all():
            msg = ('fid calculation produces singular product; '
                   'adding %s to diagonal of cov estimates') % eps
            print(msg)
            offset = np.eye(sigma1.shape[0]) * eps
            covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

        # Numerical error might give slight imaginary component
        if np.iscomplexobj(covmean):
            if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
                m = np.max(np.abs(covmean.imag))
                raise ValueError('Imaginary component {}'.format(m))
            covmean = covmean.real

        tr_covmean = np.trace(covmean)

        out = (diff.dot(diff) +
               np.trace(sigma1) +
               np.trace(sigma2) -
               2 * tr_covmean)
    return out

In [None]:
def get_inception_score_and_fid(
        images,
        fid_stats_path,
        splits=10,
        batch_size=50,
        is_dataloader=False,
        use_torch=False,
        verbose=False):
    """Calculate Inception Score and FID.
    For each image, only a forward propagation is required to
    calculating features for FID and Inception Score.

    Args:
        images: List of tensor or torch.utils.data.Dataloader. The return image
                must be float tensor of range [0, 1].
        fid_stats_path: str, Path to pre-calculated statistic
        splits: The number of bins of Inception Score. Default is 10.
        batch_size: int, The batch size for calculating activations. If
                    `images` is torch.utils.data.Dataloader, this arguments
                    does not work.
        use_torch: bool. The default value is False and the backend is same as
                   official implementation, i.e., numpy. If use_torch is
                   enableb, the backend linalg is implemented by torch, the
                   results are not guaranteed to be consistent with numpy, but
                   the speed can be accelerated by GPU.
        verbose: int. Set verbose to 0 for disabling progress bar. Otherwise,
                 the progress bar is showing when calculating activations.
    Returns:
        inception_score: float tuple, (mean, std)
        fid: float
    """
    if is_dataloader:
        assert isinstance(images, DataLoader)
        num_images = min(len(images.dataset), images.batch_size * len(images))
        batch_size = images.batch_size
    else:
        num_images = len(images)

    block_idx1 = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
    block_idx2 = InceptionV3.BLOCK_INDEX_BY_DIM['prob']
    model = InceptionV3([block_idx1, block_idx2]).to(device)
    model.eval()

    if use_torch:
        fid_acts = torch.empty((num_images, 2048)).to(device)
        is_probs = torch.empty((num_images, 1008)).to(device)
    else:
        fid_acts = np.empty((num_images, 2048))
        is_probs = np.empty((num_images, 1008))

    pbar = tqdm(
        total=num_images, dynamic_ncols=True, leave=False,
        disable=not verbose, desc="get_inception_score_and_fid")
    looper = iter(images)
    start = 0
    while start < num_images:
        # get a batch of images from iterator
        if is_dataloader:
            batch_images = next(looper)
        else:
            batch_images = images[start: start + batch_size]
        end = start + len(batch_images)

        # calculate inception feature
        batch_images = batch_images.to(device)
        with torch.no_grad():
            pred = model(batch_images)
            if use_torch:
                fid_acts[start: end] = pred[0].view(-1, 2048)
                is_probs[start: end] = pred[1]
            else:
                fid_acts[start: end] = pred[0].view(-1, 2048).cpu().numpy()
                is_probs[start: end] = pred[1].cpu().numpy()
        start = end
        pbar.update(len(batch_images))
    pbar.close()

    # Inception Score
    scores = []
    for i in range(splits):
        part = is_probs[
            (i * is_probs.shape[0] // splits):
            ((i + 1) * is_probs.shape[0] // splits), :]
        if use_torch:
            kl = part * (
                torch.log(part) -
                torch.log(torch.unsqueeze(torch.mean(part, 0), 0)))
            kl = torch.mean(torch.sum(kl, 1))
            scores.append(torch.exp(kl))
        else:
            kl = part * (
                np.log(part) -
                np.log(np.expand_dims(np.mean(part, 0), 0)))
            kl = np.mean(np.sum(kl, 1))
            scores.append(np.exp(kl))
    if use_torch:
        scores = torch.stack(scores)
        is_score = (torch.mean(scores).cpu().item(),
                    torch.std(scores).cpu().item())
    else:
        is_score = (np.mean(scores), np.std(scores))

    # FID Score
    f = np.load(fid_stats_path)
    m2, s2 = f['mu'][:], f['sigma'][:]
    f.close()
    if use_torch:
        m1 = torch.mean(fid_acts, axis=0)
        s1 = torch_cov(fid_acts, rowvar=False)
        m2 = torch.tensor(m2).to(m1.dtype).to(device)
        s2 = torch.tensor(s2).to(s1.dtype).to(device)
    else:
        m1 = np.mean(fid_acts, axis=0)
        s1 = np.cov(fid_acts, rowvar=False)
    fid_score = calculate_frechet_distance(m1, s1, m2, s2, use_torch=use_torch)

    del fid_acts, is_probs, scores, model
    return is_score, fid_score


## SSIM

In [None]:
def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

In [None]:
def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

In [None]:
def _ssim(img1, img2, window, window_size, channel, size_average = True):
    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

In [None]:
def ssim(img1, img2, window_size = 11, size_average = True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)
    
    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)
    
    return _ssim(img1, img2, window, window_size, channel, size_average)

In [None]:
if flags.dataset == "celebahq":
    !mkdir -p /content/data/celebAHQ
    !unzip -qq '/content/drive/MyDrive/Colab Notebooks/improved_contrastive_divergence/data/celebAHQ/data128x128.zip' -d /content/data/celebAHQ
elif flags.dataset == "celeba":
    !mkdir -p /content/data
    %cd /content/drive/MyDrive/Colab Notebooks/improved_contrastive_divergence.v5
    %cp -av data/celeba/ /content/data
elif flags.dataset == "cats":
    !mkdir -p /content/data
    %cd /content/drive/MyDrive/Colab Notebooks/improved_contrastive_divergence.v5
    %cp -av data/cats/ /content/data
    !unzip -qq /content/data/cats/cats-dataset.zip -d /content/data/cats

In [None]:
 tensorboard --logdir runs

In [None]:
main_single(flags)