In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable

import numpy as  np


class CNN_Encoder(nn.Module):
    def __init__(self, output_size, input_size=(1, 28, 28)):
        super(CNN_Encoder, self).__init__()

        self.input_size = input_size
        self.channel_mult = 16

        #convolutions
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=1,
                     out_channels=self.channel_mult*1,
                     kernel_size=4,
                     stride=1,
                     padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(self.channel_mult*1, self.channel_mult*2, 4, 2, 1),
            nn.BatchNorm2d(self.channel_mult*2),
            nn.LeakyReLU(),
            nn.Conv2d(self.channel_mult*2, self.channel_mult*4, 4, 2, 1),
            nn.BatchNorm2d(self.channel_mult*4),
            nn.LeakyReLU(),
            nn.Conv2d(self.channel_mult*4, self.channel_mult*8, 4, 2, 1),
            nn.BatchNorm2d(self.channel_mult*8),
            nn.LeakyReLU(),
            nn.Conv2d(self.channel_mult*8, self.channel_mult*16, 3, 2, 1),
            nn.BatchNorm2d(self.channel_mult*16),
            nn.LeakyReLU()
        )

        self.flat_fts = self.get_flat_fts(self.conv)

        self.linear = nn.Sequential(
            nn.Linear(self.flat_fts, output_size),
            nn.BatchNorm1d(output_size),
            nn.LeakyReLU(),
        )

    def get_flat_fts(self, fts):
        f = fts(Variable(torch.ones(1, *self.input_size)))
        return int(np.prod(f.size()[1:]))

    def forward(self, x):
        x = self.conv(x.view(-1, *self.input_size))
        x = x.view(-1, self.flat_fts)
        return self.linear(x)

class CNN_Decoder(nn.Module):
    def __init__(self, embedding_size, input_size=(1, 28, 28)):
        super(CNN_Decoder, self).__init__()
        self.input_height = 28
        self.input_width = 28
        self.input_dim = embedding_size
        self.channel_mult = 16
        self.output_channels = 1
        self.fc_output_dim = 512

        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, self.fc_output_dim),
            nn.BatchNorm1d(self.fc_output_dim),
            nn.ELU()
        )

        self.deconv = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(self.fc_output_dim, self.channel_mult*4,
                                4, 1, 0, bias=False),
            nn.BatchNorm2d(self.channel_mult*4),
            nn.ELU(),
            # state size. self.channel_mult*32 x 4 x 4
            nn.ConvTranspose2d(self.channel_mult*4, self.channel_mult*2,
                                3, 2, 1, bias=False),
            nn.BatchNorm2d(self.channel_mult*2),
            nn.ELU(),
            # state size. self.channel_mult*16 x 7 x 7
            nn.ConvTranspose2d(self.channel_mult*2, self.channel_mult*1,
                                4, 2, 1, bias=False),
            nn.BatchNorm2d(self.channel_mult*1),
            nn.ELU(),
            # state size. self.channel_mult*8 x 14 x 14
            nn.ConvTranspose2d(self.channel_mult*1, self.output_channels, 4, 2, 1, bias=False),
            nn.Sigmoid()
            # state size. self.output_channels x 28 x 28
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(-1, self.fc_output_dim, 1, 1)
        x = self.deconv(x)
        return x.view(-1, self.input_width*self.input_height)

In [None]:
class VAE_MNIST(nn.Module):
    def __init__(self, output_size=512, embedding_size=16):
        super(VAE_MNIST, self).__init__()
        output_size = 512
        self.encoder = CNN_Encoder(output_size)
        self.var = nn.Linear(output_size, embedding_size)
        self.mu = nn.Linear(output_size, embedding_size)

        self.decoder = CNN_Decoder(embedding_size)

    def encode(self, x):
        x = self.encoder(x)
        mu = self.mu(x)
        var = self.var(x)
        if self.training:
            z = self.reparameterize(mu, var)
        else:
            z = mu
        return z, mu, var

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        z, mu, logvar = self.encode(x.view(-1, 784))
        return self.decode(z), mu, logvar

    def loss_function(self, recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD

In [None]:

class ConvBlock (nn.Sequential):
    def __init__ (self, in_c, out_c, kernel_size, stride=1):
        super().__init__()
        self.add_module('Convolution', nn.utils.spectral_norm(nn.Conv2d(in_c, out_c, kernel_size, stride)))
        self.add_module('BatchNorm', nn.BatchNorm2d(out_c, affine=True))
        self.add_module('Activation', nn.ELU())

class ConvTransposeBlock (nn.Sequential):
    def __init__ (self, in_c, out_c, kernel_size, stride=1):
        super().__init__()
        self.add_module('ConvTranspose', nn.utils.spectral_norm(nn.ConvTranspose2d(in_c, out_c, kernel_size, stride)))
        self.add_module('BatchNorm', nn.BatchNorm2d(out_c, affine=True))
        self.add_module('Activation', nn.ELU())


class Encoder (nn.Module):
    """
    Encoder of 2D VAE, producing latent multivariate normal distributions from input images. We return logarithmic variances as they have the real numbers as domain.

    Forward pass:
        1 Input:
            i)  Image of shape [N, C, H, W] (by default 1x28x28 MNIST images)
        2 Outputs:
            i)  Means of latent distributions of shape [N, latent_dim]
            ii) Logarithmic variances of latent distribution of shape [N, latent_dim] (Approximation of multivariate Gaussian, covariance is strictly diagonal, i.e. [N, d, d] is now [N, d])

    Arguments:
        X_dim (list) : dimensions of input 2D image, in the form of [Channels, Height, Width]
        latent_dim (int) : dimension of latent space.
    """
    def __init__(self, X_dim=[1,28,28], latent_dim=16):
        super(Encoder, self).__init__()

        conv1_outchannels = 32
        conv2_outchannels = 32

        # How the convolutions change the shape
        conv_outputshape = (
            conv2_outchannels
            * int(((X_dim[1]-4)/2 - 2)/2 + 1)
            * int(((X_dim[2]-4)/2 - 2)/2 + 1)
        )

        self.enc = nn.Sequential(
            ConvBlock(X_dim[0], conv1_outchannels, kernel_size=4, stride=2),
            ConvBlock(conv1_outchannels, conv2_outchannels, kernel_size=3, stride=2)
        )

        self.zmean = nn.Linear(conv_outputshape, latent_dim)
        self.zlogvar = nn.Linear(conv_outputshape, latent_dim)


    def forward (self, X):
        x = self.enc(X)
        x = x.view(X.shape[0], -1)
        mean = self.zmean(x)
        logvar = self.zlogvar(x)

        return mean, logvar


class Decoder (nn.Module):
    """
    Decoder of 2D VAE, producing output multivariate normal distributions from latent vectors. We return logarithmic variances as they have the real numbers as domain.

    Forward pass:
        1 Input:
            i)  Latent vector of shape [N, latent_dim]
        2 Outputs:
            i)  Means of output distributions of shape [N, C, H, W]
            ii) Variances of output distribution of shape [N, C, H, W] (Approximation of multivariate Gaussian, covariance is strictly diagonal). We assume constant variance during VAE training.

    Arguments:
        X_dim (list) : dimensions of input 2D image, in the form of [Channels, Height, Width]
        latent_dim (int) : dimension of latent space.
    """
    def __init__(self, X_dim=[1,28,28], latent_dim=16):
        super(Decoder, self).__init__()

        # Currently number of clusters is set to 4*latent_dim as a good rule of thumb, can be changed, but then also change it later during training!
        #self.improved_variance = True
        #k = 4*latent_dim
        #self.rbfNN = RBF(centers=pt.zeros(k,latent_dim), bandwidth=pt.zeros(k), X_dim=np.prod(X_dim))

        conv1_outchannels = 32
        conv2_outchannels = 32

        # How the convolutions change the shape
        self.conv_outputshape = (
            int(((X_dim[1]-4)/2 - 2)/2 + 1),
            int(((X_dim[2]-4)/2 - 2)/2 + 1)
        )

        self.lin = nn.Linear(latent_dim, conv2_outchannels*self.conv_outputshape[0]*self.conv_outputshape[1])

        self.conv = nn.Sequential(
            ConvTransposeBlock(conv2_outchannels, conv1_outchannels, kernel_size=3, stride=2),
            ConvTransposeBlock(conv1_outchannels, 32, kernel_size=4, stride=2)
        )

        self.Xmean = nn.Sequential(
            nn.Conv2d(32, X_dim[0], kernel_size=1),
            # Output is grayscale between -1 and 1
            nn.Tanh()
        )


    def forward (self, z):
        """
        When improved_variance is set to True, we use a trained RBF to return a better variance estimate as described in "Arvanitidis et al. (2018): Latent Space Oddity". Of course the RBF has to be assigned to the Decoder first.
        """
        x = self.lin(z)
        x = x.view(z.shape[0], -1, self.conv_outputshape[0], self.conv_outputshape[1])
        x = self.conv(x)
        mean = self.Xmean(x)

        #if not self.improved_variance:
            # We freeze the variance as constant 0.5. Requires grad so metric computation goes smoothly (i.e. returns no gradient)
        #    var = pt.ones_like(mean, requires_grad=True) * 0.5
        #else:
        #    var = 1/self.rbfNN(z)

        return mean#, var

class VAE_MNIST(nn.Module):
    def __init__(self, output_size=512, embedding_size=16):
        super(VAE_MNIST, self).__init__()
        output_size = 512
        self.encoder = Encoder()
        #self.var = nn.Linear(output_size, embedding_size)
        #self.mu = nn.Linear(output_size, embedding_size)

        self.decoder = Decoder()

    def encode(self, x):
        mu, var = self.encoder(x)
        #mu = self.mu(x)
        #var = self.var(x)
        if self.training:
            z = self.reparameterize(mu, var)
        else:
            z = mu
        return z, mu, var

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        z, mu, logvar = self.encode(x.view(-1, 784))
        return self.decode(z), mu, logvar

    def loss_function(self, recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as Datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.models as models
import torchvision.utils as vutils
from torch.hub import load_state_dict_from_url

import os
import random
import numpy as np
import math
from IPython.display import clear_output
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import trange, tqdm

latent_channels = 16
batch_size = 128
lr = 0.005
nepoch = 20
start_epoch = 0
dataset_root = ""
save_dir = os.getcwd()
model_name = "MNIST_VAE__black_lrelu" + "_" + str(latent_channels)
load_checkpoint  = False
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"


use_cuda = torch.cuda.is_available()
gpu_indx  = 0
device = torch.device(gpu_indx if use_cuda else "cpu")


In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import numpy as np

class BlackImagesDataset(Dataset):
    def __init__(self, num_images, image_size, label):
        self.num_images = num_images
        self.image_size = image_size
        self.label = label

    def __len__(self):
        return self.num_images

    def __getitem__(self, idx):
        image = torch.zeros(self.image_size, dtype=torch.float32)
        label = self.label
        return image, label

class WhiteImagesDataset(Dataset):
    def __init__(self, num_images, image_size, label):
        self.num_images = num_images
        self.image_size = image_size
        self.label = label

    def __len__(self):
        return self.num_images

    def __getitem__(self, idx):
        image = torch.ones(self.image_size, dtype=torch.float32)
        label = self.label
        return image, label

def prepare_datasets(num_black_images=50):
    # Load MNIST dataset
    transform = transforms.Compose([transforms.ToTensor()])
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    # Create dataset of black images (label 10 for black images)
    black_train_dataset = BlackImagesDataset(num_black_images, (1, 28, 28), 10)
    black_test_dataset = BlackImagesDataset(num_black_images, (1, 28, 28), 10)

    white_train_dataset = WhiteImagesDataset(num_black_images, (1, 28, 28), 11)
    white_test_dataset = WhiteImagesDataset(num_black_images, (1, 28, 28), 11)

    # Concatenate MNIST with black images dataset
    train_dataset = ConcatDataset([train_dataset, black_train_dataset, white_train_dataset])
    test_dataset = ConcatDataset([test_dataset, black_test_dataset, white_test_dataset])

    return train_dataset, test_dataset



# Prepare datasets
train_dataset, test_dataset = prepare_datasets()

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=True)

# Now you can use train_loader and test_loader for training and testing your VAE


In [None]:
import torchvision
from torchvision import transforms
# Define a transform to normalize the data
save_dir = os.getcwd()

transform = transforms.Compose([transforms.ToTensor()])
# Download and load the training data
trainset = torchvision.datasets.MNIST(save_dir, download=True, train=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
testset = torchvision.datasets.MNIST(save_dir, download=True, train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)

import os


In [None]:
dataiter = iter(test_loader)
test_images, test_labels = next(dataiter)
test_images[0]
train_iter = iter(train_loader)
train_images, train_labels = next(train_iter)


In [None]:
plt.figure(figsize = (5,10))
out = vutils.make_grid(test_images[0:], normalize=True)
plt.imshow(out.numpy().transpose((1, 2, 0)))

In [None]:
test_labels[2].dtype

In [None]:
import os
vae_net = VAE_MNIST().to(device)

optimizer = optim.Adam(vae_net.parameters(), lr=lr, weight_decay=1e-4)
#Loss function
loss_log = []
from torchsummary import summary
summary(vae_net, (1, 28, 28))
vae_net.eval()
recon_img, mu, logvar = vae_net(test_images[0].view(-1,1,28,28).to(device))
recon_img.shape
recon_img.view(1,1,28,28)

In [None]:
num_epochs = 20

for epoch in trange(start_epoch, num_epochs, leave=False):
    vae_net.train()
    train_loss = 0
    for i, (images, _) in enumerate(tqdm(train_loader, leave=False)):
        images = images.to(device)

        recon_img, mu, logvar = vae_net(images)

        loss = vae_net.loss_function(recon_img, images, mu, logvar)

        vae_net.zero_grad()
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    avg_loss = train_loss / len(train_loader)  # Calculate average loss for the epoch
    loss_log.append(avg_loss)  # Append average loss to loss_log
    #In eval mode the model will use mu as the encoding instead of sampling from the distribution
    vae_net.eval()
    with torch.no_grad():
        recon_img, _, _ = vae_net(test_images.to(device))
        img_cat = torch.cat((recon_img.view(-1,1,28,28).cpu(), test_images), 2)

        vutils.save_image(img_cat,
                          "%s/%s/%s_%d.png" % (save_dir, "Results" , model_name, 28),
                          normalize=True)

        #Save a checkpoint
        torch.save({
                    'epoch'                         : epoch,
                    'loss_log'                      : loss_log,
                    'model_state_dict'              : vae_net.state_dict(),
                    'optimizer_state_dict'          : optimizer.state_dict()

                     }, save_dir + "/Models/" + model_name  + ".pt")
    print(f'Epoch {epoch}/{num_epochs} - Avg Loss: {avg_loss}')


In [None]:
use_cuda = torch.cuda.is_available()
gpu_indx  = 0
device = torch.device(gpu_indx if use_cuda else "cpu")
image_size = 28

vae_net = VAE_MNIST().to(device)

# setup optimizer

checkpoint = torch.load(save_dir + "/Models/" + model_name + ".pt")
print("Checkpoint loaded")
vae_net.load_state_dict(checkpoint['model_state_dict'])

vae_net.eval()

In [None]:
image_size = 28
vae_net.eval()
def compute_etta(model, zi, zi_minus, zi_plus, dt):
    # Compute the finite difference
    g_zi_minus = model.decode(zi_minus).view(-1)
    g_zi = model.decode(zi).view(-1)
    g_zi_plus = model.decode(zi_plus).view(-1)

    finite_diff = (g_zi_plus - 2 * g_zi + g_zi_minus) / dt
    finite_diff = finite_diff.view(1, 1, image_size, image_size)  # Reshape it to match the encoder's output shape

    # Define a wrapper function for the encoder, so it can handle just the required output
    def partial_encoder(input_data):
        return model.encode(input_data)[0]



    # Compute Jacobian-vector product
    vjp_outputs = torch.autograd.functional.jvp(partial_encoder, g_zi.view(1, 1, image_size, image_size), finite_diff)

    # Get the result from the vjp outputs
    Jv = vjp_outputs[1].view_as(zi)

    # Compute etta_i
    etta_i = -Jv

    # Free up memory
    del g_zi_minus, g_zi, g_zi_plus, finite_diff, Jv, vjp_outputs
    torch.cuda.empty_cache()

    return etta_i

In [None]:
def compute_etta_d(model, zi, zi_minus, zi_plus, dt):
    # Compute the finite difference
    g_zi_minus = model.decode(zi_minus).view(-1)
    g_zi = model.decode(zi).view(-1)
    g_zi_plus = model.decode(zi_plus).view(-1)

    finite_diff = (g_zi_plus - 2 * g_zi + g_zi_minus) / dt
    finite_diff = finite_diff.view(1, 1, image_size, image_size)  # Reshape it to match the encoder's output shape

    scaled_finite_diff = 0.1 * finite_diff

    # Compute vector-Jacobian product
    vjp_outputs = torch.autograd.functional.vjp(model.decode, zi, finite_diff.view(1,-1))
    # Get the result from the vjp outputs
    Jv = vjp_outputs[1].view_as(zi)

    # Compute etta_i
    etta_i = -Jv #+ scaled_finite_diff)

    # Free up memory
    del g_zi_minus, g_zi, g_zi_plus, finite_diff, Jv, vjp_outputs
    torch.cuda.empty_cache()

    return etta_i


def sum_of_etta_norms(model, z_collection, dt):
    norms = []
    for j in range(1, len(z_collection) - 1):
        etta_j = compute_etta_d(model, z_collection[j], z_collection[j-1], z_collection[j+1], dt)
        norms.append(etta_j.norm().pow(2).item())
        #del etta_j
        torch.cuda.empty_cache()
        gc.collect()
        torch.cuda.empty_cache()
    return sum(norms)


def backtracking_line_search(model, z_collection, i, direction, start_alpha, beta, dt, max_iterations, c=0.001):
    alpha = start_alpha
    current_energy = sum_of_etta_norms(model, z_collection, dt)
    gradient_norm_square = direction.norm() ** 2


    tmp_z = z_collection[i] - alpha * direction
    new_z_collection = [element.clone() for element in z_collection]
    new_z_collection[i] = tmp_z
    iterations_count = 0

    while sum_of_etta_norms(model, new_z_collection, dt) > current_energy - c * alpha * gradient_norm_square:
        if iterations_count > max_iterations:
            break
        alpha *= beta
        tmp_z = z_collection[i] - alpha * direction
        new_z_collection[i] = tmp_z
        iterations_count+=1

    return alpha

def sum_of_etta_norms_enc(model, z_collection, dt):
    norms = []
    for j in range(1, len(z_collection) - 2):
        etta_j = compute_etta(model, z_collection[j], z_collection[j-1], z_collection[j+1], dt)
        norms.append(etta_j.norm().item())
        del etta_j
        torch.cuda.empty_cache()
    return sum(norms)

import copy

def geodesic_path_algorithm(model, z0, zT, alpha, T, beta, epsilon, max_iterations):
    dt = 1.0 / T
    z_collection = [z0 + float(i) / T * (zT - z0) for i in range(T)]
    #z_collection = [z0 if i == 0 else zT if i == T - 1 else z0 + float(i) / T * (zT - z0) + 0.01 * torch.randn_like(z0) for i in range(T)]

    z_list.append([z.clone() for z in z_collection])

    iterations = 0

    while sum_of_etta_norms(model, z_collection, dt) > epsilon:

        if iterations == max_iterations:
            break


        print("Energy", sum_of_etta_norms(model, z_collection, dt))
        etta_norms = []

        for i in range(1, T-1):
            etta_i = compute_etta_d(model, z_collection[i], z_collection[i-1], z_collection[i+1], dt)
            #alpha_i = backtracking_line_search(model, z_collection, i, etta_i, alpha, beta, dt, max_iterations)

            alpha_i= alpha
            z_collection[i] -= alpha_i * etta_i
            etta_norms.append(etta_i.norm().item())

            del etta_i
            torch.cuda.empty_cache()

        #z_collection = [tensor + (torch.randn_like(tensor) * 0.1 if 0 < i < len(z_collection) - 1 else 0)
        #                for i, tensor in enumerate(z_collection)]

        if (iterations+1) % 10 == 0:
            z_list.append([z.clone() for z in z_collection])

        iterations+=1

    return z_collection

image_size = 28
import gc
latent_dim = 32

In [None]:
all_features = []
all_labels = []

# Iterate through the DataLoader
for features, labels in test_loader:
    # Assuming features and labels are tensors, you can append them to the lists
    all_features.append(features)
    all_labels.append(labels)

# Concatenate all features and labels
all_features = torch.cat(all_features, dim=0)
all_labels = torch.cat(all_labels, dim=0)
y_test = all_labels
labels = np.unique(all_labels.cpu().numpy())

E = vae_net.encoder
with torch.no_grad():
    EX_test = E(all_features.cuda()).cpu().numpy()

In [None]:
from sklearn.manifold import TSNE

tsne = TSNE(n_components=2, random_state=0, learning_rate=70, n_iter=2000, n_iter_without_progress=400, verbose=1)
EX2D = tsne.fit_transform(EX_test)

In [None]:
plt.figure(figsize=(10,15))
plt.rcParams.update({'font.size': 30})
colors = ['brown', 'g', 'olive', 'c', 'm', 'y', 'k', 'lightblue', 'orange', 'gray', "b", "r"]

for i, c in zip(labels, colors):
    idx = y_test==i
    plt.scatter(EX2D[y_test==i, 0], EX2D[y_test==i,1], c=c, label=str(i))

plt.legend(bbox_to_anchor=(-0.1, 0.55), loc=2, borderaxespad=0)
plt.axis('off')

In [None]:
plt.figure(figsize = (5,10))
out = vutils.make_grid(train_images[0:20], normalize=True)
plt.imshow(out.numpy().transpose((1, 2, 0)))

In [None]:
test_z,_,_ = vae_net.encode(train_images[49].view(1,1,28,28).cuda())
x,_,_ =vae_net.encode(train_images[16].view(1,1,28,28).cuda())
zero_z = torch.FloatTensor(1,latent_dim).zero_().cuda()
white_z,_,_ = vae_net.encode(torch.ones((1, 1, 28, 28)).cuda())
black_z,_,_ = vae_net.encode(torch.zeros((1, 1, 28, 28)).cuda())

z_list = []


In [None]:
path = geodesic_path_algorithm(vae_net, black_z, x, alpha=0.01, T=10, beta=0.7, epsilon=10, max_iterations=3)


In [None]:

interpolated_geodesic_images = [vae_net.decode(vec).view(-1,28,28) for vec in path]

vv = [v for v in interpolated_geodesic_images]
plt.figure(figsize = (20,5))
out = vutils.make_grid(vv, normalize=True)
plt.imshow(out.cpu().numpy().transpose((1, 2, 0)))

In [None]:

def interpolate(start, end, steps):
    """Generate interpolated vectors between start and end."""
    interpolation = [start + float(i) / steps * (end - start) for i in range(steps)]
    #interpolation = [start if i == 0 else end if i == steps - 1 else start + float(i) / steps * (end - start) + 0.3 * torch.randn_like(start) for i in range(steps)]
    return interpolation

# Interpolate between zero_z and x
interpolated_vectors = interpolate(black_z, x, 10)

# Decode these vectors to images
interpolated_rec_images = [vae_net.decode(vec).view(-1,28,28) for vec in interpolated_vectors]
interpolated_images = interpolate(interpolated_rec_images[0], interpolated_rec_images[-1], 10)

vv = [v for v in interpolated_rec_images]
plt.figure(figsize = (20,5))
out = vutils.make_grid(vv, normalize=True)
plt.imshow(out.cpu().numpy().transpose((1, 2, 0)))

In [None]:
vv = [v for v in interpolated_images]
plt.figure(figsize = (20,5))
out = vutils.make_grid(interpolated_images, normalize=True)
plt.imshow(out.cpu().numpy().transpose((1, 2, 0)))

In [None]:
class VAEClassifier(nn.Module):
    def __init__(self, pretrained_vae, num_classes=10):
        super(VAEClassifier, self).__init__()

        # Use the encoder from the pretrained VAE
        self.encoder = pretrained_vae.encoder
        self.mu = pretrained_vae.mu

        # Freeze the encoder parameters
        for param in self.encoder.parameters():
            param.requires_grad = False

        for param in self.mu.parameters():
            param.requires_grad = False

        # Define a classifier - this is just an example, you can add more layers if needed
        self.classifier = nn.Sequential(
            nn.Linear(32, 16),  # 512 is the output size of the encoder
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(16, num_classes)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.mu(x)
        x = self.classifier(x)
        return x


In [None]:
classifier = VAEClassifier(pretrained_vae=vae_net).to(device)

num_epochs = 20
optimizer = optim.Adam(classifier.classifier.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Training loop
def test_model():
    classifier.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.cuda(), target.cuda()
            output = classifier(data)
            test_loss += criterion(output, target).item()  # Sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # Get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

    return test_loss

# Existing training loop starts here...
for epoch in range(num_epochs):
    classifier.train()
    train_loss = 0
    for batch_idx, (data, target) in enumerate(tqdm(train_loader, leave=False)):
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()

        output = classifier(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        train_loss += loss.item()

    # Print epoch statistics
    train_loss /= len(train_loader.dataset)
    print('Epoch: {} Average training loss: {:.4f}'.format(epoch, train_loss))

    # Test the model at the end of each epoch
    test_model()





In [None]:
# Boilerplate imports.
import numpy as np
import PIL.Image
from matplotlib import pylab as P
import torch
from torchvision import models, transforms

# From our repository.

%matplotlib inline
def ShowImage(im, title='', ax=None):
    if ax is None:
        P.figure()
    P.axis('off')
    P.imshow(im)
    P.title(title)

def ShowGrayscaleImage(im, title='', ax=None):
    if ax is None:
        P.figure()
    im_min = im.min()
    im_max = im.max()
    im_normalized = (im - im_min) / (im_max - im_min)

    P.axis('off')
    P.imshow(im_normalized.squeeze(), cmap=P.cm.gray, vmin=0, vmax=1)
    P.title(title)

def ShowHeatMap(im, title, ax=None):
    if ax is None:
        P.figure()
    P.axis('off')
    P.imshow(im, cmap='inferno')
    P.title(title)

def LoadImage(file_path):
    im = PIL.Image.open(file_path)
    im = im.resize((299, 299))
    im = np.asarray(im)
    return im

transformer = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
def PreprocessImages(images):
    # assumes input is 4-D, with range [0,255]
    #
    # torchvision have color channel as first dimension
    # with normalization relative to mean/std of ImageNet:
    #    https://pytorch.org/vision/stable/models.html
    images = np.array(images)
    images = images/255
    images = np.transpose(images, (0,3,1,2))
    images = torch.tensor(images, dtype=torch.float32)
    images = transformer.forward(images)
    return images.requires_grad_(True)


In [None]:
import numpy as np

INPUT_OUTPUT_GRADIENTS = 'INPUT_OUTPUT_GRADIENTS'

SHAPE_ERROR_MESSAGE = {
    INPUT_OUTPUT_GRADIENTS: (
        'Expected key INPUT_OUTPUT_GRADIENTS to be the same shape as input '
        'x_value_batch - expected {}, actual {}'
    ),
}

class IntegratedGradients:
    """Class that implements the integrated gradients method.

    https://arxiv.org/abs/1703.01365
    """

    expected_keys = [INPUT_OUTPUT_GRADIENTS]

    def GetMask(self, x_value, call_model_function, variant='vanilla',
                call_model_args=None, x_baseline=None, x_steps=25,
                batch_size=1, interpolation_points=None):
        """Returns an integrated gradients mask.

        Args:
          ... [Same as previously described]
          variant: Either 'vanilla' or 'custom'. If 'custom', interpolation_points should be provided.
          interpolation_points: Points to be used for the 'custom' variant of integrated gradients.
                                List of ndarrays.
        """


        if variant == 'vanilla':
            if x_baseline is None:
                x_baseline = torch.zeros_like(x_value)
            assert x_baseline.shape == x_value.shape

            x_diff = x_value - x_baseline

            total_gradients = torch.zeros_like(x_value, dtype=torch.float32)

            x_step_batched = []
            for alpha in np.linspace(0, 1, x_steps):
                x_step = x_baseline + alpha * x_diff
                x_step_batched.append(x_step)
                if len(x_step_batched) == batch_size or alpha == 1:
                    x_step_batched = torch.stack(x_step_batched)
                    call_model_output = call_model_function(
                        x_step_batched,
                        call_model_args=call_model_args,
                        expected_keys=self.expected_keys)

                    print("shape of x batched", x_step_batched.shape)

                    self.format_and_check_call_model_output(call_model_output,
                                                            x_step_batched.shape,
                                                            self.expected_keys)

                    total_gradients += call_model_output[INPUT_OUTPUT_GRADIENTS].sum(axis=0)
                    x_step_batched = []

            return total_gradients * x_diff / x_steps

        elif variant == 'manifold':
            assert interpolation_points is not None #Provide interpolation points for custom variant.
            x_diff = interpolation_points[-1] - interpolation_points[0]

            total_gradients = torch.zeros_like(x_diff, dtype=torch.float32)
            x_step_batched = torch.stack(interpolation_points)
            #for x_step in interpolation_points:
            call_model_output = call_model_function(
                x_step_batched,
                call_model_args=call_model_args,
                expected_keys=self.expected_keys)

            self.format_and_check_call_model_output(call_model_output,
                                                    x_step_batched.shape,
                                                    self.expected_keys)
            total_gradients += call_model_output[INPUT_OUTPUT_GRADIENTS].sum(axis=0)

            return total_gradients * x_diff / len(interpolation_points)

        else:
            raise ValueError("Invalid variant provided.")

    def format_and_check_call_model_output(self, output, input_shape, expected_keys):
        """Converts keys in the output into an np.ndarray, and confirms its shape.

        Args:
          ... [Same as previously described]
        """
        check_full_shape = [INPUT_OUTPUT_GRADIENTS]
        for expected_key in expected_keys:
            output[expected_key] = np.asarray(output[expected_key])
            expected_shape = input_shape
            actual_shape = output[expected_key].shape
            if expected_key not in check_full_shape:
                expected_shape = expected_shape[0]

                actual_shape = actual_shape[0]
            if expected_shape != actual_shape:
                raise ValueError(SHAPE_ERROR_MESSAGE[expected_key].format(
                                expected_shape, actual_shape))


In [None]:
class_idx_str = 'class_idx_str'
classifier.eval()
def call_model_function(images, call_model_args=None, expected_keys=None):
    tensor_images = images.cuda()
    tensor_images.requires_grad_(True)

    target_class_idx =  call_model_args[class_idx_str]
    output = classifier(tensor_images)
    #m = torch.nn.Softmax(dim=1)
    #output = m(output)
    if INPUT_OUTPUT_GRADIENTS in expected_keys:
        outputs = output[:,target_class_idx]
        grads = torch.autograd.grad(outputs, tensor_images, grad_outputs=torch.ones_like(outputs))
        grads = torch.movedim(grads[0], 2, 3)
        gradients = grads.cpu().detach().numpy()
        return {INPUT_OUTPUT_GRADIENTS: gradients}
    else:
        one_hot = torch.zeros_like(output)
        one_hot[:,target_class_idx] = 1
        model.zero_grad()
        output.backward(gradient=one_hot, retain_graph=True)
        return conv_layer_outputs

In [None]:
#im_orig = LoadImage('./doberman.png')
#im_tensor = PreprocessImages([im_orig])
# Show the image
#ShowImage(im_orig)
class_idx_str = 'class_idx_str'

predictions = classifier(train_images[8].cuda())
m = torch.nn.Softmax(dim=1)
predictions = m(predictions)
predictions = predictions.cpu().detach().numpy()
prediction_class = np.argmax(predictions[0])
call_model_args = {class_idx_str: prediction_class}

print("Prediction class: " + str(prediction_class))  # Should be a doberman, class idx = 236
#im = im_orig.astype(np.float32)

In [None]:
def ShowGrayscaleImage(im, title='', ax=None, vmin=0, vmax=1):
    if ax is None:
        P.figure()
    im_min = im.min()
    im_max = im.max()
    im_normalized = im #(im - im_min) / (im_max - im_min)
    P.axis('off')
    P.imshow(im_normalized.squeeze(), cmap=P.cm.gray, vmin=vmin, vmax=vmax)
    P.title(title)

ROWS = 1
COLS = 5
UPSCALE_FACTOR = 10
P.figure(figsize=(ROWS * UPSCALE_FACTOR, COLS * UPSCALE_FACTOR))

In [None]:


integrated_gradients = IntegratedGradients()

# Baseline is a black image.
baseline = torch.zeros(test_images[1].shape)
baseline = interpolated_geodesic_images[0].cpu()
#baseline = torch.ones(test_images[3].shape)

test = interpolated_geodesic_images[-1].cuda()
# Compute the vanilla mask and the smoothed mask.
vanilla_ig = integrated_gradients.GetMask(x_value=test.cpu(),
                                                  call_model_function=call_model_function,
                                                  variant='vanilla', call_model_args=call_model_args,
                                                  x_steps=20, x_baseline=baseline, batch_size=100)


geodesic_points = [point.cpu() for point in interpolated_geodesic_images]

linear_points = [point.cpu() for point in interpolated_rec_images]

ig_geodesic = integrated_gradients.GetMask(x_value=test.cpu(),
                                                  call_model_function=call_model_function,
                                                  variant='manifold', call_model_args=call_model_args,
                                                  x_steps=20, x_baseline=baseline, batch_size=100,
                                                 interpolation_points=geodesic_points)

ig_linear = integrated_gradients.GetMask(x_value=test.cpu(),
                                                  call_model_function=call_model_function,
                                                  variant='manifold', call_model_args=call_model_args,
                                                  x_steps=20, x_baseline=baseline, batch_size=100,
                                                 interpolation_points=linear_points)

# Render the saliency masks.
ShowGrayscaleImage(vanilla_ig.cpu().detach().numpy(), title='Vanilla IG', ax=P.subplot(ROWS, COLS, 1), vmin=-0.00001, vmax=0.005)
ShowGrayscaleImage(ig_geodesic.cpu().detach().numpy(), title='IG Geodesic', ax=P.subplot(ROWS, COLS, 2), vmin=-0.00001, vmax=0.005)
ShowGrayscaleImage(ig_linear.cpu().detach().numpy(), title='IG Linear', ax=P.subplot(ROWS, COLS, 3), vmin=-0.00001, vmax=0.005)
#ShowGrayscaleImage(attributions_ig.cpu().detach().numpy(), title='Captum IG', ax=P.subplot(ROWS, COLS, 4), vmin=-0.00001, vmax=0.005)

ShowGrayscaleImage(test.cpu().detach().numpy(), title='Original', ax=P.subplot(ROWS, COLS, 4), vmin=0, vmax=1)


In [None]:
plt.figure(figsize = (5,10))
out = vutils.make_grid(train_images[0:16], normalize=True)
plt.imshow(out.numpy().transpose((1, 2, 0)))

In [None]:

import matplotlib.pyplot as plt
plt.figure(figsize=(10,6))
plt.plot(linear_det_val, label='Linear Path', marker='o', linestyle='-')
plt.plot(geodesic_det_val, label='Geodesic Path', marker='x', linestyle='--')
plt.yscale('log')  # Use a logarithmic scale for clarity due to the large range
plt.title('Determinant of G along Linear and Geodesic Paths')
plt.xlabel('Point Index')

plt.ylabel('Determinant Value (Log Scale)')
plt.legend()
plt.grid(True, which="both", ls="--")
plt.show()



In [None]:
import matplotlib.pyplot as plt
import numpy as np

import matplotlib.pyplot as plt
plt.figure(figsize=(10,6))
cm = plt.get_cmap("viridis")  # you can change 'viridis' to another colormap name if desired
num_of_paths = len(path_list)
colors = [cm(0.2*i/num_of_paths) for i in range(num_of_paths)]

for index, (sub_list, color) in enumerate(zip(path_list, colors)):
    sub_list = [val.cpu().numpy() for val in sub_list]
    plt.plot(sub_list, color=color, label=f"pth_iteration{index*10-index}")


plt.yscale('log')  # Use a logarithmic scale for clarity due to the large range
plt.title('Determinant of G along every updated path between linear and geodesic')
plt.xlabel('Point Index')

plt.ylabel('Determinant Value (Log Scale)')
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.grid(True, which="both", ls="--")
plt.show()
