In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as Datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.models as models
import torchvision.utils as vutils
from torch.hub import load_state_dict_from_url
from torch.autograd import Variable

import os
import random
import numpy as np
import math
from IPython.display import clear_output
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import trange, tqdm

batch_size = 64
image_size = 64
lr = 1e-4
nepoch = 100
start_epoch = 0
dataset_root = ""
save_dir = os.getcwd()
model_name = "STL10_Resnet_64"
latent_channels = 64
load_checkpoint  = False

In [None]:
import torch
import torch.nn as nn
import torch.utils.data


class ResDown(nn.Module):
    """
    Residual down sampling block for the encoder
    """

    def __init__(self, channel_in, channel_out, kernel_size=3):
        super(ResDown, self).__init__()
        self.conv1 = nn.Conv2d(channel_in, channel_out // 2, kernel_size, 2, kernel_size // 2)
        self.bn1 = nn.BatchNorm2d(channel_out // 2, eps=1e-4)
        self.conv2 = nn.Conv2d(channel_out // 2, channel_out, kernel_size, 1, kernel_size // 2)
        self.bn2 = nn.BatchNorm2d(channel_out, eps=1e-4)

        self.conv3 = nn.Conv2d(channel_in, channel_out, kernel_size, 2, kernel_size // 2)

        self.act_fnc = nn.ELU()

    def forward(self, x):
        skip = self.conv3(x)
        x = self.act_fnc(self.bn1(self.conv1(x)))
        x = self.conv2(x)
        return self.act_fnc(self.bn2(x + skip))


class ResUp(nn.Module):
    """
    Residual up sampling block for the decoder
    """

    def __init__(self, channel_in, channel_out, kernel_size=3, scale_factor=2):
        super(ResUp, self).__init__()

        self.conv1 = nn.Conv2d(channel_in, channel_in // 2, kernel_size, 1, kernel_size // 2)
        self.bn1 = nn.BatchNorm2d(channel_in // 2, eps=1e-4)
        self.conv2 = nn.Conv2d(channel_in // 2, channel_out, kernel_size, 1, kernel_size // 2)
        self.bn2 = nn.BatchNorm2d(channel_out, eps=1e-4)

        self.conv3 = nn.Conv2d(channel_in, channel_out, kernel_size, 1, kernel_size // 2)

        self.up_nn = nn.Upsample(scale_factor=scale_factor, mode="nearest")

        self.act_fnc = nn.ELU()

    def forward(self, x):
        x = self.up_nn(x)
        skip = self.conv3(x)
        x = self.act_fnc(self.bn1(self.conv1(x)))
        x = self.conv2(x)

        return self.act_fnc(self.bn2(x + skip))


class ResBlock(nn.Module):
    """
    Residual block
    """

    def __init__(self, channel_in, channel_out, kernel_size=3):
        super(ResBlock, self).__init__()

        self.conv1 = nn.Conv2d(channel_in, channel_in // 2, kernel_size, 1, kernel_size // 2)
        self.bn1 = nn.BatchNorm2d(channel_in // 2, eps=1e-4)
        self.conv2 = nn.Conv2d(channel_in // 2, channel_out, kernel_size, 1, kernel_size // 2)
        self.bn2 = nn.BatchNorm2d(channel_out, eps=1e-4)

        if not channel_in == channel_out:
            self.conv3 = nn.Conv2d(channel_in, channel_out, kernel_size, 1, kernel_size // 2)
        else:
            self.conv3 = nn.Identity()

        self.act_fnc = nn.ELU()

    def forward(self, x):
        skip = self.conv3(x)
        x = self.act_fnc(self.bn1(self.conv1(x)))
        x = self.conv2(x)

        return self.act_fnc(self.bn2(x + skip))


class Encoder(nn.Module):
    """
    Encoder block
    """

    def __init__(self, channels, ch=64, blocks=(1, 2, 4, 8), latent_channels=512):
        super(Encoder, self).__init__()
        self.conv_in = nn.Conv2d(channels, blocks[0] * ch, 3, 1, 1)

        widths_in = list(blocks)
        widths_out = list(blocks[1:]) + [blocks[-1]]

        layer_blocks = []

        for w_in, w_out in zip(widths_in, widths_out):
            layer_blocks.append(ResDown(w_in * ch, w_out * ch))

        layer_blocks.append(ResBlock(blocks[-1] * ch, blocks[-1] * ch))
        layer_blocks.append(ResBlock(blocks[-1] * ch, blocks[-1] * ch))

        self.res_blocks = nn.Sequential(*layer_blocks)

        self.conv_mu = nn.Conv2d(blocks[-1] * ch, latent_channels, 1, 1)
        self.conv_log_var = nn.Conv2d(blocks[-1] * ch, latent_channels, 1, 1)
        self.act_fnc = nn.ELU()

    def sample(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, sample=False):
        x = self.act_fnc(self.conv_in(x))
        x = self.res_blocks(x)

        mu = self.conv_mu(x)
        log_var = self.conv_log_var(x)

        if self.training or sample:
            x = self.sample(mu, log_var)
        else:
            x = mu

        return x, mu, log_var


class Decoder(nn.Module):
    """
    Decoder block
    Built to be a mirror of the encoder block
    """

    def __init__(self, channels, ch=64, blocks=(1, 2, 4, 8), latent_channels=512):
        super(Decoder, self).__init__()
        self.conv_in = nn.Conv2d(latent_channels, ch * blocks[-1], 1, 1)

        widths_out = list(blocks)[::-1]
        widths_in = (list(blocks[1:]) + [blocks[-1]])[::-1]

        layer_blocks = [ResBlock(blocks[-1] * ch, blocks[-1] * ch),
                        ResBlock(blocks[-1] * ch, blocks[-1] * ch)]

        for w_in, w_out in zip(widths_in, widths_out):
            layer_blocks.append(ResUp(w_in * ch, w_out * ch))

        self.res_blocks = nn.Sequential(*layer_blocks)

        self.conv_out = nn.Conv2d(blocks[0] * ch, channels, 3, 1, 1)
        self.act_fnc = nn.ELU()

    def forward(self, x):
        x = self.act_fnc(self.conv_in(x))
        x = self.res_blocks(x)
        mu = torch.tanh(self.conv_out(x))
        return mu


class VAE(nn.Module):
    """
    VAE network, uses the above encoder and decoder blocks
    """
    def __init__(self, channel_in=3, ch=64, blocks=(1, 2, 4, 8), latent_channels=512):
        super(VAE, self).__init__()
        """Res VAE Network
        channel_in  = number of channels of the image
        z = the number of channels of the latent representation
        (for a 64x64 image this is the size of the latent vector)
        """

        self.encoder = Encoder(channel_in, ch=ch, blocks=blocks, latent_channels=latent_channels)
        self.decoder = Decoder(channel_in, ch=ch, blocks=blocks, latent_channels=latent_channels)

    def forward(self, x):
        encoding, mu, log_var = self.encoder(x)
        recon_img = self.decoder(encoding)
        return recon_img, mu, log_var

In [None]:

save_dir = os.getcwd()


def get_data_STL10(transform, batch_size, download = True, root = "/data"):
    print("Loading trainset...")
    trainset = Datasets.STL10(root=root, split='unlabeled', transform=transform, download=download)

    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

    print("Loading testset...")
    testset = Datasets.STL10(root=root, split='test', download=download, transform=transform)

    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
    print("Done!")

    return trainloader, testloader

transform = transforms.Compose([transforms.Resize(image_size),
                                transforms.CenterCrop(image_size),
                                transforms.RandomHorizontalFlip(0.5),
                                transforms.ToTensor(),
                                transforms.Normalize(0.5, 0.5)])

trainloader, testloader = get_data_STL10(transform, batch_size, download=True, root=dataset_root)

In [None]:
import random
dataiter = iter(testloader)
test_images, _ = next(dataiter)
trainiter = iter(trainloader)
train_images, _ = next(trainiter)

plt.figure(figsize = (1,1))
out = vutils.make_grid(train_images[42], normalize=True)
plt.imshow(out.numpy().transpose((1, 2, 0)))

test_images.shape

num_samples = 30
sample_indices = random.sample(range(test_images.shape[0]), num_samples)

# Select those images
sampled_images = test_images[sample_indices]

# Average the sampled images

average_image = torch.mean(sampled_images, dim=0)

In [None]:
out = vutils.make_grid(average_image, normalize=True)

plt.imshow(out.numpy().transpose((1, 2, 0)))


In [None]:
use_cuda = torch.cuda.is_available()
gpu_indx  = 0
device = torch.device(gpu_indx if use_cuda else "cpu")

vae_net = VAE(channel_in=3, ch=64, blocks=(1, 2, 4, 8), latent_channels=64).to(device)
# setup optimizer
optimizer = optim.Adam(vae_net.parameters(), lr=lr, betas=(0.5, 0.999))

checkpoint = torch.load(save_dir + "/Models/" + model_name + "_" + str(image_size) + ".pt")
print("Checkpoint loaded")
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
vae_net.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint["epoch"]
loss_log = checkpoint["loss_log"]

vae_net

In [None]:
vae_net.eval()
x,_,_ =vae_net.encoder(test_images[0].view(1,3,64,64).cuda())
test_z,_,_ = vae_net.encoder(test_images[16].view(1,3,64,64).cuda())
zero_z = torch.FloatTensor(1,64,4,4).zero_().cuda()
rand_z = torch.randn_like(zero_z)
x_black = torch.FloatTensor(1,3,64,64).zero_().cuda()
x_white = torch.full((3,64,64), 1.0).view(1,3,64,64).cuda()
print(x_white.shape, x_black.shape )
z_black,_,_ = vae_net.encoder(x_black)
z_white,_,_ = vae_net.encoder(x_white)
average_z ,_,_ =vae_net.encoder(average_image.view(1,3,64,64).cuda())
#igg = get_blurred_image(test_images[7].view(1,3,64,64).cuda(), sigma=20)
#blurred_img = igg.view(1,3,64,64)
#blurred_z = vae_net.encoder(blurred_img)

def interpolate(start, end, steps):
    """Generate interpolated vectors between start and end."""
    interpolation = [start + (float(i)/steps)*(end-start) for i in range(0, steps)]
    return interpolation

# Interpolate between zero_z and x
interpolated_vectors = interpolate(zero_z, x, 10)

# Decode these vectors to images
interpolated_rec_images = [vae_net.decoder(vec) for vec in interpolated_vectors]
interpolated_images = interpolate(interpolated_rec_images[0], interpolated_rec_images[-1], 10)



for inter_ing in interpolated_rec_images:
    plt.figure(figsize = (1,1))

    decoded = vutils.make_grid(inter_ing, normalize=True)
    plt.imshow(decoded.cpu().numpy().transpose((1, 2, 0)))



In [None]:
for inter_ing in interpolated_images:
    plt.figure(figsize = (1,1))

    decoded = vutils.make_grid(inter_ing, normalize=True)
    plt.imshow(decoded.cpu().numpy().transpose((1, 2, 0)))


In [None]:
import numpy as np
from scipy.ndimage import gaussian_filter, center_of_mass
from tqdm import tqdm
from PIL import Image

def get_blurred_image(image, sigma=10):
    if len(image.shape) == 4:
        image = image.cpu().numpy()
        blurred_images = [gaussian_filter(im, (sigma, sigma, 0)) for im in image]
        return torch.tensor(blurred_images).cuda()
    elif len(image.shape) == 3:
        return gaussian_filter(image, (sigma, sigma, 0))
    else:
        return gaussian_filter(image, sigma)

plt.figure(figsize = (1,1))

igg = get_blurred_image(test_images[7].view(1,3,64,64).cuda(), sigma=20)

igg = vutils.make_grid(igg[0], normalize=True)
#plt.imshow(igg.transpose((1, 2, 0)))

plt.imshow(igg.cpu().numpy().transpose((1, 2, 0)))
igg.shape

In [None]:
recon_img, _, _ = vae_net(igg.view(1,3,64,64))
s = vae_net.decoder(z_black)
#s = s.squeeze().cpu().detach().numpy().transpose((1, 2, 0))
plt.figure(figsize = (1,1))


decoded = vutils.make_grid(recon_img, normalize=True)
plt.imshow(decoded.cpu().numpy().transpose((1, 2, 0)))

In [None]:
x_black.shape

In [None]:
vae_net.eval()
from torch.autograd import Variable

# Step 2: Create a random tensor of shape (batch_size, latent_channels, height, width).
# Let's take batch_size as 1 for simplicity and height and width as 1x1.
batch_size = 1
latent_channels = 64  # as defined in the VAE
height, width = 4, 4  # typical for a latent space, unless your VAE encoder produces differently shaped latent space

random_tensor = torch.randn(batch_size, latent_channels, height, width).cuda()

# Step 3: Decode the random tensor using the VAE's decoder.
with torch.no_grad():  # since we're not training now
    decoded_image = vae_net.decoder(random_tensor)

print(decoded_image.shape)

In [None]:
import torch

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(0)

import gc
gc.collect()
torch.cuda.empty_cache()


def compute_etta(model, zi, zi_minus, zi_plus, dt):
    #model.eval()
    g_zi_minus = model.decoder(zi_minus).view(-1)
    g_zi = model.decoder(zi).view(-1)
    g_zi_plus = model.decoder(zi_plus).view(-1)

    finite_diff = (g_zi_plus - 2*g_zi + g_zi_minus) / dt



    # Use the encoder's Jacobian to map the finite difference in X space back to Z space
    Jh_tuple = torch.autograd.functional.jacobian(model.encoder, g_zi.view(1,3,image_size,image_size))
    Jh = Jh_tuple[0].view(512*4*4, -1)  # Reshape to a 2D tensor

    etta_i = - torch.mm(Jh, finite_diff.unsqueeze(-1)).view_as(zi)
    del Jh
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.empty_cache()

    return etta_i

def compute_etta_d(model, zi, zi_minus, zi_plus, dt):
    # Compute the finite difference
    g_zi_minus = model.decoder(zi_minus).view(-1)
    g_zi = model.decoder(zi).view(-1)
    g_zi_plus = model.decoder(zi_plus).view(-1)

    finite_diff = (g_zi_plus - 2 * g_zi + g_zi_minus) / dt
    finite_diff = finite_diff.view(1, 3, image_size, image_size)  # Reshape it to match the encoder's output shape

    scaled_finite_diff = 0.1 * finite_diff
    #print("Fdidd shape", finite_diff.shape)

    # Compute Jacobian-vector product
    vjp_outputs = torch.autograd.functional.vjp(model.decoder, zi, finite_diff)
    #print("vjp", len(vjp_outputs))
    # Get the result from the vjp outputs
    Jv = vjp_outputs[1].view_as(zi)
    #print("Jv", Jv.shape)

    # Compute etta_i
    etta_i = -Jv #+ scaled_finite_diff)

    # Free up memory
    del g_zi_minus, g_zi, g_zi_plus, finite_diff, Jv, vjp_outputs
    torch.cuda.empty_cache()

    return etta_i

def compute_etta1(model, zi, zi_minus, zi_plus, dt):

    g_zi_minus = model.decoder(zi_minus).view(-1).detach()
    g_zi = model.decoder(zi).view(-1)  # We will need the gradient information for this tensor for the Jacobian
    g_zi_plus = model.decoder(zi_plus).view(-1).detach()

    finite_diff = (g_zi_plus - 2 * g_zi + g_zi_minus) / dt

    def partial_encoder(input_data):
        return model.encoder(input_data)[1]

    #encoder_output = model.encoder(model.decoder(zi))  # Assuming input_data is your input to the encoder
      # Access the first element of the tuple
    Jh_tuple = torch.autograd.functional.jacobian(partial_encoder, g_zi.view(1, 3, 64, 64))
    Jh = Jh_tuple.view(512 * 4 * 4, -1)  # Reshape to a 2D tensor


    # Use the encoder's Jacobian to map the finite difference in X space back to Z space
    #Jh_tuple = torch.autograd.functional.jacobian(model.encoder, g_zi.view(1, 3, 64, 64))
    #Jh = Jh_tuple[0].view(512 * 4 * 4, -1)  # Reshape to a 2D tensor

    etta_i = - torch.mm(Jh, finite_diff.unsqueeze(-1)).view_as(zi)

    # Free up memory
    del g_zi_minus, g_zi, g_zi_plus, Jh, Jh_tuple, finite_diff
    torch.cuda.empty_cache()
    gc.collect()

    return etta_i


def compute_etta(model, zi, zi_minus, zi_plus, dt):
    # Compute the finite difference
    g_zi_minus = model.decoder(zi_minus).view(-1)
    g_zi = model.decoder(zi).view(-1)
    g_zi_plus = model.decoder(zi_plus).view(-1)

    finite_diff = (g_zi_plus - 2 * g_zi + g_zi_minus) / dt
    finite_diff = finite_diff.view(1, 3, image_size, image_size)  # Reshape it to match the encoder's output shape

    # Define a wrapper function for the encoder, so it can handle just the required output
    def partial_encoder(input_data):
        return model.encoder(input_data)[1]

    # Compute Jacobian-vector product
    #vjp_outputs = torch.autograd.functional.jvp(wrapper_decoder, zi, finite_diff)
    vjp_outputs = torch.autograd.functional.jvp(partial_encoder, g_zi.view(1, 3, image_size, image_size), finite_diff)
    # vjp_outputs = torch.autograd.functional.jvp(wrapper_func, g_zi.view(1, 3, 64, 64), finite_diff)

    # Get the result from the vjp outputs
    Jv = vjp_outputs[1].view_as(zi)

    # Compute etta_i
    etta_i = -Jv

    # Free up memory
    del g_zi_minus, g_zi, g_zi_plus, finite_diff, Jv, vjp_outputs
    torch.cuda.empty_cache()
    gc.collect()

    return etta_i


def compute_etta_d(model, zi, zi_minus, zi_plus, dt):
    # Compute the finite difference
    g_zi_minus = model.decoder(zi_minus).view(-1)
    g_zi = model.decoder(zi).view(-1)
    g_zi_plus = model.decoder(zi_plus).view(-1)

    finite_diff = (g_zi_plus - 2 * g_zi + g_zi_minus) / dt
    finite_diff = finite_diff.view(1, 3, image_size, image_size)  # Reshape it to match the encoder's output shape

    # Compute Jacobian-vector product
    vjp_outputs = torch.autograd.functional.vjp(model.decoder, zi, finite_diff)
    #print("vjp", len(vjp_outputs))
    # Get the result from the vjp outputs
    Jv = vjp_outputs[1].view_as(zi)

    # Compute etta_i
    etta_i = -Jv

    # Free up memory
    del g_zi_minus, g_zi, g_zi_plus, finite_diff, Jv, vjp_outputs
    torch.cuda.empty_cache()

    return etta_i

In [None]:



def backtracking_line_search(model, z_collection, i, direction, start_alpha, beta, dt, max_iterations, c=0.001):
    alpha = start_alpha
    current_energy = sum_of_etta_norms(model, z_collection, dt)
    gradient_norm_square = direction.norm().pow(2)


    tmp_z = z_collection[i] - alpha * direction
    new_z_collection = [element.clone() for element in z_collection]
    new_z_collection[i] = tmp_z
    #iterations_count = 0

    while sum_of_etta_norms(model, new_z_collection, dt) > current_energy - c * alpha * gradient_norm_square:
        alpha *= beta
        tmp_z = z_collection[i] - alpha * direction
        new_z_collection[i] = tmp_z

    return alpha

def sum_of_etta_norms_enc(model, z_collection, dt):
    norms = []
    for j in range(1, len(z_collection) - 1):
        etta_j = compute_etta(model, z_collection[j], z_collection[j-1], z_collection[j+1], dt)
        norms.append(etta_j.norm().pow(2).item())
        del etta_j
        torch.cuda.empty_cache()
        gc.collect()
        torch.cuda.empty_cache()
    return sum(norms)

def sum_of_etta_norms(model, z_collection, dt):
    norms = []
    for j in range(1, len(z_collection) - 1):
        etta_j = compute_etta_d(model, z_collection[j], z_collection[j-1], z_collection[j+1], dt)
        norms.append(etta_j.norm().pow(2).item())
        #del etta_j
        torch.cuda.empty_cache()
        gc.collect()
        torch.cuda.empty_cache()
    return sum(norms)


def geodesic_path_algorithm(model, z0, zT, alpha, T, beta, epsilon, max_iterations):
    dt = 1.0 / T
    z_collection = [z0 + float(i) / T * (zT - z0) for i in range(T)]
    initial_sum_norms = float('inf')
    iterations = 0

    while sum_of_etta_norms(model, z_collection, dt) > epsilon:
        print(f"It{iterations}:Energy", sum_of_etta_norms(model, z_collection, dt))

    #while True:
        etta_norms = []
        if sum_of_etta_norms(model, z_collection, dt) > initial_sum_norms:
                initial_sum_norms = sum_of_etta_norms(model, z_collection, dt)
        else:
            pass
            #break
        #while True:
            #if init == 0:
        if iterations == max_iterations:
            break

        iterations +=1

        for i in range(1, T-1):
            etta_i = compute_etta_d(model, z_collection[i], z_collection[i-1], z_collection[i+1], dt)
            #alpha_i = backtracking_line_search(model, z_collection, i, etta_i, alpha, beta, dt, max_iterations)
            alpha_i= alpha
            z_collection[i] -= alpha_i * etta_i
            etta_norms.append(etta_i.norm().pow(2).item())

            del etta_i
            torch.cuda.empty_cache()
            gc.collect()
            torch.cuda.empty_cache()
    return z_collection

# Example usage:
x,_,_ =vae_net.encoder(train_images[1].view(1,3,64,64).cuda())
zero_z = torch.FloatTensor(1,64,4,4).zero_().cuda()
z_0,_,_ = vae_net.encoder(train_images[3].view(1,3,64,64).cuda())

path = geodesic_path_algorithm(vae_net, zero_z, x, alpha=0.00001, T=10, beta=0.5, epsilon=1000, max_iterations=2)

In [None]:
zero_z.shape

In [None]:
interpolated_geodesic_images = [vae_net.decoder(vec) for vec in path]

vv = [v.squeeze() for v in interpolated_geodesic_images]
plt.figure(figsize = (20,10))
out = vutils.make_grid(vv, normalize=True)
plt.imshow(out.cpu().numpy().transpose((1, 2, 0)))

In [None]:
interpolated_vectors = interpolate(z_0, x, 10)

# Decode these vectors to images
interpolated_rec_images = [vae_net.decoder(vec) for vec in interpolated_vectors]

vv = [v.squeeze() for v in interpolated_rec_images]
plt.figure(figsize = (20,10))
out = vutils.make_grid(vv, normalize=True)
plt.imshow(out.cpu().numpy().transpose((1, 2, 0)))

In [None]:
#x,_,_ =vae_net.encoder(train_images[1].view(1,3,64,64).cuda())
#zero_z = torch.FloatTensor(1,64,4,4).zero_().cuda()
#z_0,_,_ = vae_net.encoder(train_images[32].view(1,3,64,64).cuda())
interpolated_vectors = interpolate(z_0, x, 10)



vv = [v.squeeze() for v in interpolated_rec_images]
plt.figure(figsize = (20,10))
out = vutils.make_grid(vv, normalize=False)
plt.imshow(out.cpu().numpy().transpose((1, 2, 0)))

In [None]:
interpolated_images = interpolate(interpolated_rec_images[0], interpolated_rec_images[-1], 30)

vv = [v.squeeze() for v in interpolated_images]
plt.figure(figsize = (20,10))
out = vutils.make_grid(vv, normalize=True)
plt.imshow(out.cpu().numpy().transpose((1, 2, 0)))

In [None]:
x,_,_ =vae_net.encoder(train_images[42].view(1,3,64,64).cuda())
zero_z = torch.FloatTensor(1,64,4,4).zero_().cuda()
rand_z = torch.randn_like(zero_z)
x_black = torch.FloatTensor(1,3,64,64).zero_().cuda()
z_black,_,_ = vae_net.encoder(x_black)

def interpolate(start, end, steps):
    """Generate interpolated vectors between start and end."""
    interpolation = [start + (float(i)/steps)*(end-start) for i in range(0, steps)]
    return interpolation

# Interpolate between zero_z and x
interpolated_vectors = interpolate(zero_z, x, 10)

# Decode these vectors to images
interpolated_rec_images = [vae_net.decoder(vec) for vec in interpolated_vectors]
interpolated_images = interpolate(interpolated_rec_images[0], interpolated_rec_images[-1], 10)

for inter_ing in interpolated_rec_images:
    plt.figure(figsize = (1,1))

    decoded = vutils.make_grid(inter_ing, normalize=True)
    plt.imshow(decoded.cpu().numpy().transpose((1, 2, 0)))

In [None]:
plt.figure(figsize = (20,10))
out = vutils.make_grid(test_images[0:8], normalize=True)
plt.imshow(out.numpy().transpose((1, 2, 0)))

In [None]:
vv = [v.squeeze() for v in interpolated_geodesic_images]
plt.figure(figsize = (20,10))
out = vutils.make_grid(vv, normalize=True)
plt.imshow(out.cpu().numpy().transpose((1, 2, 0)))

In [None]:
interpolated_vectors = interpolate(test_z, x, 10)
interpolated_rec_images = [vae_net.decoder(vec) for vec in interpolated_vectors]
interpolated_images = interpolate(interpolated_rec_images[0], interpolated_rec_images[-1], 10)

plt.figure(figsize = (20,20))

in_rec = [v.squeeze() for v in interpolated_rec_images]
decoded = vutils.make_grid(in_rec, normalize=True)
plt.imshow(decoded.cpu().numpy().transpose((1, 2, 0)))

In [None]:
interpolated_vectors = interpolate(test_z, x, 10)
interpolated_rec_images = [vae_net.decoder(vec) for vec in interpolated_vectors]
interpolated_images = interpolate(interpolated_rec_images[0], interpolated_rec_images[-1], 10)

plt.figure(figsize = (20,5))

in_rec = [v.squeeze() for v in interpolated_images]
decoded = vutils.make_grid(in_rec, normalize=True)
plt.imshow(decoded.cpu().numpy().transpose((1, 2, 0)))

In [None]:
# A trial to fast compute geodesics, FAILED! left to have a look later

import torch.cuda as cuda


device = 'cuda:0'
vae_net = vae_net.to(device)
vae_net = nn.DataParallel(vae_net)

import multiprocessing

try:
    multiprocessing.set_start_method('spawn')
except RuntimeError:
    pass


def process_compute_etta(queue, device, model, zi, zi_minus, zi_plus, dt):
    torch.cuda.set_device(device)
    model.eval()

    zi, zi_minus, zi_plus = zi.to(device), zi_minus.to(device), zi_plus.to(device)

    etta = compute_etta(model, zi, zi_minus, zi_plus, dt)
    queue.put(etta.cpu().numpy())


def async_compute_etta(device, model, zi, zi_minus, zi_plus, dt):
    with torch.cuda.device(device):  # Set the current device to this GPU
        g_zi_minus = model.decoder(zi_minus).view(-1)
        g_zi = model.decoder(zi).view(-1)
        g_zi_plus = model.decoder(zi_plus).view(-1)
        finite_diff = (g_zi_plus - 2 * g_zi + g_zi_minus) / dt
        Jh_tuple = torch.autograd.functional.jacobian(model.encoder, g_zi.view(1, 3, 64, 64))
        Jh = Jh_tuple[0].view(512 * 4 * 4, -1)
        etta_i = - torch.mm(Jh, finite_diff.unsqueeze(-1)).view_as(zi)
    return etta_i

from multiprocessing import Queue, Process

def geodesic_path_algorithm(model, z0, zT, alpha, T, epsilon=200):
    dt = 1.0 / T
    z_collection = [z0 + i / T * (zT - z0) for i in range(T+1)]

    while True:
        etta_norms = []

        for i in range(1, T):
            device = 'cuda:0' if i % 2 == 0 else 'cuda:1'

            # Using multiprocessing to parallelize across GPUs
            q = Queue()
            p = Process(target=process_compute_etta, args=(q, device, model, z_collection[i], z_collection[i-1], z_collection[i+1], dt))
            p.start()
            etta_i = torch.tensor(q.get()).to(device)
            p.join()

            etta_norms.append(etta_i.norm().item())
            z_collection[i] -= alpha * etta_i
            del etta_i
            torch.cuda.empty_cache()

        if sum(etta_norms) < epsilon:
            break

        print("Energy", sum(etta_norms))

    return z_collection



path = geodesic_path_algorithm(vae_net, z_black, x, alpha=0.01, T=10)

