# Imports

In [None]:
!nvidia-smi

In [None]:
import numpy
import torch
import torch.nn as nn
from collections import OrderedDict
import os
import sys
import warnings
from torch.utils.data import DataLoader
import argparse
import time
import copy
import math
import torchvision.utils as vision_utils
import json
import numpy as np
from torch.distributions import bernoulli
from scipy import linalg
import torchvision.datasets as _datasets
import torchvision.transforms as _transforms
import matplotlib.pyplot as plt
import shutil
import random

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Models

In [None]:
_NOISE_DIM = 128
_H_FILTERS = 64


class DiscriminatorCNN28(nn.Module):

    def __init__(self, img_channels=1, h_filters=_H_FILTERS,
                 spectral_norm=False, img_size=None, n_outputs=1):
        if any(not isinstance(_arg, int) for _arg in [img_channels, h_filters, n_outputs]):
            raise TypeError("Unsupported operand type. Expected integer.")
        if not isinstance(spectral_norm, bool):
            raise TypeError(f"Unsupported operand type: {type(spectral_norm)}. "
                            "Expected bool.")
        if min([img_channels, h_filters, n_outputs]) <= 0:
            raise ValueError("Expected nonzero positive input arguments for: the "
                             "number of output channels, the dimension of the noise "
                             "vector, as well as the depth of the convolution kernels.")
        super(DiscriminatorCNN28, self).__init__()
        # _conv = nn.utils.spectral_norm(nn.Conv2d) if spectral_norm else nn.Conv2d
        _apply_sn = lambda x: nn.utils.spectral_norm(x) if spectral_norm else x
        self.img_channels = img_channels
        self.img_size = img_size
        self.n_outputs = n_outputs
        self.main = nn.Sequential(
            _apply_sn(nn.Conv2d(img_channels, h_filters, 4, 2, 1, bias=False)),
            nn.LeakyReLU(0.2, inplace=True),
            _apply_sn(nn.Conv2d(h_filters, h_filters * 2, 4, 2, 1, bias=False)),
            nn.BatchNorm2d(h_filters * 2),
            nn.LeakyReLU(0.2, inplace=True),
            _apply_sn(nn.Conv2d(h_filters * 2, h_filters * 4, 4, 2, 1, bias=False)),
            nn.BatchNorm2d(h_filters * 4),
            nn.LeakyReLU(0.2, inplace=True),
            _apply_sn(nn.Conv2d(h_filters * 4, self.n_outputs, 3, 1, 0, bias=False))        
        )

    def forward(self, x):
        if self.img_channels is not None and self.img_size is not None:
            if numpy.prod(list(x.size())) % (self.img_size ** 2 * self.img_channels) != 0:
                raise ValueError(f"Size mismatch. Input size: {numpy.prod(list(x.size()))}. "
                                 f"Expected input divisible by: {self.noise_dim}")
            x = x.view(-1, self.img_channels, self.img_size, self.img_size)
        x = self.main(x)
        return x.view(-1, self.n_outputs)

    def load(self, model):
      self.load_state_dict(model.state_dict())


class GeneratorCNN28(nn.Module):

    def __init__(self, img_channels=1, noise_dim=_NOISE_DIM, h_filters=_H_FILTERS, out_tanh=False):
        if any(not isinstance(_arg, int) for _arg in [img_channels, noise_dim, h_filters]):
            raise TypeError("Unsupported operand type. Expected integer.")
        if min([img_channels, noise_dim, h_filters]) <= 0:
            raise ValueError("Expected strictly positive input arguments for the "
                             "number of output channels, the dimension of the noise "
                             "vector, as well as the depth of the convolution kernels.")
        super(GeneratorCNN28, self).__init__()
        self.noise_dim = noise_dim
        self.main = nn.Sequential(
            nn.ConvTranspose2d(noise_dim, h_filters * 8, 3, 1, 0, bias=False),
            nn.BatchNorm2d(_H_FILTERS * 8),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(h_filters * 8, h_filters * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(_H_FILTERS * 4),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(h_filters * 4, h_filters * 2, 4, 2, 0, bias=False),
            nn.BatchNorm2d(_H_FILTERS * 2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(h_filters * 2, img_channels, 4, 2, 1, bias=False),
            nn.Tanh() if out_tanh else nn.Sigmoid()
        )

    def forward(self, x):

        if numpy.prod(list(x.size())) % self.noise_dim != 0:
            raise ValueError(f"Size mismatch. Input size: {numpy.prod(list(x.size()))}. "
                             f"Expected input divisible by: {self.noise_dim}")
        x = x.view(-1, self.noise_dim, 1, 1)
        x = self.main(x)
        return x

    def load(self, model):
      self.load_state_dict(model.state_dict())


class MLP_mnist(nn.Module):
  def __init__(self, input_dims, n_hiddens, n_class):
    super(MLP_mnist, self).__init__()
    assert isinstance(input_dims, int), 'Expected int for input_dims'
    self.input_dims = input_dims
    current_dims = input_dims
    layers = OrderedDict()

    if isinstance(n_hiddens, int):
      n_hiddens = [n_hiddens]
    else:
      n_hiddens = list(n_hiddens)
    for i, n_hidden in enumerate(n_hiddens):
      layers['fc{}'.format(i+1)] = nn.Linear(current_dims, n_hidden)
      layers['relu{}'.format(i+1)] = nn.ReLU()
      layers['drop{}'.format(i+1)] = nn.Dropout(0.2)
      current_dims = n_hidden
    layers['out'] = nn.Linear(current_dims, n_class)
    self.layers = layers
    self.model= nn.Sequential(layers)
    #print(self.model)

  def forward(self, input):
    input = input.view(input.size(0), -1)
    assert input.size(1) == self.input_dims
    return self.model.forward(input)

  def get_logits_and_fc2_outputs(self, x):
    x = x.view(x.size(0), -1)
    assert x.size(1) == self.input_dims
    fc2_out = None
    for l in self.model:
      x = l(x)
      if l == self.layers["fc2"]:
        fc2_out = x
    return x, fc2_out


def pretrained_mnist_model(input_dims=784, n_hiddens=[256, 256], n_class=10, 
                           pretrained=None):
    model = MLP_mnist(input_dims, n_hiddens, n_class)
    if pretrained is not None:
        if os.path.exists(pretrained):
            print('Loading trained model from %s' % pretrained)
            state_dict = torch.load(pretrained,
                    map_location='cuda:0' if torch.cuda.is_available() else 'cpu')
            if 'parallel' in pretrained:
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove `module.`
                    new_state_dict[name] = v
                state_dict = new_state_dict
        else:
            raise FileNotFoundError(f"Could not find pretrained model: {pretrained}.")
        model.load_state_dict(state_dict)
    if torch.cuda.is_available():
        model = model.cuda()
    return model

# Dataloader

In [None]:
class Binarize(object):
  def __init__(self, threshold=0.3):
    self.threshold = threshold
      
  def __call__(self, t):
    t = (t > self.threshold).float()
    return t
  
  def __repr__(self):
    return self.__class__.__name__ + '(th={0})'.format(self.threshold)


class Smooth(object):
  def __init__(self, smooth=0.1):
    self.smooth = smooth
      
  def __call__(self, t):
    t[t == 1.] = 1 - self.smooth
    t[t == 0.] = 0 + self.smooth
    return t
  
  def __repr__(self):
    return self.__class__.__name__ + '(smooth={0})'.format(self.smooth)


def load_mnist(_data_root='datasets', binarized=False, bin_th=0.3, smooth=None):
    trans = [_transforms.ToTensor()]
    if binarized:
      binarizor = Binarize(bin_th)
      trans.append(binarizor)
    if smooth is not None:
      smoother = Smooth(smooth)
      trans.append(smoother)
    trans = _transforms.Compose(trans)
    _data = _datasets.MNIST(_data_root, train=True, download=True,
                            transform=trans)
    return _data

# Generate constraints

In [None]:
num=100

random.seed(0)
torch.manual_seed(0)
def generate_constraints(num, M, b_lb=3, b_ub=10, device='cuda'):
  device=torch.device(device)
  l,r=[],[]
  for i in range(num):
    r.append(b_lb+(b_ub-b_lb)*random.random())
    li=copy.deepcopy(M)
    li.to(device)
    for p in li.parameters():
      p=torch.randn(*p.shape,device=device)
    l.append(li)
  return [l,r]

G = GeneratorCNN28(noise_dim=_NOISE_DIM, out_tanh=True)
D = DiscriminatorCNN28(spectral_norm=False, img_size=28)
C_g=generate_constraints(num, G)
C_d=generate_constraints(num, D)


def proj(A,generator,eps=1e-5, maxit=10000):
  l,r=None,None
  if generator:
    l,r=C_g
  else:
    l,r=C_d
  c_num=len(r)
  it=0
  while it<maxit:
    idx=-1
    max_norm=0
    d=eps
    for i in range(c_num):
      temp=-r[i]
      norm=0
      for p,pl in zip(A.parameters(), l[i].parameters()):
        with torch.no_grad():
          temp+=torch.sum(p*pl)
          norm+=torch.sum(pl**2)
      norm=math.sqrt(norm)
      temp/=norm
      if temp>d:
        d=temp
        idx=i
        max_norm=norm
    if idx==-1:
      break
    for p,pl in zip(A.parameters(), l[idx].parameters()):
      with torch.no_grad():
        p-=d/max_norm*pl
    it+=1
    if it==maxit:
      print('*************Projection Failed!****************')

# PI-ACVI training function

In [None]:
def get_disciminator_loss_x(Dx, Dy, Dlmd, x_real, x_gen, lbl_real, lbl_fake, beta, use_acvi=True):
  """"""
  D_x = Dx(x_real)
  D_G_z = Dx(x_gen)
  lossD_real = torch.binary_cross_entropy_with_logits(D_x, lbl_real).mean()
  lossD_fake = torch.binary_cross_entropy_with_logits(D_G_z, lbl_fake).mean()
  lossD = lossD_real + lossD_fake
  if use_acvi:
    for px,py,plmd in zip(Dx.parameters(),Dy.parameters(),Dlmd.parameters()):
      lossD+=beta/2*torch.sum((px-(py-plmd/beta))**2)
  return lossD


def get_generator_loss_x(Gx, Gy, Glmd, Dx, z, lbl_real, beta, use_acvi=True):
  """"""
  D_G_z = Dx(Gx(z))
  lossG = torch.binary_cross_entropy_with_logits(D_G_z, lbl_real).mean()
  if use_acvi:
    for px,py,plmd in zip(Gx.parameters(),Gy.parameters(),Glmd.parameters()):
      lossG += beta/2*torch.sum((px-(py-plmd/beta))**2)
  return lossG


def get_sampler(dataset, batch_size, shuffle=True, drop_last=True):
  dataloader = DataLoader(dataset, batch_size, shuffle=shuffle, drop_last=drop_last)
  dataloader_iterator = iter(dataloader)
  def sampler():
    nonlocal dataloader_iterator
    try:
        data = next(dataloader_iterator) 
    except StopIteration:
        dataloader_iterator = iter(dataloader)
        data = next(dataloader_iterator) 
    return data
  return sampler


def train(Gx, Gy, Glmd, Dx, Dy, Dlmd, dataset, iterations, batch_size=32, lrDx=0.01, lrGx=0.01,
          beta1=0.99, eval_every=100, device=torch.device('cpu'), use_acvi=True,
          plot_func=lambda a,b,c,d: None, extragrad=False, 
          out_dir=None, lx=1, lx_warmup=100, train_time=1500, beta=0.5, ceps=-0.001):
  
  XY = [(x,y) for x,y in dataset]
  X = torch.stack([x[0] for x in XY]).to('cuda')
  Y = torch.tensor([x[1] for x in XY]).long().to('cuda')
  dataset = torch.utils.data.TensorDataset(X, Y)
    
  sampler = get_sampler(dataset, batch_size, shuffle=True, drop_last=True)

  if extragrad:
    D_extrax = copy.deepcopy(Dx)
    G_extrax = copy.deepcopy(Gx)
  else:
    D_extrax = Dx
    G_extrax = Gx

  # Optimizers
  optimizerDx = torch.optim.Adam(Dx.parameters(), lr=lrDx, betas=(beta1, 0.999))
  optimizerGx = torch.optim.Adam(Gx.parameters(), lr=lrGx, betas=(beta1, 0.999))

  optimizerD_extrax = torch.optim.Adam(D_extrax.parameters(), lr=lrDx, betas=(beta1, 0.999))
  optimizerG_extrax = torch.optim.Adam(G_extrax.parameters(), lr=lrGx, betas=(beta1, 0.999))

  # LBLs
  lbl_real = torch.ones(batch_size, 1, device=device)
  lbl_fake = torch.zeros(batch_size, 1, device=device)

  fixed_noise = torch.randn(100, Gx.noise_dim, device=device)

  Gx.to(device)
  Dx.to(device)
  Gy.to(device)
  Dy.to(device)
  Glmd.to(device)
  Dlmd.to(device)

  G_extrax.to(device)
  D_extrax.to(device)

  start_time = time.perf_counter()

  for itr in range(iterations):

    # UPDATE THE X
    for _ in range(lx_warmup if itr == 0 else lx):

      err_x = 0

      if extragrad:
        # STEP 1: get G_{t+1} (G_extra)
        optimizerG_extrax.zero_grad()
        z = torch.randn(batch_size, G_extrax.noise_dim, device=device)
        lossGx = get_generator_loss_x(G_extrax, Gy, Glmd, Dx, z, lbl_real, beta, use_acvi=use_acvi)
        lossGx.backward()
        optimizerG_extrax.step()
        if not use_acvi:
          with torch.no_grad():
            proj(G_extrax,generator=True)

        # STEP 2: get D_{t+1} (D_extra)
        optimizerD_extrax.zero_grad()
        x_real, _ = sampler()
        x_real = x_real.to(device)
        z = torch.randn(batch_size, Gx.noise_dim, device=device)
        with torch.no_grad():
          x_gen = Gx(z)
        lossDx = get_disciminator_loss_x(D_extrax, Dy, Dlmd, x_real, x_gen, lbl_real, lbl_fake, beta, use_acvi=use_acvi)
        lossDx.backward()
        optimizerD_extrax.step()
        if not use_acvi:
          with torch.no_grad():
            proj(D_extrax,generator=False)

      # STEP 3: D optimization step using G_extra
      x_real, _ = sampler()
      x_real = x_real.to(device)
      z = torch.randn(batch_size, Gx.noise_dim, device=device)
      with torch.no_grad():
        x_gen = G_extrax(z) # using G_{t+1}
      optimizerDx.zero_grad()
      lossDx = get_disciminator_loss_x(Dx, Dy, Dlmd, x_real, x_gen, lbl_real, lbl_fake, beta, use_acvi=use_acvi)
      lossDx.backward()
      #with torch.no_grad():
      #  for px, py, plmd in zip(Dx.parameters(), Dy.parameters(), Dlmd.parameters()):
      #    err_x = max((px + 1/beta * px.grad - py + 1/beta * plmd).abs().max(),  err_x)
      optimizerDx.step()
      optimizerDx.zero_grad()
      if not use_acvi:
          with torch.no_grad():
            proj(Dx,generator=False)

      # STEP 4: G optimization step using D_extra
      z = torch.randn(batch_size, Gx.noise_dim, device=device)
      optimizerGx.zero_grad()
      lossGx = get_generator_loss_x(Gx, Gy, Glmd, D_extrax, z, lbl_real, beta, use_acvi=use_acvi) # we use the unrolled D
      lossGx.backward()
      #with torch.no_grad():
      #  for px, py, plmd in zip(Gx.parameters(), Gy.parameters(), Glmd.parameters()):
      #    err_x = max((px + 1/beta * px.grad - py + 1/beta * plmd).abs().max(), err_x)
      optimizerGx.step()
      if not use_acvi:
          with torch.no_grad():
            proj(Gx,generator=True)

      if extragrad:
        G_extrax.load_state_dict(Gx.state_dict())
        D_extrax.load_state_dict(Dx.state_dict())

    time_tick=time.perf_counter() - start_time

    # Just plotting things
    if itr % eval_every == 0 or itr == iterations-1 or time_tick>=train_time:
      with torch.no_grad():
        probasx = torch.sigmoid(Dx(Gx(fixed_noise)))
        mean_probax = probasx.mean().cpu().item()
        std_probax = probasx.std().cpu().item()
        samplesx = Gx(fixed_noise)
      print(f"Iter {itr}: Mean proba from Dx(Gx(z)): {mean_probax:.4f} +/- {std_probax:.4f}")
      plot_func(samplesx.detach().cpu(), time_tick=time_tick, D=Dx, G=Gx, iteration=itr, G_avg=None, G_ema=None)

    if time_tick >= train_time:
      break

    # UPDATE THE Y
    if use_acvi:
      with torch.no_grad():
        for px, py, plmd in zip(Gx.parameters(),Gy.parameters(),Glmd.parameters()):
          py.data = px + plmd/beta
        proj(Gy,generator=True)
        for px, py, plmd in zip(Dx.parameters(),Dy.parameters(),Dlmd.parameters()):
          py.data = px + plmd/beta
        proj(Dy,generator=False)

    # UPDATE THE LAMBDAS
    if use_acvi:
      with torch.no_grad():
        for px,py,plmd in zip(Gx.parameters(),Gy.parameters(),Glmd.parameters()):
          plmd += beta*(px-py)
        for px,py,plmd in zip(Dx.parameters(),Dy.parameters(),Dlmd.parameters()):
          plmd += beta*(px-py)

    if time_tick >= train_time:
      break

# Display & Eval

In [None]:
def compute_mu_sigma_pretrained_model(dataset, pretrained_clf):
  dataloader = DataLoader(dataset, batch_size=512, num_workers=2, drop_last=True)
  cuda = next(pretrained_clf.parameters()).is_cuda
  all_fc2_out = []
  pretrained_clf.eval()
  for batch, _ in dataloader:
    with torch.no_grad():
      if cuda:
        batch = batch.cuda()
      _, fc2_out = pretrained_clf.get_logits_and_fc2_outputs(batch)
    all_fc2_out.append(fc2_out.cpu())
  all_fc2_out = torch.cat(all_fc2_out, dim=0).numpy()
  mu_real = np.mean(all_fc2_out, axis=0)
  sigma_real = np.cov(all_fc2_out, rowvar=False)
  return mu_real, sigma_real


def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance."""
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
    assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"

    diff = mu1 - mu2

    # product might be almost singular
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps
        print(msg)
        # warnings.warn(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError("Imaginary component {}".format(m))
        covmean = covmean.real
    tr_covmean = np.trace(covmean)

    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean


def _calculate_metrics(pretrained_clf, G, dataset_length, mu_real, sigma_real, 
                       n_classes=10, batch_size=1024):
    cuda = next(pretrained_clf.parameters()).is_cuda
    if cuda:
      device = torch.device('cuda')
    else:
      device = torch.device('cpu')
    # Using pretrained clf to get predictions over fake data
    inception_predictions, all_fc2_out, class_probas = [], [], []
    dataloader = DataLoader(list(range(dataset_length)), batch_size, num_workers=2, drop_last=True)
    pretrained_clf.eval()
    for batch in dataloader:
      with torch.no_grad():
        noise = torch.randn(batch_size, G.noise_dim, device=device)
        probas, fc2_out = pretrained_clf.get_logits_and_fc2_outputs(G(noise).view(batch_size, -1))
      all_fc2_out.append(fc2_out.cpu())
      class_probas.append(probas.cpu())
    all_fc2_out = torch.cat(all_fc2_out, dim=0).numpy()
    class_probas = torch.cat(class_probas, dim=0)
    inception_predictions = torch.softmax(class_probas, dim=1).numpy()
    class_probas = class_probas.numpy()
    pred_prob = np.maximum(class_probas, 1e-20 * np.ones_like(class_probas))

    y_vec = 1e-20 * np.ones((len(pred_prob), n_classes), dtype=np.float)  # pred label distr
    gnd_vec = 0.1 * np.ones((1, n_classes), dtype=np.float)  # gnd label distr, uniform over classes

    for i, label in enumerate(pred_prob):
        y_vec[i, np.argmax(pred_prob[i])] = 1.0
    y_vec = np.sum(y_vec, axis=0, keepdims=True)
    y_vec = y_vec / np.sum(y_vec)

    label_entropy = np.sum(-y_vec * np.log(y_vec)).tolist()
    label_tv = np.true_divide(np.sum(np.abs(y_vec - gnd_vec)), 2).tolist()
    label_l2 = np.sum((y_vec - gnd_vec) ** 2).tolist()

    # --- is ----
    inception_scores = []
    for i in range(n_classes):
        part = inception_predictions[(i * inception_predictions.shape[0]
                                      // n_classes):((i + 1) * inception_predictions.shape[0]
                                                     // n_classes), :]
        kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
        kl = np.mean(np.sum(kl, 1))
        inception_scores.append(np.exp(kl))

    mu = np.mean(all_fc2_out, axis=0)
    sigma = np.cov(all_fc2_out, rowvar=False)
    _fid = calculate_frechet_distance(mu, sigma, mu_real, sigma_real)

    return (label_entropy, label_tv, label_l2,
            float(np.mean(inception_scores)),
            float(np.std(inception_scores)),
            _fid)


def get_metrics(pretrained_clf, dataset_length, mu_real, sigma_real, G):
    """Calculates entropy, TV, L2, and inception scores."""
    e, tv, l2, is_m, is_std, fid = _calculate_metrics(pretrained_clf,
                                                      G,
                                                      dataset_length,
                                                      mu_real,
                                                      sigma_real)
    m_result = {
        'entropy': e,
        'TV': tv,
        'L2': l2,
        'inception_mean': is_m,
        'inception_std': is_std,
        'fid': fid
    }
    return m_result

In [None]:
def save_models(G, D, opt_G, opt_D, out_dir, suffix):
  torch.save(G.state_dict(), os.path.join(out_dir, f"gen_{suffix}.pth"))
  torch.save(D.state_dict(), os.path.join(out_dir, f"disc_{suffix}.pth"))
  torch.save(opt_G.state_dict(), os.path.join(out_dir, f"gen_optim_{suffix}.pth"))
  torch.save(opt_D.state_dict(), os.path.join(out_dir, f"disc_optim_{suffix}.pth"))


def get_plot_func(out_dir, img_size, num_samples_eval=10000, save_curves=None):
  dataset = load_mnist(_data_root='datasets', binarized=False)
  #shutil.rmtree(out_dir, ignore_errors=True)
  #if not os.path.exists(out_dir):
  #  os.makedirs(out_dir)
  pretrained_clf = pretrained_mnist_model(pretrained='/content/drive/My Drive/mnist_exp/ACVI/mnist/mnist.pth')
  mu_real, sigma_real = compute_mu_sigma_pretrained_model(dataset, pretrained_clf)
  inception_means, inception_stds, inception_means_ema, inception_means_avg, fids, fids_ema, fids_avg = [], [], [], [], [], [], []
  iterations, times = [], []
  def plot_func(samples, iteration, time_tick, G=None, D=None, G_avg=None, G_ema=None):
    fig = plt.figure(figsize=(12,5), dpi=100)
    plt.subplot(1,2,1)
    samples = samples.view(100, *img_size)
    file_name = os.path.join(out_dir, '%08d.png' % iteration)
    vision_utils.save_image(samples, file_name, nrow=10)
    grid_img = vision_utils.make_grid(samples, nrow=10, normalize=True, padding=0)
    plt.imshow(grid_img.permute(1, 2, 0), interpolation='nearest')
    plt.subplot(1,2,2)
    metrics = get_metrics(pretrained_clf, num_samples_eval, mu_real, sigma_real, G)
    fids.append(metrics['fid'])
    inception_means.append(metrics['inception_mean'])
    inception_stds.append(metrics['inception_std'])
    if G_avg is not None:
      metrics = get_metrics(pretrained_clf, num_samples_eval, mu_real, sigma_real, G_avg)
      fids_avg.append(metrics['fid'])
      inception_means_avg.append(metrics['inception_mean'])
    if G_ema is not None:
      metrics = get_metrics(pretrained_clf, num_samples_eval, mu_real, sigma_real, G_ema)
      fids_ema.append(metrics['fid'])
      inception_means_ema.append(metrics['inception_mean'])
    iterations.append(iteration)
    times.append(time_tick)
    #  is
    is_low  = [m - s for m, s in zip(inception_means, inception_stds)]
    is_high = [m + s for m, s in zip(inception_means, inception_stds)]
    plt.plot(times, inception_means, label="is", color='r')
    plt.fill_between(times, is_low, is_high, facecolor='r', alpha=.3)
    plt.yticks(np.arange(0, 10+1, 0.5))
    # fid
    plt.plot(times, fids, label="fid", color='b')
    plt.xlabel('Time (sec)')
    plt.ylabel('Metric')
    plt.grid()
    ax = fig.gca()
    ax.set_ylim(-0.1, 10)
    plt.legend(fancybox=True, framealpha=.5)
    curves_img_file_name = os.path.join(out_dir, 'curves.png')
    fig.savefig(curves_img_file_name)
    plt.show()
    curves_file_name = os.path.join(out_dir, 'curves.json')
    curves = {
        'inception_means': list(inception_means),
        'inception_stds': list(inception_stds),
        'inception_means_ema': list(inception_means_ema),
        'inception_means_avg': list(inception_means_avg),
        'fids_ema': list(fids_ema),
        'fids_avg': list(fids_avg),
        'fids': list(fids),
        'iterations':iterations,
        'times': times
    }
    with open(curves_file_name, 'w') as fs:
      json.dump(curves, fs)
  return plot_func

# Experiments

## GDA

In [None]:
args = dict(iterations = 10000,
            batch_size = 50,
            lrDx = 0.001,
            lrGx = 0.001,
            beta1 = 0.05,
            extragrad = False,
            eval_every = 100,
            device = 'cuda',
            lx_warmup=1,
            lx=1,
            beta=0.5,
            use_acvi=False)

for k in range(0,4):
  torch.manual_seed(k)
  torch.cuda.manual_seed(k)
  np.random.seed(k)
  exp_key = f"iter{args['iterations']}_bs{args['batch_size']}_lrD{args['lrDx']}" + \
            f"_lrG{args['lrGx']}_beta1{args['beta1']}_beta{args['beta']}" + \
            f"_extragrad{args['extragrad']}_ee{args['eval_every']}_lx{args['lx']}_lxw{args['lx_warmup']}_useacvi{args['use_acvi']}"
  out_dir = f"/content/drive/My Drive/mnist_exp/ACVI/aistats/EG/{exp_key}/{k}/"

  shutil.rmtree(out_dir, ignore_errors=True)
  if not os.path.exists(out_dir):
    os.makedirs(out_dir)

  with open(os.path.join(out_dir, 'args.json'), 'w') as fs:
    json.dump(args, fs)

  dataset = load_mnist(_data_root='datasets', binarized=False)

  plot_func = get_plot_func(out_dir=out_dir, 
                img_size=dataset[0][0].size(),
                num_samples_eval=10000)

  Gx = GeneratorCNN28(noise_dim=_NOISE_DIM, out_tanh=True)
  Dx = DiscriminatorCNN28(spectral_norm=False, img_size=28)

  Gy=GeneratorCNN28(noise_dim=_NOISE_DIM, out_tanh=True)
  with torch.no_grad():
    for px, py in zip(Gx.parameters(), Gy.parameters()):
      py.data = copy.deepcopy(px.data)
  Dy=DiscriminatorCNN28(spectral_norm=False, img_size=28)
  with torch.no_grad():
    for px, py in zip(Dx.parameters(), Dy.parameters()):
      py.data = copy.deepcopy(px.data)

  Glmd=GeneratorCNN28(noise_dim=_NOISE_DIM, out_tanh=True)
  with torch.no_grad():
    for p in Glmd.parameters():
      p.data.fill_(0.0)
  Dlmd=DiscriminatorCNN28(spectral_norm=False, img_size=28)
  with torch.no_grad():
    for p in Dlmd.parameters():
      p.data.fill_(0.0)

  train(Gx, Gy, Glmd, Dx, Dy, Dlmd, dataset, 
        iterations=args['iterations'], 
        batch_size=args['batch_size'], 
        lrDx=args['lrDx'], 
        lrGx=args['lrGx'],
        beta1=args['beta1'], 
        eval_every=args['eval_every'], 
        device=torch.device(args['device']), 
        plot_func=plot_func, 
        extragrad=args['extragrad'], 
        out_dir=out_dir, 
        lx=args['lx'], 
        lx_warmup=args['lx_warmup'], 
        beta=args['beta'],
        use_acvi=args['use_acvi'])

## EG

In [None]:
args = dict(iterations = 5000,
            batch_size = 50,
            lrDx = 0.001,
            lrGx = 0.001,
            beta1 = 0.05,
            extragrad = True,
            eval_every = 100,
            device = 'cuda',
            lx_warmup=1,
            lx=1,
            beta=0.5,
            use_acvi=False)

for k in range(0,4):
  torch.manual_seed(k)
  torch.cuda.manual_seed(k)
  np.random.seed(k)
  exp_key = f"iter{args['iterations']}_bs{args['batch_size']}_lrD{args['lrDx']}" + \
            f"_lrG{args['lrGx']}_beta1{args['beta1']}_beta{args['beta']}" + \
            f"_extragrad{args['extragrad']}_ee{args['eval_every']}_lx{args['lx']}_lxw{args['lx_warmup']}_useacvi{args['use_acvi']}"
  out_dir = f"/content/drive/My Drive/mnist_exp/ACVI/aistats/EG/{exp_key}/{k}/"

  shutil.rmtree(out_dir, ignore_errors=True)
  if not os.path.exists(out_dir):
    os.makedirs(out_dir)

  with open(os.path.join(out_dir, 'args.json'), 'w') as fs:
    json.dump(args, fs)

  dataset = load_mnist(_data_root='datasets', binarized=False)

  plot_func = get_plot_func(out_dir=out_dir, 
                img_size=dataset[0][0].size(),
                num_samples_eval=10000)

  Gx = GeneratorCNN28(noise_dim=_NOISE_DIM, out_tanh=True)
  Dx = DiscriminatorCNN28(spectral_norm=False, img_size=28)

  Gy=GeneratorCNN28(noise_dim=_NOISE_DIM, out_tanh=True)
  with torch.no_grad():
    for px, py in zip(Gx.parameters(), Gy.parameters()):
      py.data = copy.deepcopy(px.data)
  Dy=DiscriminatorCNN28(spectral_norm=False, img_size=28)
  with torch.no_grad():
    for px, py in zip(Dx.parameters(), Dy.parameters()):
      py.data = copy.deepcopy(px.data)

  Glmd=GeneratorCNN28(noise_dim=_NOISE_DIM, out_tanh=True)
  with torch.no_grad():
    for p in Glmd.parameters():
      p.data.fill_(0.0)
  Dlmd=DiscriminatorCNN28(spectral_norm=False, img_size=28)
  with torch.no_grad():
    for p in Dlmd.parameters():
      p.data.fill_(0.0)

  train(Gx, Gy, Glmd, Dx, Dy, Dlmd, dataset, 
        iterations=args['iterations'], 
        batch_size=args['batch_size'], 
        lrDx=args['lrDx'], 
        lrGx=args['lrGx'],
        beta1=args['beta1'], 
        eval_every=args['eval_every'], 
        device=torch.device(args['device']), 
        plot_func=plot_func, 
        extragrad=args['extragrad'], 
        out_dir=out_dir, 
        lx=args['lx'], 
        lx_warmup=args['lx_warmup'], 
        beta=args['beta'],
        use_acvi=args['use_acvi'])

## PI-ACVI

In [None]:
args = dict(iterations = 5000,
            batch_size = 50,
            lrDx = 0.001,
            lrGx = 0.001,
            beta1 = 0.05,
            extragrad = False,
            eval_every = 100,
            device = 'cuda',
            lx_warmup=500,
            lx=20,
            beta=0.5,
            use_acvi=True)

for k in range(0,4):
  torch.manual_seed(k)
  torch.cuda.manual_seed(k)
  np.random.seed(k)
  exp_key = f"iter{args['iterations']}_bs{args['batch_size']}_lrD{args['lrDx']}" + \
            f"_lrG{args['lrGx']}_beta1{args['beta1']}_beta{args['beta']}" + \
            f"_extragrad{args['extragrad']}_ee{args['eval_every']}_lx{args['lx']}_lxw{args['lx_warmup']}_useacvi{args['use_acvi']}"
  out_dir = f"/content/drive/My Drive/mnist_exp/ACVI/aistats/EG/{exp_key}/{k}/"

  shutil.rmtree(out_dir, ignore_errors=True)
  if not os.path.exists(out_dir):
    os.makedirs(out_dir)

  with open(os.path.join(out_dir, 'args.json'), 'w') as fs:
    json.dump(args, fs)

  dataset = load_mnist(_data_root='datasets', binarized=False)

  plot_func = get_plot_func(out_dir=out_dir, 
                img_size=dataset[0][0].size(),
                num_samples_eval=10000)

  Gx = GeneratorCNN28(noise_dim=_NOISE_DIM, out_tanh=True)
  Dx = DiscriminatorCNN28(spectral_norm=False, img_size=28)

  Gy=GeneratorCNN28(noise_dim=_NOISE_DIM, out_tanh=True)
  with torch.no_grad():
    for px, py in zip(Gx.parameters(), Gy.parameters()):
      py.data = copy.deepcopy(px.data)
  Dy=DiscriminatorCNN28(spectral_norm=False, img_size=28)
  with torch.no_grad():
    for px, py in zip(Dx.parameters(), Dy.parameters()):
      py.data = copy.deepcopy(px.data)

  Glmd=GeneratorCNN28(noise_dim=_NOISE_DIM, out_tanh=True)
  with torch.no_grad():
    for p in Glmd.parameters():
      p.data.fill_(0.0)
  Dlmd=DiscriminatorCNN28(spectral_norm=False, img_size=28)
  with torch.no_grad():
    for p in Dlmd.parameters():
      p.data.fill_(0.0)

  train(Gx, Gy, Glmd, Dx, Dy, Dlmd, dataset, 
        iterations=args['iterations'], 
        batch_size=args['batch_size'], 
        lrDx=args['lrDx'], 
        lrGx=args['lrGx'],
        beta1=args['beta1'], 
        eval_every=args['eval_every'], 
        device=torch.device(args['device']), 
        plot_func=plot_func, 
        extragrad=args['extragrad'], 
        out_dir=out_dir, 
        lx=args['lx'], 
        lx_warmup=args['lx_warmup'], 
        beta=args['beta'],
        use_acvi=args['use_acvi'])