<a href="https://colab.research.google.com/github/karimul/improved_contrastive_divergence/blob/master/improved_contrastive_divergence_v3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Mounting to Google Drive

In [None]:
from google.colab import drive
import os
drive.mount('/content/drive')

ROOT = "/content/drive/MyDrive/Colab Notebooks"
sample_dir = os.path.join(ROOT, 'improved_contrastive_divergence')
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
os.chdir(sample_dir)

## Dependencies

In [None]:
from easydict import EasyDict
import math
from tqdm import tqdm
import time
import timeit
import os.path as osp
import pandas as pd
from PIL import Image
import pickle
from imageio import imread
import cv2

import torch.nn as nn

from torch.utils.data import Dataset
import torchvision
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import MNIST
from torch.nn import Dropout
from torch.optim import Adam, SGD
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_


import numpy as np
import random
import matplotlib.pyplot as plt

In [None]:
%load_ext tensorboard

## Configuration

In [None]:
flags = EasyDict()

# Configurations for distributed training
flags['slurm'] = False # whether we are on slurm
flags['repel_im'] = True # maximize entropy by repeling images from each other
flags['hmc'] = False # use the hamiltonian monte carlo sampler
flags['asgld'] = True # use the adaptively precondition SGLD sampler
flags['square_energy'] = False # make the energy square
flags['alias'] = False # make the energy square
flags['cpu'] = torch.device("cpu")
flags['gpu'] = torch.device("cuda:0")

flags['dataset'] = 'cats' # cifar10 or celeba
flags['batch_size'] = 128 #128 # batch size during training
flags['multiscale'] = True # A multiscale EBM
flags['self_attn'] = True #Use self attention in models
flags['sigmoid'] = False # Apply sigmoid on energy (can improve the stability)
flags['anneal'] = False # Decrease noise over Langevin steps
flags['data_workers'] = 4 # Number of different data workers to load data in parallel
flags['buffer_size'] = 10000 # Size of inputs

# General Experiment Settings
flags['exp'] = 'default' #name of experiments
flags['log_interval'] = 50 #log outputs every so many batches
flags['save_interval'] = 500 # save outputs every so many batches
flags['test_interval'] = 500 # evaluate outputs every so many batches
flags['resume_iter'] = 0 #iteration to resume training from
flags['train'] = True # whether to train or test
flags['transform'] = True # apply data augmentation when sampling from the replay buffer
flags['kl'] = True # apply a KL term to loss
flags['cuda'] = True # move device on cuda
flags['epoch_num'] = 10000 # Number of Epochs to train on
flags['ensembles'] = 1 #Number of ensembles to train models with
flags['lr'] = 2e-4 #Learning for training
flags['kl_coeff'] = 1.0 #coefficient for kl

# EBM Specific Experiments Settings
flags_objective = 'cd' #use the cd objective

# Setting for MCMC sampling
flags['num_steps'] = 40 # Steps of gradient descent for training
flags['step_lr'] = 100.0 # Size of steps for gradient descent
flags['replay_batch'] = True # Use MCMC chains initialized from a replay buffer.
flags['reservoir'] = True # Use a reservoir of past entires
flags['noise_scale'] = 1. # Relative amount of noise for MCMC
flags['init_noise'] = 0.1
flags['momentum'] = 0.9
flags['eps'] = 1e-6
flags['step_size'] = 10

# Architecture Settings
flags['filter_dim'] = 64 #64 #number of filters for conv nets
flags['im_size'] = 32 #32 #size of images
flags['spec_norm'] = False #Whether to use spectral normalization on weights
flags['norm'] = True #Use group norm in models norm in models

# Conditional settings
flags['cond'] = False #conditional generation with the model
flags['all_step'] = False #backprop through all langevin steps
flags['log_grad'] = False #log the gradient norm of the kl term
flags['cond_idx'] = 0 #conditioned index

In [None]:
writer = SummaryWriter(comment="_ASGLD_{dataset}".format(dataset=flags.dataset))

## Models

In [None]:
class Self_Attn(nn.Module):
    """ Self attention Layer"""
    def __init__(self,in_dim,activation):
        super(Self_Attn,self).__init__()
        self.chanel_in = in_dim
        self.activation = activation

        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1) #

    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N) 
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N

        out = torch.bmm(proj_value,attention.permute(0,2,1) )
        out = out.view(m_batchsize,C,width,height)

        out = self.gamma*out + x
        return out,attention

In [None]:
class CondResBlock(nn.Module):
    def __init__(self, args, downsample=True, rescale=True, filters=64, latent_dim=64, im_size=64, classes=512, norm=True, spec_norm=False):
        super(CondResBlock, self).__init__()

        self.filters = filters
        self.latent_dim = latent_dim
        self.im_size = im_size
        self.downsample = downsample

        if filters <= 128:
            self.bn1 = nn.InstanceNorm2d(filters, affine=True)
        else:
            self.bn1 = nn.GroupNorm(32, filters)

        if not norm:
            self.bn1 = None

        self.args = args

        if spec_norm:
            self.conv1 = spectral_norm(nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1))
        else:
            self.conv1 = WSConv2d(filters, filters, kernel_size=3, stride=1, padding=1)

        if filters <= 128:
            self.bn2 = nn.InstanceNorm2d(filters, affine=True)
        else:
            self.bn2 = nn.GroupNorm(32, filters, affine=True)

        if not norm:
            self.bn2 = None

        if spec_norm:
            self.conv2 = spectral_norm(nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1))
        else:
            self.conv2 = WSConv2d(filters, filters, kernel_size=3, stride=1, padding=1)

        self.dropout = Dropout(0.2)

        # Upscale to an mask of image
        self.latent_map = nn.Linear(classes, 2*filters)
        self.latent_map_2 = nn.Linear(classes, 2*filters)

        self.relu = torch.nn.ReLU(inplace=True)
        self.act = swish

        # Upscale to mask of image
        if downsample:
            if rescale:
                self.conv_downsample = nn.Conv2d(filters, 2 * filters, kernel_size=3, stride=1, padding=1)

                if args.alias:
                    self.avg_pool = Downsample(channels=2*filters)
                else:
                    self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1)
            else:
                self.conv_downsample = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)

                if args.alias:
                    self.avg_pool = Downsample(channels=filters)
                else:
                    self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1)


    def forward(self, x, y):
        x_orig = x

        if y is not None:
            latent_map = self.latent_map(y).view(-1, 2*self.filters, 1, 1)

            gain = latent_map[:, :self.filters]
            bias = latent_map[:, self.filters:]

        x = self.conv1(x)

        if self.bn1 is not None:
            x = self.bn1(x)

        if y is not None:
            x = gain * x + bias

        x = self.act(x)

        if y is not None:
            latent_map = self.latent_map_2(y).view(-1, 2*self.filters, 1, 1)
            gain = latent_map[:, :self.filters]
            bias = latent_map[:, self.filters:]

        x = self.conv2(x)

        if self.bn2 is not None:
            x = self.bn2(x)

        if y is not None:
            x = gain * x + bias

        x = self.act(x)

        x_out = x

        if self.downsample:
            x_out = self.conv_downsample(x_out)
            x_out = self.act(self.avg_pool(x_out))

        return x_out

## MNIST Model

In [None]:
class MNISTModel(nn.Module):
    def __init__(self, args):
        super(MNISTModel, self).__init__()
        self.act = swish
        # self.relu = torch.nn.ReLU(inplace=True)

        self.args = args
        self.filter_dim = args.filter_dim
        self.init_main_model()
        self.init_label_map()
        self.filter_dim = args.filter_dim

        # self.act = self.relu
        self.cond = args.cond
        self.sigmoid = args.sigmoid


    def init_main_model(self):
        args = self.args
        filter_dim = self.filter_dim
        im_size = 28
        self.conv1 = nn.Conv2d(1, filter_dim, kernel_size=3, stride=1, padding=1)
        self.res1 = CondResBlock(args, filters=filter_dim, latent_dim=1, im_size=im_size)
        self.res2 = CondResBlock(args, filters=2*filter_dim, latent_dim=1, im_size=im_size)

        self.res3 = CondResBlock(args, filters=4*filter_dim, latent_dim=1, im_size=im_size)
        self.energy_map = nn.Linear(filter_dim*8, 1)


    def init_label_map(self):
        args = self.args

        self.map_fc1 = nn.Linear(10, 256)
        self.map_fc2 = nn.Linear(256, 256)

    def main_model(self, x, latent):
        x = x.view(-1, 1, 28, 28)
        x = self.act(self.conv1(x))
        x = self.res1(x, latent)
        x = self.res2(x, latent)
        x = self.res3(x, latent)
        x = self.act(x)
        x = x.mean(dim=2).mean(dim=2)
        energy = self.energy_map(x)

        return energy

    def label_map(self, latent):
        x = self.act(self.map_fc1(latent))
        x = self.map_fc2(x)

        return x

    def forward(self, x, latent):
        args = self.args
        x = x.view(x.size(0), -1)

        if self.cond:
            latent = self.label_map(latent)
        else:
            latent = None

        energy = self.main_model(x, latent)

        return energy

## CelebA Model

In [None]:
class CelebAModel(nn.Module):
    def __init__(self, args, debug=False):
        super(CelebAModel, self).__init__()
        self.act = swish
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.cond = args.cond

        self.args = args
        self.init_main_model()

        if args.multiscale:
            self.init_mid_model()
            self.init_small_model()

        self.relu = torch.nn.ReLU(inplace=True)
        self.downsample = Downsample(channels=3)
        self.heir_weight = nn.Parameter(torch.Tensor([1.0, 1.0, 1.0]))
        self.debug = debug

    def init_main_model(self):
        args = self.args
        filter_dim = args.filter_dim
        latent_dim = args.filter_dim
        im_size = args.im_size

        self.conv1 = nn.Conv2d(3, filter_dim // 2, kernel_size=3, stride=1, padding=1)

        self.res_1a = CondResBlock(args, filters=filter_dim // 2, latent_dim=latent_dim, im_size=im_size, downsample=True, classes=2, norm=args.norm, spec_norm=args.spec_norm)
        self.res_1b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=False, classes=2, norm=args.norm, spec_norm=args.spec_norm)

        self.res_2a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=True, rescale=False, classes=2, norm=args.norm, spec_norm=args.spec_norm)
        self.res_2b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, classes=2, norm=args.norm, spec_norm=args.spec_norm)

        self.res_3a = CondResBlock(args, filters=2*filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, classes=2, norm=args.norm, spec_norm=args.spec_norm)
        self.res_3b = CondResBlock(args, filters=2*filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, classes=2, norm=args.norm, spec_norm=args.spec_norm)

        self.res_4a = CondResBlock(args, filters=4*filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, classes=2, norm=args.norm, spec_norm=args.spec_norm)
        self.res_4b = CondResBlock(args, filters=4*filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, classes=2, norm=args.norm, spec_norm=args.spec_norm)

        self.self_attn = Self_Attn(4 * filter_dim, self.act)

        self.energy_map = nn.Linear(filter_dim*8, 1)

    def init_mid_model(self):
        args = self.args
        filter_dim = args.filter_dim
        latent_dim = args.filter_dim
        im_size = args.im_size

        self.mid_conv1 = nn.Conv2d(3, filter_dim, kernel_size=3, stride=1, padding=1)

        self.mid_res_1a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=True, rescale=False, classes=2)
        self.mid_res_1b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=False, classes=2)

        self.mid_res_2a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=True, rescale=False, classes=2)
        self.mid_res_2b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, classes=2)

        self.mid_res_3a = CondResBlock(args, filters=2*filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, classes=2)
        self.mid_res_3b = CondResBlock(args, filters=2*filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, classes=2)

        self.mid_energy_map = nn.Linear(filter_dim*4, 1)
        self.avg_pool = Downsample(channels=3)

    def init_small_model(self):
        args = self.args
        filter_dim = args.filter_dim
        latent_dim = args.filter_dim
        im_size = args.im_size

        self.small_conv1 = nn.Conv2d(3, filter_dim, kernel_size=3, stride=1, padding=1)

        self.small_res_1a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=True, rescale=False, classes=2)
        self.small_res_1b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=False, classes=2)

        self.small_res_2a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=True, rescale=False, classes=2)
        self.small_res_2b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, classes=2)

        self.small_energy_map = nn.Linear(filter_dim*2, 1)

    def main_model(self, x, latent):
        x = self.act(self.conv1(x))

        x = self.res_1a(x, latent)
        x = self.res_1b(x, latent)

        x = self.res_2a(x, latent)
        x = self.res_2b(x, latent)


        x = self.res_3a(x, latent)
        x = self.res_3b(x, latent)

        if self.args.self_attn:
            x, _ = self.self_attn(x)

        x = self.res_4a(x, latent)
        x = self.res_4b(x, latent)
        x = self.act(x)

        x = x.mean(dim=2).mean(dim=2)

        x = x.view(x.size(0), -1)
        energy = self.energy_map(x)

        if self.args.square_energy:
            energy = torch.pow(energy, 2)

        if self.args.sigmoid:
            energy = F.sigmoid(energy)

        return energy

    def mid_model(self, x, latent):
        x = F.avg_pool2d(x, 3, stride=2, padding=1)

        x = self.act(self.mid_conv1(x))

        x = self.mid_res_1a(x, latent)
        x = self.mid_res_1b(x, latent)

        x = self.mid_res_2a(x, latent)
        x = self.mid_res_2b(x, latent)

        x = self.mid_res_3a(x, latent)
        x = self.mid_res_3b(x, latent)
        x = self.act(x)

        x = x.mean(dim=2).mean(dim=2)

        x = x.view(x.size(0), -1)
        energy = self.mid_energy_map(x)

        if self.args.square_energy:
            energy = torch.pow(energy, 2)

        if self.args.sigmoid:
            energy = F.sigmoid(energy)

        return energy

    def small_model(self, x, latent):
        x = F.avg_pool2d(x, 3, stride=2, padding=1)
        x = F.avg_pool2d(x, 3, stride=2, padding=1)

        x = self.act(self.small_conv1(x))

        x = self.small_res_1a(x, latent)
        x = self.small_res_1b(x, latent)

        x = self.small_res_2a(x, latent)
        x = self.small_res_2b(x, latent)
        x = self.act(x)

        x = x.mean(dim=2).mean(dim=2)

        x = x.view(x.size(0), -1)
        energy = self.small_energy_map(x)

        if self.args.square_energy:
            energy = torch.pow(energy, 2)

        if self.args.sigmoid:
            energy = F.sigmoid(energy)

        return energy

    def label_map(self, latent):
        x = self.act(self.map_fc1(latent))
        x = self.act(self.map_fc2(x))
        x = self.act(self.map_fc3(x))
        x = self.act(self.map_fc4(x))

        return x

    def forward(self, x, latent):
        args = self.args

        if not self.cond:
            latent = None

        energy = self.main_model(x, latent)

        if args.multiscale:
            large_energy = energy
            mid_energy = self.mid_model(x, latent)
            small_energy = self.small_model(x, latent)
            energy = torch.cat([small_energy, mid_energy, large_energy], dim=-1)

        return energy

## ResNet Model

In [None]:
class ResNetModel(nn.Module):
    def __init__(self, args):
        super(ResNetModel, self).__init__()
        self.act = swish

        self.args = args
        self.spec_norm = args.spec_norm
        self.norm = args.norm
        self.init_main_model()

        if args.multiscale:
            self.init_mid_model()
            self.init_small_model()

        self.relu = torch.nn.ReLU(inplace=True)
        self.downsample = Downsample(channels=3)

        self.cond = args.cond

    def init_main_model(self):
        args = self.args
        filter_dim = args.filter_dim
        latent_dim = args.filter_dim
        im_size = args.im_size

        self.conv1 = nn.Conv2d(3, filter_dim, kernel_size=3, stride=1, padding=1)
        self.res_1a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, spec_norm=self.spec_norm, norm=self.norm)
        self.res_1b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=False, spec_norm=self.spec_norm, norm=self.norm)

        self.res_2a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, spec_norm=self.spec_norm, norm=self.norm)
        self.res_2b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, spec_norm=self.spec_norm, norm=self.norm)

        self.res_3a = CondResBlock(args, filters=2*filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, spec_norm=self.spec_norm, norm=self.norm)
        self.res_3b = CondResBlock(args, filters=2*filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, spec_norm=self.spec_norm, norm=self.norm)

        self.res_4a = CondResBlock(args, filters=4*filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, spec_norm=self.spec_norm, norm=self.norm)
        self.res_4b = CondResBlock(args, filters=4*filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, spec_norm=self.spec_norm, norm=self.norm)

        self.self_attn = Self_Attn(2 * filter_dim, self.act)

        self.energy_map = nn.Linear(filter_dim*8, 1)

    def init_mid_model(self):
        args = self.args
        filter_dim = args.filter_dim
        latent_dim = args.filter_dim
        im_size = args.im_size

        self.mid_conv1 = nn.Conv2d(3, filter_dim, kernel_size=3, stride=1, padding=1)
        self.mid_res_1a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, spec_norm=self.spec_norm, norm=self.norm)
        self.mid_res_1b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=False, spec_norm=self.spec_norm, norm=self.norm)

        self.mid_res_2a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, spec_norm=self.spec_norm, norm=self.norm)
        self.mid_res_2b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, spec_norm=self.spec_norm, norm=self.norm)

        self.mid_res_3a = CondResBlock(args, filters=2*filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, spec_norm=self.spec_norm, norm=self.norm)
        self.mid_res_3b = CondResBlock(args, filters=2*filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, spec_norm=self.spec_norm, norm=self.norm)

        self.mid_energy_map = nn.Linear(filter_dim*4, 1)
        self.avg_pool = Downsample(channels=3)

    def init_small_model(self):
        args = self.args
        filter_dim = args.filter_dim
        latent_dim = args.filter_dim
        im_size = args.im_size

        self.small_conv1 = nn.Conv2d(3, filter_dim, kernel_size=3, stride=1, padding=1)
        self.small_res_1a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, spec_norm=self.spec_norm, norm=self.norm)
        self.small_res_1b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=False, spec_norm=self.spec_norm, norm=self.norm)

        self.small_res_2a = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, downsample=False, spec_norm=self.spec_norm, norm=self.norm)
        self.small_res_2b = CondResBlock(args, filters=filter_dim, latent_dim=latent_dim, im_size=im_size, rescale=True, spec_norm=self.spec_norm, norm=self.norm)

        self.small_energy_map = nn.Linear(filter_dim*2, 1)

    def main_model(self, x, latent, compute_feat=False):
        x = self.act(self.conv1(x))

        x = self.res_1a(x, latent)
        x = self.res_1b(x, latent)

        x = self.res_2a(x, latent)
        x = self.res_2b(x, latent)

        if self.args.self_attn:
            x, _ = self.self_attn(x)

        x = self.res_3a(x, latent)
        x = self.res_3b(x, latent)

        x = self.res_4a(x, latent)
        x = self.res_4b(x, latent)
        x = self.act(x)

        x = x.mean(dim=2).mean(dim=2)

        if compute_feat:
            return x

        x = x.view(x.size(0), -1)
        energy = self.energy_map(x)

        if self.args.square_energy:
            energy = torch.pow(energy, 2)

        if self.args.sigmoid:
            energy = F.sigmoid(energy)

        return energy

    def mid_model(self, x, latent):
        x = F.avg_pool2d(x, 3, stride=2, padding=1)

        x = self.act(self.mid_conv1(x))

        x = self.mid_res_1a(x, latent)
        x = self.mid_res_1b(x, latent)

        x = self.mid_res_2a(x, latent)
        x = self.mid_res_2b(x, latent)

        x = self.mid_res_3a(x, latent)
        x = self.mid_res_3b(x, latent)
        x = self.act(x)

        x = x.mean(dim=2).mean(dim=2)

        x = x.view(x.size(0), -1)
        energy = self.mid_energy_map(x)

        if self.args.square_energy:
            energy = torch.pow(energy, 2)

        if self.args.sigmoid:
            energy = F.sigmoid(energy)

        return energy

    def small_model(self, x, latent):
        x = F.avg_pool2d(x, 3, stride=2, padding=1)
        x = F.avg_pool2d(x, 3, stride=2, padding=1)

        x = self.act(self.small_conv1(x))

        x = self.small_res_1a(x, latent)
        x = self.small_res_1b(x, latent)

        x = self.small_res_2a(x, latent)
        x = self.small_res_2b(x, latent)
        x = self.act(x)

        x = x.mean(dim=2).mean(dim=2)

        x = x.view(x.size(0), -1)
        energy = self.small_energy_map(x)

        if self.args.square_energy:
            energy = torch.pow(energy, 2)

        if self.args.sigmoid:
            energy = F.sigmoid(energy)

        return energy

    def forward(self, x, latent):
        args = self.args

        if self.cond:
            latent = self.label_map(latent)
        else:
            latent = None

        energy = self.main_model(x, latent)

        if args.multiscale:
            large_energy = energy
            mid_energy = self.mid_model(x, latent)
            small_energy = self.small_model(x, latent)

            # Add a seperate energy penalizing the different energies from each model
            energy = torch.cat([small_energy, mid_energy, large_energy], dim=-1)

        return energy

    def compute_feat(self, x, latent):
        return self.main_model(x, None, compute_feat=True)

## Downsample

In [None]:
class Downsample(nn.Module):
    def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0):
        super(Downsample, self).__init__()
        self.filt_size = filt_size
        self.pad_off = pad_off
        self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))]
        self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes]
        self.stride = stride
        self.off = int((self.stride-1)/2.)
        self.channels = channels

        if(self.filt_size==1):
            a = np.array([1.,])
        elif(self.filt_size==2):
            a = np.array([1., 1.])
        elif(self.filt_size==3):
            a = np.array([1., 2., 1.])
        elif(self.filt_size==4):
            a = np.array([1., 3., 3., 1.])
        elif(self.filt_size==5):
            a = np.array([1., 4., 6., 4., 1.])
        elif(self.filt_size==6):
            a = np.array([1., 5., 10., 10., 5., 1.])
        elif(self.filt_size==7):
            a = np.array([1., 6., 15., 20., 15., 6., 1.])

        filt = torch.Tensor(a[:,None]*a[None,:])
        filt = filt/torch.sum(filt)
        self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1)))

        self.pad = get_pad_layer(pad_type)(self.pad_sizes)

    def forward(self, inp):
        if(self.filt_size==1):
            if(self.pad_off==0):
                return inp[:,:,::self.stride,::self.stride]
            else:
                return self.pad(inp)[:,:,::self.stride,::self.stride]
        else:
            return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])

def get_pad_layer(pad_type):
    if(pad_type in ['refl','reflect']):
        PadLayer = nn.ReflectionPad2d
    elif(pad_type in ['repl','replicate']):
        PadLayer = nn.ReplicationPad2d
    elif(pad_type=='zero'):
        PadLayer = nn.ZeroPad2d
    else:
        print('Pad type [%s] not recognized'%pad_type)
    return PadLayer

## Replay Buffer

In [None]:
class GaussianBlur(object):

    def __init__(self, min=0.1, max=2.0, kernel_size=9):
        self.min = min
        self.max = max
        self.kernel_size = kernel_size

    def __call__(self, sample):
        sample = np.array(sample)

        # blur the image with a 50% chance
        prob = np.random.random_sample()

        if prob < 0.5:
            sigma = (self.max - self.min) * np.random.random_sample() + self.min
            sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)

        return sample

In [None]:
class ReplayBuffer(object):
    def __init__(self, size, transform, dataset):
        """Create Replay buffer.
        Parameters
        ----------
        size: int
            Max number of transitions to store in the buffer. When the buffer
            overflows the old memories are dropped.
        """
        self._storage = []
        self._maxsize = size
        self._next_idx = 0
        self.gaussian_blur = GaussianBlur()

        def get_color_distortion(s=1.0):
        # s is the strength of color distortion.
            color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.4*s)
            rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
            rnd_gray = transforms.RandomGrayscale(p=0.2)
            color_distort = transforms.Compose([
                rnd_color_jitter,
                rnd_gray])
            return color_distort

        color_transform = get_color_distortion()

        if dataset in ("cifar10", "celeba", "cats"):
            im_size = 32
        elif dataset == "continual":
            im_size = 64
        elif dataset == "celebahq":
            im_size = 128
        elif dataset == "object":
            im_size = 128
        elif dataset == "mnist":
            im_size = 28
        elif dataset == "moving_mnist":
            im_size = 28
        elif dataset == "imagenet":
            im_size = 128
        elif dataset == "lsun":
            im_size = 128
        else:
            assert False

        self.dataset = dataset
        if transform:
            if dataset in ("cifar10", "celeba", "cats"):
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.08, 1.0)), transforms.RandomHorizontalFlip(), color_transform, transforms.ToTensor()])
            elif dataset == "continual":
                color_transform = get_color_distortion(0.1)
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.7, 1.0)), color_transform, transforms.ToTensor()])
            elif dataset == "celebahq":
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.08, 1.0)), transforms.RandomHorizontalFlip(), color_transform, transforms.ToTensor()])
            elif dataset == "imagenet":
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.01, 1.0)), transforms.RandomHorizontalFlip(), color_transform, transforms.ToTensor()])
            elif dataset == "object":
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.01, 1.0)), transforms.RandomHorizontalFlip(), color_transform, transforms.ToTensor()])
            elif dataset == "lsun":
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.08, 1.0)), transforms.RandomHorizontalFlip(), color_transform, transforms.ToTensor()])
            elif dataset == "mnist":
                self.transform = None
            elif dataset == "moving_mnist":
                self.transform = None
            else:
                assert False
        else:
            self.transform = None

    def __len__(self):
        return len(self._storage)

    def add(self, ims):
        batch_size = ims.shape[0]
        if self._next_idx >= len(self._storage):
            self._storage.extend(list(ims))
        else:
            if batch_size + self._next_idx < self._maxsize:
                self._storage[self._next_idx:self._next_idx +
                              batch_size] = list(ims)
            else:
                split_idx = self._maxsize - self._next_idx
                self._storage[self._next_idx:] = list(ims)[:split_idx]
                self._storage[:batch_size - split_idx] = list(ims)[split_idx:]
        self._next_idx = (self._next_idx + ims.shape[0]) % self._maxsize

    def _encode_sample(self, idxes, no_transform=False, downsample=False):
        ims = []
        for i in idxes:
            im = self._storage[i]

            if self.dataset != "mnist":
                if (self.transform is not None) and (not no_transform):
                    im = im.transpose((1, 2, 0))
                    im = np.array(self.transform(Image.fromarray(np.array(im))))

                # if downsample and (self.dataset in ["celeba", "object", "imagenet"]):
                #     im = im[:, ::4, ::4]

            im = im * 255
            ims.append(im)
        return np.array(ims)

    def sample(self, batch_size, no_transform=False, downsample=False):
        """Sample a batch of experiences.
        Parameters
        ----------
        batch_size: int
            How many transitions to sample.
        Returns
        -------
        obs_batch: np.array
            batch of observations
        act_batch: np.array
            batch of actions executed given obs_batch
        rew_batch: np.array
            rewards received as results of executing act_batch
        next_obs_batch: np.array
            next set of observations seen after executing act_batch
        done_mask: np.array
            done_mask[i] = 1 if executing act_batch[i] resulted in
            the end of an episode and 0 otherwise.
        """
        idxes = [random.randint(0, len(self._storage) - 1)
                 for _ in range(batch_size)]
        return self._encode_sample(idxes, no_transform=no_transform, downsample=downsample), idxes

    def set_elms(self, data, idxes):
        if len(self._storage) < self._maxsize:
            self.add(data)
        else:
            for i, ix in enumerate(idxes):
                self._storage[ix] = data[i]

In [None]:
class ReservoirBuffer(object):
    def __init__(self, size, transform, dataset):
        """Create Replay buffer.
        Parameters
        ----------
        size: int
            Max number of transitions to store in the buffer. When the buffer
            overflows the old memories are dropped.
        """
        self._storage = []
        self._maxsize = size
        self._next_idx = 0
        self.n = 0

        def get_color_distortion(s=1.0):
        # s is the strength of color distortion.
            color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.4*s)
            rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
            rnd_gray = transforms.RandomGrayscale(p=0.2)
            color_distort = transforms.Compose([
                rnd_color_jitter,
                rnd_gray])
            return color_distort

        if dataset in ("cifar10", "celeba", "cats"):
            im_size = 32
        elif dataset == "continual":
            im_size = 64
        elif dataset == "celebahq":
            im_size = 128
        elif dataset == "object":
            im_size = 128
        elif dataset == "mnist":
            im_size = 28
        elif dataset == "moving_mnist":
            im_size = 28
        elif dataset == "imagenet":
            im_size = 128
        elif dataset == "lsun":
            im_size = 128
        elif dataset == "stl":
            im_size = 48
        else:
            assert False

        color_transform = get_color_distortion(0.5)
        self.dataset = dataset

        if transform:
            if dataset in ("cifar10", "celeba", "cats"):
                color_transform = get_color_distortion(1.0)
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.08, 1.0)), transforms.RandomHorizontalFlip(), color_transform, transforms.ToTensor()])
                # self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.03, 1.0)), transforms.RandomHorizontalFlip(), color_transform, GaussianBlur(kernel_size=5), transforms.ToTensor()])
            elif dataset == "continual":
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.08, 1.0)), transforms.RandomHorizontalFlip(), color_transform, GaussianBlur(kernel_size=5), transforms.ToTensor()])
            elif dataset == "celebahq":
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.08, 1.0)), transforms.RandomHorizontalFlip(), color_transform, GaussianBlur(kernel_size=5), transforms.ToTensor()])
            elif dataset == "imagenet":
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.6, 1.0)), transforms.RandomHorizontalFlip(), color_transform, GaussianBlur(kernel_size=11), transforms.ToTensor()])
            elif dataset == "lsun":
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.08, 1.0)), transforms.RandomHorizontalFlip(), color_transform, GaussianBlur(kernel_size=5), transforms.ToTensor()])
            elif dataset == "stl":
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.04, 1.0)), transforms.RandomHorizontalFlip(), color_transform, GaussianBlur(kernel_size=11), transforms.ToTensor()])
            elif dataset == "object":
                self.transform = transforms.Compose([transforms.RandomResizedCrop(im_size, scale=(0.08, 1.0)), transforms.RandomHorizontalFlip(), color_transform, transforms.ToTensor()])
            elif dataset == "mnist":
                self.transform = None
            elif dataset == "moving_mnist":
                self.transform = None
            else:
                assert False
        else:
            self.transform = None

    def __len__(self):
        return len(self._storage)

    def add(self, ims):
        batch_size = ims.shape[0]
        if self._next_idx >= len(self._storage):
            self._storage.extend(list(ims))
            self.n = self.n + ims.shape[0]
        else:
            for im in ims:
                self.n = self.n + 1
                ix = random.randint(0, self.n - 1)

                if ix < len(self._storage):
                    self._storage[ix] = im

        self._next_idx = (self._next_idx + ims.shape[0]) % self._maxsize


    def _encode_sample(self, idxes, no_transform=False, downsample=False):
        ims = []
        for i in idxes:
            im = self._storage[i]

            if self.dataset != "mnist":
                if (self.transform is not None) and (not no_transform):
                    im = im.transpose((1, 2, 0))
                    im = np.array(self.transform(Image.fromarray(im)))

                # if downsample and (self.dataset in ["celeba", "object", "imagenet"]):
                #     im = im[:, ::4, ::4]

            im = im * 255

            ims.append(im)
        return np.array(ims)

    def sample(self, batch_size, no_transform=False, downsample=False):
        """Sample a batch of experiences.
        Parameters
        ----------
        batch_size: int
            How many transitions to sample.
        Returns
        -------
        obs_batch: np.array
            batch of observations
        act_batch: np.array
            batch of actions executed given obs_batch
        rew_batch: np.array
            rewards received as results of executing act_batch
        next_obs_batch: np.array
            next set of observations seen after executing act_batch
        done_mask: np.array
            done_mask[i] = 1 if executing act_batch[i] resulted in
            the end of an episode and 0 otherwise.
        """
        # idxes = [random.randint(0, len(self._storage) - 1)  for _ in range(batch_size)]
        idxes = torch.randint(low=0, high=len(self._storage) - 1, size=(batch_size,))
        return self._encode_sample(idxes, no_transform=no_transform, downsample=downsample), idxes

## Utils

In [None]:
def swish(x):
    return x * torch.sigmoid(x)
    
def adjust_learning_rate(epoch, opt, optimizer):
    """Sets the learning rate to the initial LR decayed by 0.2 every steep step"""
    steps = np.sum(epoch > np.asarray(opt.lr_decay_epochs))
    if steps > 0:
        new_lr = opt.learning_rate * (opt.lr_decay_rate ** steps)
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lr

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def compute_jacobian_generic(y, x, create_graph=False):
    # Computes the jacobian by tiling values.
    # Assumes y is of shape n x d
    # Assumes x is of shape n x d also

    latent_dim = y.size(1)
    grad_y = torch.zeros_like(y)
    jacs = []

    for i in range(latent_dim):
        grad_y[:, i] = 1
        jac = torch.autograd.grad(y, x, grad_y, create_graph=create_graph, retain_graph=True)[0]
        jacs.append(jac)
        grad_y[:, i] = 0

    jacs = torch.stack(jacs, dim=1)
    return jacs


def compute_jacobian(model, im_feat, latent, optimize_partition=False, create_graph=False):
    # Computes the jacobian by tiling values.
    # Assumes y is of shape n x d
    # Assumes x is of shape n x d also
    latent_dim = model.energy_dim

    im_shape = im_feat.size()
    latent_shape = latent.size()
    im_feat_raw = im_feat

    im_feat = im_feat[:, None, :].repeat(1, latent_dim, 1).view(-1, im_shape[1])
    latent = latent[:, None, :].repeat(1, latent_dim, 1).view(-1, latent_shape[1])
    grad_y = torch.eye(latent_dim).to(im_feat.device)[None, :, :].repeat(im_shape[0], 1, 1)
    grad_y = grad_y.view(-1, latent_dim)
    energy = model.feat_energy(im_feat, latent)

    if optimize_partition:
        im_feat_raw = im_feat_raw[torch.randperm(im_feat_raw.size(0)).to(im_feat_raw.device)][:32]
        # im_feat_raw = im_feat_raw
        im_feat_partition = im_feat_raw[:, None, :].repeat(1, latent.size(0), 1)
        latent_neg_partition = latent[None, :, :].repeat(im_feat_raw.size(0), 1, 1)
        partition_est = model.feat_energy(im_feat_partition, latent_neg_partition)
        energy = energy + torch.logsumexp(-1 * partition_est, dim=0)

    jacs = torch.autograd.grad(energy, latent, grad_y, create_graph=create_graph)[0]
    s = jacs.size()
    # jacs = jacs.view(im_shape[0], -1)
    jacs_dense = jacs.view(im_shape[0], -1)
    scale_factor = torch.abs(jacs_dense).max(dim=-1, keepdim=True)[0]

    jacs = jacs_dense.view(im_shape[0], -1) / scale_factor
    jacs = jacs.view(im_shape[0], latent_dim, s[1])

    energy = energy.view(-1, latent_dim, latent_dim)
    energy = energy[:, 0, :]

    return jacs, scale_factor, energy

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
class WSConv2d(nn.Conv2d):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(WSConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)

    def forward(self, x):
        weight = self.weight
        weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
                                  keepdim=True).mean(dim=3, keepdim=True)
        weight = weight - weight_mean
        std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
        weight = weight / std.expand_as(weight)
        return F.conv2d(x, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

In [None]:
class GaussianSmoothing(nn.Module):
    """
    Apply gaussian smoothing on a
    1d, 2d or 3d tensor. Filtering is performed seperately for each channel
    in the input using a depthwise convolution.
    Arguments:
        channels (int, sequence): Number of channels of the input tensors. Output will
            have this number of channels as well.
        kernel_size (int, sequence): Size of the gaussian kernel.
        sigma (float, sequence): Standard deviation of the gaussian kernel.
        dim (int, optional): The number of dimensions of the data.
            Default value is 2 (spatial).
    """
    def __init__(self, channels, kernel_size, sigma, dim=2):
        super(GaussianSmoothing, self).__init__()
        if isinstance(kernel_size, numbers.Number):
            kernel_size = [kernel_size] * dim
        if isinstance(sigma, numbers.Number):
            sigma = [sigma] * dim

        # The gaussian kernel is the product of the
        # gaussian function of each dimension.
        kernel = 1
        meshgrids = torch.meshgrid(
            [
                torch.arange(size, dtype=torch.float32)
                for size in kernel_size
            ]
        )
        for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
            mean = (size - 1) / 2
            kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
                      torch.exp(-((mgrid - mean) / std) ** 2 / 2)

        # Make sure sum of values in gaussian kernel equals 1.
        kernel = kernel / torch.sum(kernel)

        # Reshape to depthwise convolutional weight
        kernel = kernel.view(1, 1, *kernel.size())
        kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))

        self.register_buffer('weight', kernel)
        self.groups = channels

        if dim == 1:
            self.conv = F.conv1d
        elif dim == 2:
            self.conv = F.conv2d
        elif dim == 3:
            self.conv = F.conv3d
        else:
            raise RuntimeError(
                'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
            )

    def forward(self, input):
        """
        Apply gaussian filter to input.
        Arguments:
            input (torch.Tensor): Input to apply gaussian filter on.
        Returns:
            filtered (torch.Tensor): Filtered output.
        """
        return self.conv(input, weight=self.weight, groups=self.groups)

In [None]:
def cutout(mask_color=(0, 0, 0)):
    mask_size_half = flags_cutout_mask_size // 2
    offset = 1 if flags_cutout_mask_size % 2 == 0 else 0

    def _cutout(image):
        image = np.asarray(image).copy()

        if np.random.random() > flags_cutout_prob:
            return image

        h, w = image.shape[:2]

        if flags_cutout_inside:
            cxmin, cxmax = mask_size_half, w + offset - mask_size_half
            cymin, cymax = mask_size_half, h + offset - mask_size_half
        else:
            cxmin, cxmax = 0, w + offset
            cymin, cymax = 0, h + offset

        cx = np.random.randint(cxmin, cxmax)
        cy = np.random.randint(cymin, cymax)
        xmin = cx - mask_size_half
        ymin = cy - mask_size_half
        xmax = xmin + flags_cutout_mask_size
        ymax = ymin + flags_cutout_mask_size
        xmin = max(0, xmin)
        ymin = max(0, ymin)
        xmax = min(w, xmax)
        ymax = min(h, ymax)
        image[:, ymin:ymax, xmin:xmax] = np.array(mask_color)[:, None, None]
        return image

    return _cutout

In [None]:
def compress_x_mod(x_mod):
    x_mod = (255 * np.clip(x_mod, 0, 1)).astype(np.uint8)
    return x_mod


def decompress_x_mod(x_mod):
    x_mod = x_mod / 256  + \
        np.random.uniform(0, 1 / 256, x_mod.shape)
    return x_mod

In [None]:
def sync_model(models):
    size = float(dist.get_world_size())

    for model in models:
        for param in model.parameters():
            dist.broadcast(param.data, 0)


def ema_model(models, models_ema, mu=0.99):
    for model, model_ema in zip(models, models_ema):
        for param, param_ema in zip(model.parameters(), model_ema.parameters()):
            param_ema.data[:] = mu * param_ema.data + (1 - mu) * param.data



In [None]:
def average_gradients(models):
    size = float(dist.get_world_size())

    for model in models:
        for param in model.parameters():
            if param.grad is None:
                continue

            dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
            param.grad.data /= size

## Dataset

In [None]:
class Mnist(Dataset):
    def __init__(self, train=True, rescale=1.0):
        self.data = MNIST(
            "data/mnist",
            transform=transforms.ToTensor(),
            download=True, train=train)
        self.labels = np.eye(10)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        im, label = self.data[index]
        label = self.labels[label]
        im = im.squeeze()
        im = im.numpy() / 256 * 255 + np.random.uniform(0, 1. / 256, (28, 28))
        im = np.clip(im, 0, 1)
        s = 28
        im_corrupt = np.random.uniform(0, 1, (s, s, 1))
        im = im[:, :, None]

        return torch.Tensor(im_corrupt), torch.Tensor(im), label

In [None]:
if flags.dataset == "celebahq":
    !mkdir -p /content/data/celebAHQ
    !unzip -qq '/content/drive/MyDrive/Colab Notebooks/improved_contrastive_divergence/data/celebAHQ/data128x128.zip' -d /content/data/celebAHQ
elif flags.dataset == "celeba":
    !mkdir -p /content/data
    %cd /content/drive/MyDrive/Colab Notebooks/improved_contrastive_divergence
    %cp -av data/celeba/ /content/data
elif flags.dataset == "cats":
    !mkdir -p /content/data
    %cd /content/drive/MyDrive/Colab Notebooks/improved_contrastive_divergence
    %cp -av data/cats/ /content/data
    !unzip -qq /content/data/cats/cats-dataset.zip -d /content/data/cats

In [None]:
 class CelebAHQ(Dataset):

    def __init__(self, cond_idx=1, filter_idx=0):
        self.path = "/content/data/celebAHQ/data128x128/{:05}.jpg"
        self.hq_labels = pd.read_csv(os.path.join(sample_dir, "data/celebAHQ/image_list.txt"), sep="\s+")
        self.labels = pd.read_csv(os.path.join(sample_dir, "data/celebAHQ/list_attr_celeba.txt"), sep="\s+", skiprows=1)
        self.cond_idx = cond_idx
        self.filter_idx = filter_idx

    def __len__(self):
        return self.hq_labels.shape[0] 

    def __getitem__(self, index):
        info = self.hq_labels.iloc[index]
        info = self.labels.iloc[info.orig_idx]

        path = self.path.format(index+1)
        im = np.array(Image.open(path))
        image_size = 128
        # im = imresize(im, (image_size, image_size))
        im = im / 256
        im = im + np.random.uniform(0, 1 / 256., im.shape)

        label = int(info.iloc[self.cond_idx])
        if label == -1:
            label = 0
        label = np.eye(2)[label]

        im_corrupt = np.random.uniform(
            0, 1, size=(image_size, image_size, 3))

        return im_corrupt, im, label

In [None]:
class CelebADataset(Dataset):
    def __init__(
            self,
            FLAGS,
            split='train',
            augment=False,
            noise=True,
            rescale=1.0):

        if augment:
            transform_list = [
                torchvision.transforms.RandomCrop(32, padding=4),
                torchvision.transforms.RandomHorizontalFlip(),
                torchvision.transforms.ToTensor(),
            ]

            transform = transforms.Compose(transform_list)
        else:
            # transform = transforms.ToTensor()
            transform = transforms.Compose([
                # resize
                transforms.Resize(32),
                # center-crop
                transforms.CenterCrop(32),
                # to-tensor
                transforms.ToTensor()
            ])

        self.data = torchvision.datasets.CelebA(
            "/content/data",
            transform=transform,
            split=split,
            download=True)
        self.one_hot_map = np.eye(10)
        self.noise = noise
        self.rescale = rescale
        self.FLAGS = FLAGS

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        FLAGS = self.FLAGS
        
        im, label = self.data[index]

        im = np.transpose(im, (1, 2, 0)).numpy()
        image_size = 32
        label = self.one_hot_map[label]

        im = im * 255 / 256

        im = im * self.rescale + \
            np.random.uniform(0, 1 / 256., im.shape)

        # np.random.seed((index + int(time.time() * 1e7)) % 2**32)

        im_corrupt = np.random.uniform(
            0.0, self.rescale, (image_size, image_size, 3))

        return torch.Tensor(im_corrupt), torch.Tensor(im), label
        # return torch.Tensor(im), label

In [None]:
class Cats(Dataset):
    def __init__(
            self,
            augment=False,
            noise=True,
            rescale=1.0):

        if augment:
            transform_list = [
                torchvision.transforms.RandomCrop(32, padding=4),
                torchvision.transforms.RandomHorizontalFlip(),
                torchvision.transforms.ToTensor(),
            ]

            transform = transforms.Compose(transform_list)
        else:
            # transform = transforms.ToTensor()
            transform = transforms.Compose([
                # resize
                transforms.Resize(32),
                # center-crop
                transforms.CenterCrop(32),
                # to-tensor
                transforms.ToTensor()
            ])

        self.data = torchvision.datasets.ImageFolder('/content/data/cats', transform = transform)
        self.one_hot_map = np.eye(10)
        self.noise = noise
        self.rescale = rescale

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):        
        im, label = self.data[index]

        im = np.transpose(im, (1, 2, 0)).numpy()
        image_size = 32
        label = self.one_hot_map[label]

        im = im * 255 / 256

        im = im * self.rescale + \
            np.random.uniform(0, 1 / 256., im.shape)

        im_corrupt = np.random.uniform(
            0.0, self.rescale, (image_size, image_size, 3))

        return torch.Tensor(im_corrupt), torch.Tensor(im), label

In [None]:
class Cifar10(Dataset):
    def __init__(
            self,
            FLAGS,
            train=True,
            full=False,
            augment=False,
            noise=True,
            rescale=1.0):

        if augment:
            transform_list = [
                torchvision.transforms.RandomCrop(32, padding=4),
                torchvision.transforms.RandomHorizontalFlip(),
                torchvision.transforms.ToTensor(),
            ]

            transform = transforms.Compose(transform_list)
        else:
            transform = transforms.ToTensor()

        self.full = full
        self.data = torchvision.datasets.CIFAR10(
            "./data/cifar10",
            transform=transform,
            train=train,
            download=True)
        self.test_data = torchvision.datasets.CIFAR10(
            "./data/cifar10",
            transform=transform,
            train=False,
            download=True)
        self.one_hot_map = np.eye(10)
        self.noise = noise
        self.rescale = rescale
        self.FLAGS = FLAGS

    def __len__(self):

        if self.full:
            return len(self.data) + len(self.test_data)
        else:
            return len(self.data)

    def __getitem__(self, index):
        FLAGS = self.FLAGS
        if self.full:
            if index >= len(self.data):
                im, label = self.test_data[index - len(self.data)]
            else:
                im, label = self.data[index]
        else:
            im, label = self.data[index]

        im = np.transpose(im, (1, 2, 0)).numpy()
        image_size = 32
        label = self.one_hot_map[label]

        im = im * 255 / 256

        im = im * self.rescale + \
            np.random.uniform(0, 1 / 256., im.shape)

        # np.random.seed((index + int(time.time() * 1e7)) % 2**32)

        im_corrupt = np.random.uniform(
            0.0, self.rescale, (image_size, image_size, 3))

        return torch.Tensor(im_corrupt), torch.Tensor(im), label

## Sampling ##

In [None]:
def rescale_im(image):
    image = np.clip(image, 0, 1)
    return (np.clip(image * 256, 0, 255)).astype(np.uint8)


def hamiltonian(x, v, model, label):
    energy = 0.5 * torch.pow(v, 2).sum(dim=1).sum(dim=1).sum(dim=1) + model.forward(x, label).squeeze()
    return energy

In [None]:
def leapfrog_step(x, v, model, step_size, num_steps, label, sample=False):
    x.requires_grad_(requires_grad=True)
    energy = model.forward(x, label)
    im_grad = torch.autograd.grad([energy.sum()], [x])[0]
    v = v - 0.5 * step_size * im_grad
    im_negs = []

    for i in range(num_steps):
        x.requires_grad_(requires_grad=True)
        energy = model.forward(x, label)

        if i == num_steps - 1:
            im_grad = torch.autograd.grad([energy.sum()], [x], create_graph=True)[0]
            v = v - step_size * im_grad
            x = x + step_size * v
            v = v.detach()
        else:
            im_grad = torch.autograd.grad([energy.sum()], [x])[0]
            v = v - step_size * im_grad
            x = x + step_size * v
            x = x.detach()
            v = v.detach()


        if sample:
            im_negs.append(x)

        if i % 10 == 0:
            print(i, hamiltonian(torch.sigmoid(x), v, model, label).mean(), torch.abs(im_grad).mean())

    if sample:
        return x, im_negs, v, im_grad
    else:
        return x, v, im_grad

In [None]:
def gen_hmc_image(label, FLAGS, model, im_neg, num_steps, sample=False):
    step_size = FLAGS.step_lr

    v = 0.001 * torch.randn_like(im_neg)

    if sample:
        im_neg, im_negs, v, im_grad = leapfrog_step(im_neg, v, model, step_size, num_steps, label, sample=sample)
        return im_neg, im_negs, im_grad, v
    else:
        im_neg, v, im_grad = leapfrog_step(im_neg, v, model, step_size, num_steps, label, sample=sample)
        return im_neg, im_grad, v

In [None]:
def gen_image(label, FLAGS, model, im_neg, num_steps, sample=False):
    im_noise = torch.randn_like(im_neg).detach()

    im_negs_samples = []

    for i in range(num_steps):
        im_noise.normal_()

        if FLAGS.anneal:
            im_neg = im_neg + 0.001 * (num_steps - i - 1) / num_steps * im_noise
        else:
            im_neg = im_neg + 0.001 * im_noise

        im_neg.requires_grad_(requires_grad=True)
        energy = model.forward(im_neg, label)

        if FLAGS.all_step:
            im_grad = torch.autograd.grad([energy.sum()], [im_neg], create_graph=True)[0]
        else:
            im_grad = torch.autograd.grad([energy.sum()], [im_neg])[0]

        if i == num_steps - 1:
            im_neg_orig = im_neg
            im_neg = im_neg - FLAGS.step_lr * im_grad

            if FLAGS.dataset in ("cifar10", "celeba", "cats"):
                n = 128
            elif FLAGS.dataset == "celebahq":
                # Save space
                n = 128
            elif FLAGS.dataset == "lsun":
                # Save space
                n = 32
            elif FLAGS.dataset == "object":
                # Save space
                n = 32
            elif FLAGS.dataset == "mnist":
                n = 32
            elif FLAGS.dataset == "imagenet":
                n = 32
            elif FLAGS.dataset == "stl":
                n = 32

            im_neg_kl = im_neg_orig[:n]
            if sample:
                pass
            else:
                energy = model.forward(im_neg_kl, label)
                im_grad = torch.autograd.grad([energy.sum()], [im_neg_kl], create_graph=True)[0]

            im_neg_kl = im_neg_kl - FLAGS.step_lr * im_grad[:n]
            im_neg_kl = torch.clamp(im_neg_kl, 0, 1)
        else:
            im_neg = im_neg - FLAGS.step_lr * im_grad

        im_neg = im_neg.detach()

        if sample:
            im_negs_samples.append(im_neg)

        im_neg = torch.clamp(im_neg, 0, 1)

    if sample:
        return im_neg, im_neg_kl, im_negs_samples, im_grad
    else:
        return im_neg, im_neg_kl, im_grad

In [None]:
def gen_image_asgld(label, FLAGS, model, im_neg, num_steps, sample=False):
    # im_noise = torch.randn_like(im_neg).detach()

    im_negs_samples = []
    
    # Intialize mean and variance to zero
    mean = torch.zeros_like(im_neg.data)
    std = torch.zeros_like(im_neg.data)
    weight_decay = 5e-4

    for i in range(num_steps):
        # im_noise.normal_()
        # Getting mean,std at previous step
        old_mean = mean.clone()
        old_std = std.clone()

        im_noise = torch.normal(mean=old_mean, std=old_std)
        # im_neg.data.add(FLAGS.init_noise, noise)

        if FLAGS.anneal:
            # im_neg = im_neg + 0.001 * (num_steps - i - 1) / num_steps * im_noise
            im_neg = im_neg + FLAGS.init_noise * (num_steps - i - 1) / num_steps * im_noise
        else:
            # im_neg = im_neg + 0.001 * im_noise
            im_neg.data.add(FLAGS.init_noise, im_noise)

        im_neg.requires_grad_(requires_grad=True)
        energy = model.forward(im_neg, label)

        if FLAGS.all_step:
            im_grad = torch.autograd.grad([energy.sum()], [im_neg], create_graph=True)[0]
        else:
            im_grad = torch.autograd.grad([energy.sum()], [im_neg])[0]

        if i == num_steps - 1:
            im_neg_orig = im_neg
            im_neg = im_neg - FLAGS.step_lr * im_grad

            if FLAGS.dataset in ("cifar10", "celeba", "cats"):
                n = 128
            elif FLAGS.dataset == "celebahq":
                # Save space
                n = 128
            elif FLAGS.dataset == "lsun":
                # Save space
                n = 32
            elif FLAGS.dataset == "object":
                # Save space
                n = 32
            elif FLAGS.dataset == "mnist":
                n = 32
            elif FLAGS.dataset == "imagenet":
                n = 32
            elif FLAGS.dataset == "stl":
                n = 32

            im_neg_kl = im_neg_orig[:n]
            if sample:
                pass
            else:
                energy = model.forward(im_neg_kl, label)
                im_grad = torch.autograd.grad([energy.sum()], [im_neg_kl], create_graph=True)[0]

            im_neg_kl = im_neg_kl - FLAGS.step_lr * im_grad[:n]
            im_neg_kl = torch.clamp(im_neg_kl, 0, 1)
            im_neg_kl = im_neg_kl.detach()
        else:
            im_neg = im_neg - FLAGS.step_lr * im_grad        

        # Updating mean
        mean = mean.mul(FLAGS.momentum).add(im_neg.data)
            
        # Updating std
        part_var1 = im_neg.data.add(-old_mean)
        part_var2 = im_neg.data.add(-mean)
            
        new_std = torch.pow(old_std,2).mul(FLAGS.momentum).addcmul(1,part_var1,part_var2).add(FLAGS.eps)                
        new_std = torch.pow(torch.abs_(new_std),1/2)
        std.add_(-1,std).add_(new_std)
            
        im_neg = im_neg.detach()

        if sample:
            im_negs_samples.append(im_neg)

        im_neg = torch.clamp(im_neg, 0, 1)
    
    if sample:
        return im_neg, im_neg_kl, im_negs_samples, np.abs(im_grad.detach().cpu().numpy()).mean()
    else:
        return im_neg, im_neg_kl, np.abs(im_grad.detach().cpu().numpy()).mean()

## Training

In [None]:
def test(model, logger, dataloader):
    pass

In [None]:
def log_tensorboard(data):
    writer.add_scalar("repel loss", data["loss_repel"], data["iter"])
    writer.add_scalar("KL mean", data["kl_mean"], data["iter"])
    

    writer.add_scalar("positive energy mean", data["e_pos"], data["iter"])
    writer.add_scalar("positive energy std", data["e_pos_std"], data["iter"])

    writer.add_scalar("negative energy mean", data["e_neg"], data["iter"])
    writer.add_scalar("negative energy std", data["e_neg_std"], data["iter"])

    writer.add_scalar("energy different", data["e_diff"], data["iter"])
    writer.add_scalar("x gradient", data["x_grad"], data["iter"])

    writer.add_images("positive examples", data["positive_samples"], data["iter"])
    writer.add_images("negative examples", data["negative_samples"], data["iter"])


In [None]:
def train(model, optimizer, dataloader,logdir, resume_iter, FLAGS, best_inception):
    if FLAGS.cuda:
        model = model.to(FLAGS.gpu)

    if FLAGS.replay_batch:
        if FLAGS.reservoir:
            replay_buffer = ReservoirBuffer(FLAGS.buffer_size, FLAGS.transform, FLAGS.dataset)
        else:
            replay_buffer = ReplayBuffer(FLAGS.buffer_size, FLAGS.transform, FLAGS.dataset)

    itr = resume_iter
    im_neg = None
    gd_steps = 1

    optimizer.zero_grad()

    num_steps = FLAGS.num_steps

    for epoch in range(FLAGS.epoch_num):
        print("epoch : ", epoch)
        tock = time.time()
        for data_corrupt, data, label in dataloader:
            # torch.cuda.empty_cache()
            label = label.float()
            data = data.permute(0, 3, 1, 2).float().contiguous()
            
            # Generate samples to evaluate inception score
            if itr % FLAGS.save_interval == 0:
                # print("masuk situ diiterasi", itr)
                if FLAGS.dataset in ("cifar10", "celeba", "cats"):
                    data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (128, 32, 32, 3)))
                    repeat = 128 // FLAGS.batch_size + 1
                    label = torch.cat([label] * repeat, axis=0)
                    label = label[:128]
                elif FLAGS.dataset == "celebahq":
                    data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (data.shape[0], 128, 128, 3)))
                    label = label[:data.shape[0]]
                    data_corrupt = data_corrupt[:label.shape[0]]
                elif FLAGS.dataset == "stl":
                    data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (32, 48, 48, 3)))
                    label = label[:32]
                    data_corrupt = data_corrupt[:label.shape[0]]
                elif FLAGS.dataset == "lsun":
                    data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (32, 128, 128, 3)))
                    label = label[:32]
                    data_corrupt = data_corrupt[:label.shape[0]]
                elif FLAGS.dataset == "imagenet":
                    data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (32, 128, 128, 3)))
                    label = label[:32]
                    data_corrupt = data_corrupt[:label.shape[0]]
                elif FLAGS.dataset == "object":
                    data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (32, 128, 128, 3)))
                    label = label[:32]
                    data_corrupt = data_corrupt[:label.shape[0]]
                elif FLAGS.dataset == "mnist":
                    data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (32, 28, 28, 1)))
                    label = label[:32]
                    data_corrupt = data_corrupt[:label.shape[0]]
                else:
                    assert False
            
            data_corrupt = torch.Tensor(data_corrupt.float()).permute(0, 3, 1, 2).float().contiguous()
            
            if FLAGS.replay_batch and len(replay_buffer) >= FLAGS.batch_size:
                replay_batch, idxs = replay_buffer.sample(data_corrupt.size(0))
                replay_batch = decompress_x_mod(replay_batch)
                replay_mask = (
                    np.random.uniform(
                        0,
                        1,
                        data_corrupt.size(0)) > 0.05)
                data_corrupt[replay_mask] = torch.Tensor(replay_batch[replay_mask])
            else:
                idxs = None

            # print("data shape", data.shape)
            data = data.to(FLAGS.gpu, non_blocking=True)
            data_corrupt = data_corrupt.to(FLAGS.gpu, non_blocking=True)
            label = label.to(FLAGS.gpu, non_blocking=True)
            
            if FLAGS.hmc:
                if itr % FLAGS.save_interval == 0:
                    im_neg, im_samples, x_grad, v = gen_hmc_image(label, FLAGS, model, data_corrupt, num_steps, sample=True)
                else:
                    im_neg, x_grad, v = gen_hmc_image(label, FLAGS, model, data_corrupt, num_steps)
            elif FLAGS.asgld:
                if itr % FLAGS.save_interval == 0:
                    im_neg, im_neg_kl, im_samples, x_grad = gen_image_asgld(label, FLAGS, model, data_corrupt, num_steps, sample=True)
                else:
                    im_neg, im_neg_kl, x_grad = gen_image_asgld(label, FLAGS, model, data_corrupt, num_steps)
            else:
                if itr % FLAGS.save_interval == 0:
                    im_neg, im_neg_kl, im_samples, x_grad = gen_image(label, FLAGS, model, data_corrupt, num_steps, sample=True)
                else:
                    im_neg, im_neg_kl, x_grad = gen_image(label, FLAGS, model, data_corrupt, num_steps)
            

            energy_neg = model.forward(im_neg, label)
            energy_pos = model.forward(data, label[:data.size(0)])
            
            if FLAGS.replay_batch and (im_neg is not None):
                replay_buffer.add(compress_x_mod(im_neg.detach().cpu().numpy()))

            loss = energy_pos.mean() - energy_neg.mean() #
            loss = loss  + (torch.pow(energy_pos, 2).mean() + torch.pow(energy_neg, 2).mean())

            # print("debug loss calculation", torch.cuda.list_gpu_processes())
            # print("calculate KL")
            if FLAGS.kl:
                model.requires_grad_(False)
                loss_kl = model.forward(im_neg_kl, label)
                model.requires_grad_(True)
                loss = loss + FLAGS.kl_coeff * loss_kl.mean()

                if FLAGS.repel_im:
                    start = timeit.timeit()
                    bs = im_neg_kl.size(0)

                    if FLAGS.dataset in ["celebahq", "imagenet", "object", "lsun", "stl"]:
                        im_neg_kl = im_neg_kl[:, :, :, :].contiguous()

                    im_flat = torch.clamp(im_neg_kl.view(bs, -1), 0, 1)

                    if FLAGS.dataset in ("cifar10", "celeba", "cats"):
                        if len(replay_buffer) > 1000:
                            compare_batch, idxs = replay_buffer.sample(100, no_transform=False)
                            compare_batch = decompress_x_mod(compare_batch)
                            compare_batch = torch.Tensor(compare_batch).to(FLAGS.gpu, non_blocking=True)
                            # compare_batch = torch.Tensor(compare_batch)
                            compare_flat = compare_batch.view(100, -1)

                            dist_matrix = torch.norm(im_flat[:, None, :] - compare_flat[None, :, :], p=2, dim=-1)
                            loss_repel = torch.log(dist_matrix.min(dim=1)[0]).mean()
                            loss = loss - 0.3 * loss_repel
                        else:
                            loss_repel = torch.zeros(1)
                    else:
                        if len(replay_buffer) > 1000:
                            compare_batch, idxs = replay_buffer.sample(100, no_transform=False, downsample=True)
                            compare_batch = decompress_x_mod(compare_batch)
                            compare_batch = torch.Tensor(compare_batch).to(FLAGS.gpu, non_blocking=True)
                            # compare_batch = torch.Tensor(compare_batch)
                            compare_flat = compare_batch.view(100, -1)
                            dist_matrix = torch.norm(im_flat[:, None, :] - compare_flat[None, :, :], p=2, dim=-1)
                            loss_repel = torch.log(dist_matrix.min(dim=1)[0]).mean()
                        else:
                            loss_repel = torch.zeros(1).to(FLAGS.gpu, non_blocking=True)

                        loss = loss - 0.3 * loss_repel

                    end = timeit.timeit()
                else:
                    loss_repel = torch.zeros(1)

            else:
                loss_kl = torch.zeros(1)
                loss_repel = torch.zeros(1)

            if FLAGS.hmc:
                v_flat = v.view(v.size(0), -1)
                im_grad_flat = x_grad.view(x_grad.size(0), -1)
                dot_product = F.normalize(v_flat, dim=1) * F.normalize(im_grad_flat, dim=1)
                hmc_loss = torch.abs(dot_product.sum(dim=1)).mean()
                loss = loss + 0.01 * hmc_loss
            # else:
            #     hmc_loss = torch.zeros(1)

            if FLAGS.log_grad and len(replay_buffer) > 1000:
                loss_kl = loss_kl - 0.1 * loss_repel
                loss_kl = loss_kl.mean()
                loss_ml = energy_pos.mean() - energy_neg.mean()

                loss_ml.backward(retain_graph=True)
                ele = []

                for param in model.parameters():
                    if param.grad is not None:
                        ele.append(torch.norm(param.grad.data))

                ele = torch.stack(ele, dim=0)
                ml_grad = torch.mean(ele)
                model.zero_grad()

                loss_kl.backward(retain_graph=True) 
                ele = []

                for param in model.parameters():
                    if param.grad is not None:
                        ele.append(torch.norm(param.grad.data))

                ele = torch.stack(ele, dim=0)
                kl_grad = torch.mean(ele)
                model.zero_grad()

            else:
                ml_grad = None
                kl_grad = None

            loss.backward()

            clip_grad_norm_(model.parameters(), 0.5)

            optimizer.step()
            optimizer.zero_grad()

            # ema_model(models, models_ema)

            if torch.isnan(energy_pos.mean()):
                assert False

            if torch.abs(energy_pos.mean()) > 10.0:
                assert False

            if itr % FLAGS.log_interval == 0:
                tick = time.time()
                kvs = {}
                kvs['e_pos'] = energy_pos.mean().item()
                kvs['e_pos_std'] = energy_pos.std().item()
                kvs['e_neg'] = energy_neg.mean().item()
                kvs['kl_mean'] = loss_kl.mean().item()
                kvs['loss_repel'] = loss_repel.mean().item()
                kvs['e_neg_std'] = energy_neg.std().item()
                kvs['e_diff'] = kvs['e_pos'] - kvs['e_neg']
                # kvs['x_grad'] = np.abs(x_grad.detach().cpu().numpy()).mean()
                kvs['x_grad'] = x_grad
                kvs['iter'] = itr
                # kvs['hmc_loss'] = hmc_loss.item()
                kvs['num_steps'] = num_steps
                # kvs['t_diff'] = tick - tock
                kvs['positive_samples'] = data.detach()
                kvs['negative_samples'] = im_neg.detach()

                if FLAGS.replay_batch:
                    kvs['length'] = len(replay_buffer)

                if (ml_grad is not None):
                    kvs['kl_grad'] = kl_grad
                    kvs['ml_grad'] = ml_grad

                log_tensorboard(kvs)
                tock = tick

            if itr % FLAGS.save_interval == 0 and (FLAGS.save_interval != 0):
                model_path = osp.join(logdir, "model_{}.pth".format(itr))
                ckpt = {'optimizer_state_dict': optimizer.state_dict(),
                            'FLAGS': FLAGS, 'best_inception': best_inception}

                for i in range(FLAGS.ensembles):
                    ckpt['model_state_dict_{}'.format(i)] = model.state_dict()
                    # ckpt['ema_model_state_dict_{}'.format(i)] = model.state_dict()

                torch.save(ckpt, model_path)

            # if itr % FLAGS.save_interval == 0 and rank_idx == 0:
            #     im_samples = im_samples[::10]
            #     im_samples_total = torch.stack(im_samples, dim=1).detach().cpu().permute(0, 1, 3, 4, 2).numpy()
            #     try_im = im_neg
            #     orig_im = data_corrupt
            #     actual_im = rescale_im(data.detach().permute(0, 2, 3, 1).cpu().numpy())

            #     orig_im = rescale_im(orig_im.detach().permute(0, 2, 3, 1).cpu().numpy())
            #     try_im = rescale_im(try_im.detach().permute(0, 2, 3, 1).cpu().numpy()).squeeze()
            #     im_samples_total = rescale_im(im_samples_total)

            #     if rank_idx == 0:
            #         score, std = get_inception_score(list(try_im), splits=1)
            #         print("Inception score of {} with std of {}".format(
            #                 score, std))
            #         # kvs = {}
            #         # kvs['inception_score'] = score
            #         # kvs['inception_score_std'] = std
            #         # logger.writekvs(kvs)
            #         writer.add_scalar("inception score", score, itr)
            #         writer.add_scalar("inception score std", std, itr)

            #         if score > best_inception:
            #             model_path = osp.join(logdir, "model_best.pth")
            #             torch.save(ckpt, model_path)
            #             best_inception = score

            itr += 1

In [None]:
def main_single(FLAGS):
    print("Values of args: ", FLAGS)

    if FLAGS.dataset == "cifar10":
        train_dataset = Cifar10(FLAGS)
        # valid_dataset = Cifar10(FLAGS, split='valid', augment=False)
        # test_dataset = Cifar10(FLAGS, split='test', augment=False)
    elif FLAGS.dataset == "celeba":
        train_dataset = CelebADataset(FLAGS)
        # valid_dataset = CelebADataset(FLAGS, train=False, augment=False)
        # test_dataset = CelebADataset(FLAGS, train=False, augment=False)
    elif FLAGS.dataset == "cats":
        train_dataset = Cats()
    elif FLAGS.dataset == "stl":
        train_dataset = STLDataset(FLAGS)
        # valid_dataset = STLDataset(FLAGS, train=False)
        # test_dataset = STLDataset(FLAGS, train=False)
    elif FLAGS.dataset == "object":
        train_dataset = ObjectDataset(FLAGS.cond_idx)
        # valid_dataset = ObjectDataset(FLAGS.cond_idx)
        # test_dataset = ObjectDataset(FLAGS.cond_idx)
    elif FLAGS.dataset == "imagenet":
        train_dataset = ImageNet()
        # valid_dataset = ImageNet()
        # test_dataset = ImageNet()
    elif FLAGS.dataset == "mnist":
        train_dataset = Mnist(train=True)
        # valid_dataset = Mnist(train=False)
        # test_dataset = Mnist(train=False)
    elif FLAGS.dataset == "celebahq":
        train_dataset = CelebAHQ(cond_idx=FLAGS.cond_idx)
        # valid_dataset = CelebAHQ(cond_idx=FLAGS.cond_idx)
        # test_dataset = CelebAHQ(cond_idx=FLAGS.cond_idx)
    elif FLAGS.dataset == "lsun":
        train_dataset = LSUNBed(cond_idx=FLAGS.cond_idx)
        # valid_dataset = LSUNBed(cond_idx=FLAGS.cond_idx)
        # test_dataset = LSUNBed(cond_idx=FLAGS.cond_idx)
    else:
        assert False

    train_dataloader = DataLoader(train_dataset, num_workers=FLAGS.data_workers, batch_size=FLAGS.batch_size, shuffle=True, drop_last=True)
    # valid_dataloader = DataLoader(valid_dataset, num_workers=FLAGS.data_workers, batch_size=FLAGS.batch_size, shuffle=True, drop_last=True)
    # test_dataloader = DataLoader(test_dataset, num_workers=FLAGS.data_workers, batch_size=FLAGS.batch_size, shuffle=True, drop_last=True)

    logdir = osp.join(sample_dir, FLAGS.exp, FLAGS.dataset)

    best_inception = 0.0
    
    if FLAGS.resume_iter != 0:
        FLAGS_OLD = FLAGS
        model_path = osp.join(logdir, "model_{}.pth".format(FLAGS.resume_iter))
        checkpoint = torch.load(model_path)
        best_inception = checkpoint['best_inception']
        FLAGS = checkpoint['FLAGS']

        FLAGS.resume_iter = FLAGS_OLD.resume_iter
        FLAGS_OLD = None

    if FLAGS.dataset in ("cifar10", "celeba", "cats"):
        model_fn = ResNetModel
    elif FLAGS.dataset == "stl":
        model_fn = ResNetModel
    elif FLAGS.dataset == "object":
        model_fn = CelebAModel
    elif FLAGS.dataset == "mnist":
        model_fn = MNISTModel
    elif FLAGS.dataset == "celebahq":
        model_fn = CelebAModel
    elif FLAGS.dataset == "lsun":
        model_fn = CelebAModel
    elif FLAGS.dataset == "imagenet":
        model_fn = ImagenetModel
    else:
        assert False

    model = model_fn(FLAGS).train()
    # models_ema = model_fn(FLAGS).train()

    optimizer = Adam(model.parameters(), lr=FLAGS.lr, betas=(0.0, 0.9), eps=1e-8)

    # ema_model(models, models_ema, mu=0.0)

    it = FLAGS.resume_iter

    if not osp.exists(logdir):
        os.makedirs(logdir)

    checkpoint = None
    if FLAGS.resume_iter != 0:
        print("FLAGS.resume_iter:",FLAGS.resume_iter)
        model_path = osp.join(logdir, "model_{iter}.pth".format(FLAGS.resume_iter))
        checkpoint = torch.load(model_path)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        for i in range(FLAGS.ensembles):
            model.load_state_dict(checkpoint['model_state_dict_{}'.format(i)])
            # model_ema.load_state_dict(checkpoint['ema_model_state_dict_{}'.format(i)])
 

    print("New Values of args: ", FLAGS)

    pytorch_total_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
    print("Number of parameters for models", pytorch_total_params)

    train(model, optimizer, train_dataloader, logdir, FLAGS.resume_iter, FLAGS, best_inception)

In [None]:
# Call this function with list of images. Each of elements should be a 
# numpy array with values ranging from 0 to 255.
def get_inception_score(images, splits=10):
  return 0.0, 0.0
  # For convenience
  if len(images[0].shape) != 3:
    return 0, 0

  # Bypassing all the assertions so that we don't end prematuraly'
  # assert(type(images) == list)
  # assert(type(images[0]) == np.ndarray)
  # assert(len(images[0].shape) == 3)
  # assert(np.max(images[0]) > 10)
  # assert(np.min(images[0]) >= 0.0)
  inps = []
  for img in images:
    img = img.astype(np.float32)
    inps.append(np.expand_dims(img, 0))
  bs = 1
  preds = []
  n_batches = int(math.ceil(float(len(inps)) / float(bs)))
  for i in range(n_batches):
      # sys.stdout.write(".")
      # sys.stdout.flush()
      inp = inps[(i * bs):min((i + 1) * bs, len(inps))]
      inp = np.concatenate(inp, 0)
      pred = sess.run(softmax, {'ExpandDims:0': inp})
      preds.append(pred)
  preds = np.concatenate(preds, 0)
  scores = []
  for i in range(splits):
    part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
    kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
    kl = np.mean(np.sum(kl, 1))
    scores.append(np.exp(kl))
  return np.mean(scores), np.std(scores)

In [None]:
 tensorboard --logdir runs

In [None]:
main_single(flags)