# Model Performance Analysis


## Imports and Installs

In [None]:
!pip install kornia

In [None]:
# imports
# torch and friends
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import DataLoader
import torchvision.utils as vutils
import torch.nn.functional as F
from torchvision.utils import make_grid
from torchvision.datasets import CIFAR10, MNIST
from torchvision import transforms
from PIL import Image, ImageOps
import kornia.augmentation as K

# standard
import os
import random
import time
import numpy as np
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA, KernelPCA
from tqdm import tqdm
import pickle
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

## Helper Functions

In [None]:
def load_model(model, pretrained, device):
    weights = torch.load(pretrained, map_location=device)
    model.load_state_dict(weights['model'], strict=False)


def reparameterize(mu, logvar):
    """
    This function applies the reparameterization trick:
    z = mu(X) + sigma(X)^0.5 * epsilon, where epsilon ~ N(0,I)
    :param mu: mean of x
    :param logvar: log variaance of x
    :return z: the sampled latent variable
    """
    device = mu.device
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std).to(device)
    return mu + eps * std


def calc_reconstruction_loss(x, recon_x, loss_type='mse', reduction='sum'):
    """

    :param x: original inputs
    :param recon_x:  reconstruction of the VAE's input
    :param loss_type: "mse", "l1", "bce"
    :param reduction: "sum", "mean", "none"
    :return: recon_loss
    """
    if reduction not in ['sum', 'mean', 'none']:
        raise NotImplementedError
    recon_x = recon_x.view(recon_x.size(0), -1)
    x = x.view(x.size(0), -1)
    if loss_type == 'mse':
        recon_error = F.mse_loss(recon_x, x, reduction='none')
        recon_error = recon_error.sum(1)
        if reduction == 'sum':
            recon_error = recon_error.sum()
        elif reduction == 'mean':
            recon_error = recon_error.mean()
    elif loss_type == 'l1':
        recon_error = F.l1_loss(recon_x, x, reduction=reduction)
    elif loss_type == 'bce':
        recon_error = F.binary_cross_entropy(recon_x, x, reduction=reduction)
    else:
        raise NotImplementedError
    return recon_error


def interpolate(model, img_1=None, img_2=None, intervals=10, device="cpu"):
    if img_1 is not None: # encode
        mu, logvar = model.encode(img_1)
        z_1 = reparameterize(mu, logvar)
    else: # sample z ~ N(0,I)
        z_1 = torch.randn(1, model.zdim).to(device)
    
    if img_2 is not None: # encode
        mu, logvar = model.encode(img_2)
        z_2 = reparameterize(mu, logvar)
    else: # sample z ~ N(0,I)
        z_2 = torch.randn(1, model.zdim).to(device)

    images = []
    for i in range(intervals+1):
        t = i / intervals
        z = z_1 * (1-t) + z_2 * t
        img = model.decode(z)
        images.append(img.squeeze(0))
    
    return images

## Model Definition

In [None]:
# Building Blocks
class GaussianNoise(nn.Module):
    def __init__(self, sigma=0.1, is_relative_detach=True):
        super(GaussianNoise, self).__init__()
        self.sigma = sigma
        self.is_relative_detach = is_relative_detach

    def forward(self, x):
        if self.training and self.sigma != 0:
            scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
            sampled_noise = torch.normal(mean=torch.zeros_like(x), std=torch.ones_like(x)) * scale
            x = x + sampled_noise
        return x

class AugmentLayers(nn.Module):
    def __init__(self, p_augment=0.9):
        super(AugmentLayers, self).__init__()
        self.p_augment = p_augment
        # self.Affine = K.RandomAffine(degrees=0, translate=(1 / 8, 1 / 8), p=p_augment)
        self.Erase = K.RandomErasing((0.0, 0.1), p=p_augment)

    def forward(self, x):
        # x = self.Affine(x)
        x = self.Erase(x)
        return x

class ResidualBlock(nn.Module):
    """
    https://github.com/hhb072/IntroVAE
    Difference: self.bn2 on output and not on (output + identity)
    """

    def __init__(self, inc=64, outc=64, groups=1, scale=1.0):
        super(ResidualBlock, self).__init__()

        midc = int(outc * scale)

        if inc is not outc:
            self.conv_expand = nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=1, stride=1, padding=0,
                                         groups=1, bias=False)
        else:
            self.conv_expand = None

        self.conv1 = nn.Conv2d(in_channels=inc, out_channels=midc, kernel_size=3, stride=1, padding=1, groups=groups,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(midc)
        self.relu1 = nn.LeakyReLU(0.2, inplace=True)
        self.conv2 = nn.Conv2d(in_channels=midc, out_channels=outc, kernel_size=3, stride=1, padding=1, groups=groups,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(outc)
        self.relu2 = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        if self.conv_expand is not None:
            identity_data = self.conv_expand(x)
        else:
            identity_data = x
        
        output = self.relu1(self.bn1(self.conv1(x)))
        output = self.conv2(output)
        output = self.bn2(output)
        output = self.relu2(torch.add(output, identity_data))
        return output

# Encoder
class Encoder(nn.Module):
    def __init__(self, cdim=3, zdim=512, channels=(64, 128, 256, 512, 512, 512), image_size=256, conditional=False,
                 cond_dim=10, p_enc_s=0, p_enc_e=0, nn_sigma=0, nn_gn_rel=True, p_augment=0):
        super(Encoder, self).__init__()
        self.zdim = zdim
        self.cdim = cdim
        self.image_size = image_size
        self.conditional = conditional
        self.cond_dim = cond_dim
        cc = channels[0]
        self.main = nn.Sequential(
            nn.Conv2d(cdim, cc, 5, 1, 2, bias=False),
            nn.BatchNorm2d(cc),
            nn.LeakyReLU(0.2),
            nn.AvgPool2d(2),
        )

        if p_augment > 0:
            print("Data augmentation was added")
            self.augment_module = torch.nn.Sequential(
                # K.ColorJitter(brightness=(0.7, 1.3), contrast=(0.7, 1.3), saturation=(0.7, 1.3), p=p_augment),
                K.RandomAffine(degrees=0, translate=(1 / 8, 1 / 8), p=p_augment),
                K.RandomErasing((0.0, 0.5), p=p_augment),
            )
        else:
            self.augment_module = nn.Identity()

        gn_i = 0
        if nn_sigma > 0:
            gn_i += 1
            self.main.add_module('GN_{}'.format(gn_i), GaussianNoise(sigma=nn_sigma, is_relative_detach=nn_gn_rel))

        if p_enc_s > 0:
            print("Dropout implemented in the start of the Encoder with value of: ", p_enc_s)
            self.main.add_module('dropout_1', nn.Dropout(p_enc_s))
        
        sz = image_size // 2
        for ch in channels[1:]:
            self.main.add_module('res_in_{}'.format(sz), ResidualBlock(cc, ch, scale=1.0))
            self.main.add_module('down_to_{}'.format(sz // 2), nn.AvgPool2d(2))
            cc, sz = ch, sz // 2
            if nn_sigma > 0:
                gn_i += 1
                self.main.add_module('GN_{}'.format(gn_i), GaussianNoise(sigma=nn_sigma, is_relative_detach=nn_gn_rel))


        self.main.add_module('res_in_{}'.format(sz), ResidualBlock(cc, cc, scale=1.0))

        if nn_sigma > 0:
            gn_i += 1
            self.main.add_module('GN_{}'.format(gn_i), GaussianNoise(sigma=nn_sigma, is_relative_detach=nn_gn_rel))
            print(f"Gaussian noise added to {gn_i} layers of the encoder with sigma = {nn_sigma}, is_relative_detach={nn_gn_rel}")

        if p_enc_e > 0:
            print("Dropout implemented in the end of the encoder with value of: ", p_enc_e)
            self.main.add_module('dropout_2', nn.Dropout(p_enc_e))

        self.conv_output_size = self.calc_conv_output_size()
        num_fc_features = torch.zeros(self.conv_output_size).view(-1).shape[0]
        # print("conv shape: ", self.conv_output_size)
        # print("num fc features: ", num_fc_features)
        if self.conditional:
            self.fc = nn.Linear(num_fc_features + self.cond_dim, 2 * zdim)
        else:
            self.fc = nn.Linear(num_fc_features, 2 * zdim)

    def calc_conv_output_size(self):
        dummy_input = torch.zeros(1, self.cdim, self.image_size, self.image_size)
        dummy_input = self.main(dummy_input)
        return dummy_input[0].shape

    def forward(self, x, o_cond=None):
        if self.training:
            x = self.augment_module(x)

        y = self.main(x).view(x.size(0), -1)
        if self.conditional and o_cond is not None:
            y = torch.cat([y, o_cond], dim=1)
        y = self.fc(y)
        mu, logvar = y.chunk(2, dim=1)
        return mu, logvar

# Decoder
class Decoder(nn.Module):
    def __init__(self, cdim=3, zdim=512, channels=(64, 128, 256, 512, 512, 512), image_size=256, conditional=False,
                 conv_input_size=None, cond_dim=10, p_dec_s=0, p_dec_e=0, nn_sigma=0, nn_gn_rel=True):
        super(Decoder, self).__init__()
        self.cdim = cdim
        self.image_size = image_size
        self.conditional = conditional
        cc = channels[-1]
        self.conv_input_size = conv_input_size
        if conv_input_size is None:
            num_fc_features = cc * 4 * 4
        else:
            num_fc_features = torch.zeros(self.conv_input_size).view(-1).shape[0]
        self.cond_dim = cond_dim
        if self.conditional:
            self.fc = nn.Sequential(
                nn.Linear(zdim + self.cond_dim, num_fc_features),
                nn.ReLU(True),
            )
        else:
            self.fc = nn.Sequential(
                nn.Linear(zdim, num_fc_features),
                nn.ReLU(True),
            )

        gn_i = 0
        if nn_sigma > 0:
            gn_i += 1
            self.fc.add_module('GN_{}'.format(gn_i), GaussianNoise(sigma=nn_sigma, is_relative_detach=nn_gn_rel))

        if p_dec_s > 0:
            print("Dropout implemented in the start of the Decoder with value of: ", p_dec_s)
            self.fc.add_module('dropout_1', nn.Dropout(p_dec_s))
        
        sz = 4

        self.main = nn.Sequential()
        for ch in channels[::-1]:
            self.main.add_module('res_in_{}'.format(sz), ResidualBlock(cc, ch, scale=1.0))
            self.main.add_module('up_to_{}'.format(sz * 2), nn.Upsample(scale_factor=2, mode='nearest'))
            cc, sz = ch, sz * 2
            if nn_sigma > 0:
                gn_i += 1
                self.main.add_module('GN_{}'.format(gn_i), GaussianNoise(sigma=nn_sigma, is_relative_detach=nn_gn_rel))

        if p_dec_e > 0:
            print("Dropout implemented in the end of the Decoder with value of: ", p_dec_e)
            self.main.add_module('dropout_2', nn.Dropout(p_dec_e))

        self.main.add_module('res_in_{}'.format(sz), ResidualBlock(cc, cc, scale=1.0))

        if nn_sigma > 0:
            gn_i += 1
            self.main.add_module('GN_{}'.format(gn_i), GaussianNoise(sigma=nn_sigma, is_relative_detach=nn_gn_rel))
            print(f"Gaussian noise added to {gn_i} layers of the decoder with sigma = {nn_sigma}, is_relative_detach={nn_gn_rel}")

        self.main.add_module('predict', nn.Conv2d(cc, cdim, 5, 1, 2))

    def forward(self, z, y_cond=None):
        z = z.view(z.size(0), -1)
        if self.conditional and y_cond is not None:
            y_cond = y_cond.view(y_cond.size(0), -1)
            z = torch.cat([z, y_cond], dim=1)
        y = self.fc(z)
        y = y.view(z.size(0), *self.conv_input_size)
        y = self.main(y)
        return y

# Soft-IntroVAE
class SoftIntroVAE(nn.Module):
    def __init__(self, cdim=3, zdim=512, channels=(64, 128, 256, 512, 512, 512), image_size=256, conditional=False,
                 cond_dim=10, p_enc_s=0, p_enc_e=0, p_dec_s=0, p_dec_e=0, nn_sigma_enc=0, nn_sigma_dec=0, nn_gn_rel=True,
                 p_augment=0):
        super(SoftIntroVAE, self).__init__()

        self.zdim = zdim
        self.conditional = conditional
        self.cond_dim = cond_dim

        self.encoder = Encoder(cdim, zdim, channels, image_size, conditional=conditional, cond_dim=cond_dim,
                               p_enc_s=p_enc_s, p_enc_e=p_enc_e, nn_sigma=nn_sigma_enc, nn_gn_rel=nn_gn_rel, p_augment=p_augment)

        self.decoder = Decoder(cdim, zdim, channels, image_size, conditional=conditional, conv_input_size=self.encoder.conv_output_size, cond_dim=cond_dim,
                               p_dec_s=p_dec_s, p_dec_e=p_dec_e, nn_sigma=nn_sigma_dec, nn_gn_rel=nn_gn_rel)

    def forward(self, x, o_cond=None, deterministic=False):
        if self.conditional and o_cond is not None:
            mu, logvar = self.encode(x, o_cond=o_cond)
            if deterministic:
                z = mu
            else:
                z = reparameterize(mu, logvar)
            y = self.decode(z, y_cond=o_cond)
        else:
            mu, logvar = self.encode(x)
            if deterministic: 
                z = mu
            else:
                z = reparameterize(mu, logvar)
            y = self.decode(z)
        return mu, logvar, z, y

    def sample(self, z, y_cond=None):
        y = self.decode(z, y_cond=y_cond)
        return y

    def sample_with_noise(self, num_samples=1, device=torch.device("cpu"), y_cond=None):
        z = torch.randn(num_samples, self.zdim).to(device)
        return self.decode(z, y_cond=y_cond)

    def encode(self, x, o_cond=None):
        if self.conditional and o_cond is not None:
            mu, logvar = self.encoder(x, o_cond=o_cond)
        else:
            mu, logvar = self.encoder(x)
        return mu, logvar

    def decode(self, z, y_cond=None):
        if self.conditional and y_cond is not None:
            y = self.decoder(z, y_cond=y_cond)
        else:
            y = self.decoder(z)
        return y

## Analysis

### Load Model

In [None]:
from google.colab import drive
drive.mount('/content/drive') # give notebook access to google drive to load/save checkpoints
base_path = './drive/MyDrive/Colab Notebooks/Project B/' # enter path where checkpoint is saved, example: './drive/checkpoints/'
# checkpoint = 'cifar10_soft_intro_betas_1.0_256.0_1.0_fid_3.0722174215240443_model_epoch_400_iter_625200.pth' # enter checkpoint name, example: 'cifar10_soft_intro_model_epoch_400.pth'
# path = base_path + checkpoint

In [None]:
# Parameters
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dataset = 'cifar10' # Choose dataset: ['cifar10', 'mnist']


z_dim = 128

if dataset == 'cifar10':
  image_size = 32
  channels = [64, 128, 256]
  ch = 3
  checkpoint      = "cifar10_soft_intro_betas_1.0_256.0_1.0_fid_2.9929277083944044_model_epoch_660_iter_1031580.pth" # enter checkpoint name, example: 'cifar10_soft_intro_model_epoch_400.pth'
  checkpoint_orig = "cifar10_soft_intro_betas_1.0_256.0_1.0_fid_3.676469025118422_model_epoch_360_iter_562680.pth"
elif dataset == 'mnist':
  image_size = 28
  channels = [64, 128]
  ch = 1
  checkpoint      = "mnist_soft_intro_betas_1.0_256.0_1.0_model_epoch_450_iter_843750.pth" # enter checkpoint name, example: 'mnist_soft_intro_model_epoch_400.pth'
  checkpoint_orig = "mnist_orig_soft_intro_betas_1.0_256.0_1.0_model_epoch_400_iter_750000.pth" # enter checkpoint name, example: 'mnist_soft_intro_orig_model_epoch_400.pth'
else:
  raise NotImplementedError("Dataset is not supported")

In [None]:
# Load our model
model = SoftIntroVAE(cdim=ch, zdim=z_dim, channels=channels, image_size=image_size).to(device)
load_model(model, base_path + checkpoint, device)
# model.eval()
# Load original model for comparison
model_orig = SoftIntroVAE(cdim=ch, zdim=z_dim, channels=channels, image_size=image_size).to(device)
load_model(model_orig, base_path + checkpoint_orig, device)
# model_orig.eval()
# Load data
if dataset == 'cifar10':
    train_data = CIFAR10(root='./cifar10_ds', train=True, download=True, transform=transforms.ToTensor())  
    test_data  = CIFAR10(root='./cifar10_ds', train=False, download=True, transform=transforms.ToTensor())  
elif dataset == 'mnist':
    train_data = MNIST(root='./mnist_ds', train=True, download=True, transform=transforms.ToTensor())  
    test_data  = MNIST(root='./mnist_ds', train=False, download=True, transform=transforms.ToTensor())  

### Latent Space

#### Latent Map Analysis

In [None]:
# Encode training data
data_loader = DataLoader(train_data, batch_size=1000, shuffle=True, num_workers=2)
latent_reps = []
labels = []
for i, batch in enumerate(data_loader):
    batch[0].requires_grad = False
    batch[1].requires_grad = False
    imgs = batch[0].to(device)
    lbls = batch[1]
    mu, logvar = model.encode(imgs)
    z = reparameterize(mu, logvar)
    latent_reps.extend(z.detach().numpy())
    labels.extend(lbls.detach().numpy())
    if i == 0:
        break

print(len(latent_reps))
latent_reps = np.asarray(latent_reps)
print(latent_reps.shape)
labels = np.asarray(labels)
print(labels.shape)

if dataset == 'cifar10':
    label_name = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
elif dataset == 'mnist':
    label_name = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

color = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']
colors = [color[l] for l in labels]

# use PCA to vizualize latent representations in 2D
map = PCA(n_components=2).fit_transform(latent_reps)
fig, ax = plt.subplots()
fig.set_size_inches(7, 7)
ax.scatter(map[:, 0], map[:, 1], c=colors, marker='.')
ax.set_xlabel("PC1")
ax.set_ylabel("PC2")
ax.set_title("Latent Representation - PCA")
ax.legend()
legend_elements = []
for i in range(len(label_name)):
    legend_elements.append(Line2D([0], [0], color=color[i], label=label_name[i]))
ax.legend(handles=legend_elements)
plt.show()

# use KPCA to vizualize latent representations in 2D
map = KernelPCA(n_components=2, kernel='cosine', gamma=1/z_dim).fit_transform(latent_reps)
fig, ax = plt.subplots()
fig.set_size_inches(7, 7)
ax.scatter(map[:, 0], map[:, 1], c=colors, marker='.')
ax.set_xlabel("PC1")
ax.set_ylabel("PC2")
ax.set_title("Latent Representation - KPCA")
ax.legend()
legend_elements = []
for i in range(len(label_name)):
    legend_elements.append(Line2D([0], [0], color=color[i], label=label_name[i]))
ax.legend(handles=legend_elements)
plt.show()

for p in [4, 10, 50, np.sqrt(len(latent_reps))]:
# use T-SNE to vizualize latent representations in 2D
    map = TSNE(n_components=2, perplexity=p, learning_rate='auto', init='pca').fit_transform(latent_reps)
    fig, ax = plt.subplots()
    fig.set_size_inches(7, 7)
    ax.scatter(map[:, 0], map[:, 1], c=colors, marker='.')
    ax.set_xlabel("Mapped dim1")
    ax.set_ylabel("Mapped dim2")
    ax.set_title("Latent Representation - TSNE")
    ax.legend()
    legend_elements = []
    for i in range(len(label_name)):
        legend_elements.append(Line2D([0], [0], color=color[i], label=label_name[i]))
    ax.legend(handles=legend_elements)
    plt.show()

#### Interpolation in the Latent Space

In [None]:
# Choose source of images to interpolate between: ['trainset', 'testset', 'random', 'cherry']
img_src_1 = 'trainset'
img_src_2 = 'trainset'

# if 'cherry' choose index
idx1 = 12012
idx2 = 11933

train_data_loader = DataLoader(train_data, batch_size=1, shuffle=True, num_workers=1)
test_data_loader  = DataLoader(test_data, batch_size=1, shuffle=True, num_workers=1)

intervals = 20 # choose number of intervals for linear interpolation

In [None]:
# prepare image 1
if img_src_1 == 'trainset': # take image from train set
    img_1 = next(iter(train_data_loader))
    img_1 = img_1[0].to(device)

elif img_src_1 == 'testset': # take image from test set
    img_1 = next(iter(test_data_loader))
    img_1 = img_1[0].to(device)

elif img_src_1 == 'random': # generate image by sampling z ~ N(0,I)
    img_1 = None

elif img_src_1 == 'cherry':
    img_1 = train_data[idx1][0].unsqueeze(0)

else:
    raise NotImplementedError("Image source is not supported")

# prepare image 1
if img_src_2 == 'trainset': # take image from train set
    img_2 = next(iter(train_data_loader))
    img_2 = img_2[0].to(device)

elif img_src_2 == 'testset': # take image from test set
    img_2 = next(iter(test_data_loader))
    img_2 = img_2[0].to(device)

elif img_src_2 == 'random': # generate image by sampling z ~ N(0,I)
    img_2 = None

elif img_src_2 == 'cherry':
    img_2 = train_data[idx2][0].unsqueeze(0)

else:
    raise NotImplementedError("Image source is not supported")

# interpolate
images = interpolate(model, img_1, img_2, intervals-1, device)
if img_src_1 in ['trainset', 'testset', 'cherry']:
    images[0] = img_1.squeeze(0)
if img_src_2 in ['trainset', 'testset', 'cherry']:
    images[-1] = img_2.squeeze(0)

# save images
if dataset == 'cifar10':
    path = 'figures/cifar10/cifar10_interpolation.jpg'
elif dataset == 'mnist':
    path = 'figures/mnist/mnist_interpolation.jpg'
vutils.save_image(images, base_path + path, nrow=int(intervals/4))

# show images
to_img = transforms.ToPILImage()
fig, axes = plt.subplots(4, int(intervals/4), squeeze=True, figsize=(10, 8))
for i, ax in enumerate(axes.flatten()):
    img = to_img(images[i])
    ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    ax.imshow(img)

#### Latent Space Arithmetic

In [None]:
# Choose source of images: ['trainset', 'testset']
img_src = 'trainset'

In [None]:
# Prepare data
if img_src == 'trainset':
    data = train_data
else:
    data = test_data
data_loader = DataLoader(data, batch_size=3, shuffle=True, num_workers=2)
image = next(iter(data_loader))
# if dataset == 'cifar10':
image = image[0].to(device)

# image = torch.empty(3, 3, 32, 32)
# for i, idx in enumerate([24543, 48246, 6041]):
#     image[i] = train_data[idx][0]

# Encode
mu, logvar = model.encode(image)
z = reparameterize(mu, logvar)
# Perform Arithmetic Operation
z_new = z[0] - z[1] + z[2]
# Decode
image_new = model.decode(z_new.unsqueeze(0))
images = torch.cat([image, image_new])

# Save Images
if dataset == 'cifar10':
    path = 'figures/cifar10/cifar10_arithmetic.jpg'
elif dataset == 'mnist':
    path = 'figures/mnist/mnist_arithmetic.jpg'
vutils.save_image(images.data.cpu(), base_path + path, nrow=4)

# Show Images
to_img = transforms.ToPILImage()
fig, axes = plt.subplots(1, 4, squeeze=True, figsize=(10, 10))
for i, ax in enumerate(axes.flatten()):
    img = to_img(images[i])
    ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    ax.imshow(img)

### Generation

#### Image Generation

In [None]:
# Parameters
b_size = 14

In [None]:
# Create generated image
noise_batch = torch.randn(size=(b_size, z_dim)).to(device)
images = model.sample(noise_batch)

# Save Images
if dataset == 'cifar10':
    path = 'figures/cifar10/cifar10_gen.jpg'
elif dataset == 'mnist':
    path = 'figures/mnist/mnist_gen.jpg'
vutils.save_image(images.data.cpu(), base_path + path, nrow=int(b_size/2))

# Show Images
to_img = transforms.ToPILImage()
fig, axes = plt.subplots(2, int(b_size/2), squeeze=True, figsize=(b_size, 4))
for i, ax in enumerate(axes.flatten()):
    img = to_img(images[i])
    ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    ax.imshow(img)

#### Recurring Encoding Decoding for Generated Image Improvement

In [None]:
# Parameters
b_size = 10
num_cycles = 10

In [None]:
# Create generated image
noise_batch = torch.randn(size=(b_size, z_dim)).to(device)
generated_initial = model.sample(noise_batch)

# Encode-Decode image for num_cycles
generated_rec = generated_initial
for i in range(num_cycles):
    _, _, _, generated_rec = model(generated_rec)
images = torch.cat((generated_initial, generated_rec))

# Save Images
if dataset == 'cifar10':
    path = 'figures/cifar10/cifar10_rec_gen.jpg'
elif dataset == 'mnist':
    path = 'figures/mnist/mnist_rec_gen.jpg'
vutils.save_image(images.data.cpu(), base_path + path, nrow=b_size)

# Show Images
to_img = transforms.ToPILImage()
fig, axes = plt.subplots(2, b_size, squeeze=True, figsize=(2*b_size, 4))
for i, ax in enumerate(axes.flatten()):
    img = to_img(images[i])
    ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    ax.imshow(img)

### Reconstruction

#### Image Reconstruction

In [None]:
# Parameters
b_size = 8

# Choose source of images: ['trainset', 'testset']
img_src = 'trainset'

In [None]:
# Prepare data
if img_src == 'trainset':
    data = train_data
else:
    data = test_data
data_loader = DataLoader(data, batch_size=b_size, shuffle=True, num_workers=2)
image = next(iter(data_loader))
# if dataset == 'cifar10':
image = image[0].to(device)

# Reconstruct
_, _, _, image_rec = model(image)
_, _, _, image_rec_orig = model_orig(image)
images = torch.cat((image, image_rec, image_rec_orig))

# Save Images
if dataset == 'cifar10':
    path = 'figures/cifar10/cifar10_rec.jpg'
elif dataset == 'mnist':
    path = 'figures/mnist/mnist_rec.jpg'
vutils.save_image(images.data.cpu(), base_path + path, nrow=b_size)

# Show Images
to_img = transforms.ToPILImage()
fig, axes = plt.subplots(3, b_size, squeeze=True, figsize=(3*b_size, 9))
for i, ax in enumerate(axes.flatten()):
    img = to_img(images[i])
    ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    ax.imshow(img)

#### Partially Erased Image Reconstruction

In [None]:
# Parameters
b_size = 7
p_augment = 1
num_cycles = 1

# Choose source of images: ['trainset', 'testset']
img_src = 'trainset'

# Choose form of erasing: ['block', 'pixels']
erase = 'block'

In [None]:
# Prepare data
if img_src == 'trainset':
    data = train_data
else:
    data = test_data
data_loader = DataLoader(data, batch_size=b_size, shuffle=True, num_workers=2)
image = next(iter(data_loader))
image = image[0].to(device)

# Erase part of image
if erase == 'block':
    AugModule = AugmentLayers(p_augment)
    image_erased = AugModule(image)
elif erase == 'pixels':
    zero_indices = torch.randint(low=0, high=image_size, size=(image.shape[0], int(0.2*np.square(image.shape[2])), 2))
    mask = torch.ones_like(image)
    for n in range(image.shape[0]):
        for i, j in zero_indices[n]:
            mask[n, :, i, j] = 0
    image_erased = image * mask
else:
    raise NotImplementedError("Method is not supported")

# Reconstruct
image_rec = image_erased
image_rec_orig = image_erased

for i in range(num_cycles):
  _, _, _, image_rec = model(image_rec)
  _, _, _, image_rec_orig = model_orig(image_rec_orig)

images = torch.cat((image, image_erased, image_rec_orig, image_rec))

# Save Images
if dataset == 'cifar10':
    path = 'figures/cifar10/cifar10_erased_rec.jpg'
elif dataset == 'mnist':
    path = 'figures/mnist/mnist_erased_rec.jpg'
vutils.save_image(images.data.cpu(), base_path + path, nrow=b_size)

# Show Images
to_img = transforms.ToPILImage()
fig, axes = plt.subplots(4, b_size, squeeze=True, figsize=(2*b_size, 8))
for i, ax in enumerate(axes.flatten()):
    img = to_img(images[i])
    ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    ax.imshow(img)