In [None]:
# ------------------------------------------------------------------------------
# Special file just for the improved variance estimate for VAEs as presented in
# 2018 G. Arvanitidis "Latent Space Oddity: on the Curvature of Deep Generative
# Models" (https://arxiv.org/abs/1710.11379)
# ------------------------------------------------------------------------------
import torch
import torch as pt
import numpy as np
use_cuda = torch.cuda.is_available()
gpu_indx  = 0
device = torch.device(gpu_indx if use_cuda else "cpu")

class RBF (pt.nn.Module):
    """
    Class to improve the variance of VAE
    """
    def __init__(self, centers, bandwidth, X_dim, zeta=1e-1):
        super().__init__()

        self.k = centers.shape[0]
        self.centers = pt.nn.Parameter(pt.Tensor(centers), requires_grad=False)
        self.bandwidth = pt.nn.Parameter(pt.Tensor(bandwidth), requires_grad=False)

        self.W = pt.nn.Linear(centers.shape[0], X_dim, bias=False)
        self.zeta = pt.nn.Parameter(pt.Tensor([zeta]), requires_grad=False)


    def forward(self, z_input):
        N = z_input.shape[0]
        latent_dim = z_input.shape[1]

        v = pt.exp(-self.bandwidth * pt.sum(
                (z_input.view(N, 1, latent_dim) - self.centers.view(1, self.k, latent_dim))**2,
                axis=-1
                )
            )
        beta = self.W(v) + self.zeta

        return beta


def trainRBF (modelE, modelD, dataloader, latent_dim, X_dim, k, zeta=1e-3, curveMetric=1, max_epochs=100, batch_size=16):
    """
    Using trained VAE we now fit more accurate variance estimates using a RBF network.

    Arguments:
        modelE (torch.nn.Module) : Encoder of VAE.
        modelD (torch.nn.Module) : Decoder of VAE.
        dataloader (torch.nn.DataLoader) : training data used for VAE.
        latent_dim (int) : dimension of latent space.
        X_dim (np.ndarray) [C, H, W] : shape of input/output space.
        k (int) : number of clusters for k-means.
        zeta (float) : minimal precision (zeta > 0, at 0 the variance goes to infinity)
        curveMetric (float) : RBF bandwidth parameter, higher values create smoother variances, lower values make the RBF stick closer to the data (less smooth)
        max_epochs (int) : training of RBF

    Returns:
        Trained RBF (torch.nn.Module)
    """
    N = len(dataloader.dataset)
    # Keep a copy of the X_dim
    input_dim = X_dim
    device = next(modelD.parameters()).device
    modelE.eval(), modelD.eval()

    ### Compute embedded vectors
    with pt.no_grad():
        z_input = []
        kl_loss = []
        for X_input, _ in dataloader:
            X_input = X_input.to(device)
            zmean, zlogvar = modelE(X_input)
            z_input.append(zmean.detach().cpu().numpy())
            kl_loss.append((-0.5 * (1 + zlogvar - zmean**2 - zlogvar.exp()).sum(dim=1)).detach())

        z_input = np.concatenate(z_input, axis=0)
        kl_loss = pt.cat(kl_loss, dim=0).mean()

    ### Initialize the centers randomly between zmin and zmax with margin 1
    z_min = np.min(z_input.T, axis=1) - 1
    z_max = np.max(z_input.T, axis=1) + 1
    centers = np.random.uniform(z_min, z_max, size=[k, latent_dim])

    ### Generate the sets that belong to the same center, S shape [N]
    # These shaped subtractions:
    # [N, 1, latent_dim] - [1, k, latent_dim] -> [N, k, latent_dim]
    # S contains the index of which set each point belongs to
    Sidx = np.argmin(np.sum((z_input.reshape(N, 1, latent_dim) - centers.reshape(1, k, latent_dim))**2, axis=-1), axis=-1)

    ### Repeat center assignment until S_idx does not change anymore
    print("Starting k-means: ")
    iterc = 0
    while (True):
        iterc+=1
        if (iterc % 10 == 0):
            print(f"Iteration {iterc}")

        for i in range(k):
            S_i = z_input[Sidx==i]
            if S_i.shape[0] == 0:
                # We don't want empty centers, randomize until we find non-empty
                centers[i] = np.random.uniform(z_min, z_max, size=latent_dim)
            else:
                centers[i] = np.sum(S_i, axis=0) / S_i.shape[0]

        Sidx_new = np.argmin(np.sum((z_input.reshape(N, 1, latent_dim) - centers.reshape(1, k, latent_dim))**2, axis=-1), axis=-1)

        if (Sidx == Sidx_new).all():
            break
        else:
            Sidx = Sidx_new

    S_shapes = []
    for i in range(k):
        S_i = z_input[Sidx==i]
        S_shapes.append(S_i.shape[0])

    # print(f"Number of points within each center_set: {S_shapes}")

    bandwidth = np.ones(k)

    for i in range(k):
        S_i = z_input[Sidx==i]
        if S_i.shape[0] == 0:
            # If there are no points in this center, it should have minimal influence
            bandwidth[i] = 1e-3
        else:
            # Prevent that all S_i are exactly on the center
            eps = 1e-6
            bandwidth[i] = 0.5 * ( curveMetric / S_i.shape[0] * np.sum(np.sqrt(np.sum((S_i - centers[i])**2, axis=-1))+eps) )**-2

    print("Training RBF...")
    ### Start building network and clipper
    class PosClipper (object):
        def __call__(self, module):
            if hasattr(module, 'weight'):
                w = module.weight.data
                w.clamp_(0)

    if not isinstance(X_dim, int) and len(X_dim) > 1:
        # Not a scalar value
        input_dim = np.asarray(input_dim).prod()

    rbfNN = RBF(centers, bandwidth, input_dim, zeta)
    rbfNN.to(device)
    clipper = PosClipper()

    optimizerRBF = pt.optim.Adam(
                    rbfNN.parameters(),
                    lr=0.01,
                    weight_decay=1e-4
                )

    rbfNN.train()
    rbfNN.apply(clipper)
    for epoch in range(max_epochs):
        epoch_loss = 0
        shuffledIdx = pt.randperm(N)

        for i in range(0, N, batch_size):
            with pt.set_grad_enabled(True):
                modelD.eval()
                rbfNN.zero_grad()
                idx = shuffledIdx[i:i+batch_size]
                z = pt.Tensor(z_input[idx]).to(device)
                X = dataloader.dataset.data[idx].to(device)

                rbfVar = 1 / rbfNN(z)
                Xmean, Xlogvar = modelD(z)
                if not isinstance(X_dim, (int, np.int32, np.int64)) and len(X_dim) > 1:
                    Xmean = Xmean.view(-1, input_dim)
                    X = X.view(-1, input_dim)
                rec_loss = 0.5 * pt.log(rbfVar).sum(dim=1) + 0.5 * ((X - Xmean)**2 / rbfVar).sum(dim=1)
                rec_loss += (input_dim/2) * np.log(2*np.pi)

                loss = rec_loss.sum()

            loss.backward()
            optimizerRBF.step()
            rbfNN.apply(clipper)

            epoch_loss += loss.item()
        epoch_loss /= N
        # Add this constant value so it's the complete ELBO loss
        epoch_loss += kl_loss

        if (epoch % 10 == 0):
            print(f"Epoch [{epoch+1}/{max_epochs}]: Loss {epoch_loss:.4f}")

    rbfNN.eval()
    return rbfNN

In [None]:
import torch
import torchvision


class MNIST(torch.utils.data.Dataset):
    """
    Loads n amount of training data of specified classes from one of the MNIST datasets (Digits, Fashion, Kuzushiji, EMNIST). MNIST data has shape [-1, 1, 28, 28].

    Arguments:
        dataset (torchvision.datasets) : one of the MNIST datasets transformed with ToTensor.
        labels (List): an integer list containing all the classes you want to have in the data.
        number_of_samples (int): how many of each class should appear.
        train (boolean): should it be training or testing dataset
    """
    def __init__(self, dataset, labels, number_of_samples=1000, train=True):
        super().__init__()
        self.data = []
        self.targets = []

        for label in labels:
            idx = (dataset.targets == label)
            if (idx.shape[0] == 0):
                print(f"ERROR: Label {label} not found!")
                continue

            # Originally data are bytes, we want floats between -1 and 1
            self.data.append(2*(dataset.data[idx][:number_of_samples].float() / 255)-1)
            self.targets.append(dataset.targets[idx][:number_of_samples])

        self.data = torch.cat(self.data, dim=0).view(-1, 1, 28, 28)
        # Typecast targets as longs.
        self.targets = torch.cat(self.targets, dim=0)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return (self.data[idx], self.targets[idx])

class MNISTDigits (MNIST):
    """
    Loads n amount of training data of specified MNIST digits.

    Arguments:
        digits (List): an integer list containing all the digits you want to have in the data.
        number_of_samples (int): how many of each digit should appear.
        train (boolean): should it be training or testing dataset
    """
    def __init__(self, digits, number_of_samples=1000, train=True):
        dataset = torchvision.datasets.MNIST("Data/", train=train, download=True)

        super().__init__(dataset, digits, number_of_samples, train)


In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import numpy as np

class BlackImagesDataset(Dataset):
    def __init__(self, num_images, image_size, label):
        self.num_images = num_images
        self.image_size = image_size
        self.label = label

    def __len__(self):
        return self.num_images

    def __getitem__(self, idx):
        image = torch.zeros(self.image_size, dtype=torch.float32)
        label = self.label
        return image, label

class WhiteImagesDataset(Dataset):
    def __init__(self, num_images, image_size, label):
        self.num_images = num_images
        self.image_size = image_size
        self.label = label

    def __len__(self):
        return self.num_images

    def __getitem__(self, idx):
        image = torch.ones(self.image_size, dtype=torch.float32)
        label = self.label
        return image, label

def prepare_datasets(num_black_images=50):
    # Load MNIST dataset
    transform = transforms.Compose([transforms.ToTensor()])
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    # Create dataset of black images (label 10 for black images)
    black_train_dataset = BlackImagesDataset(num_black_images, (1, 28, 28), 10)
    black_test_dataset = BlackImagesDataset(num_black_images, (1, 28, 28), 10)

    white_train_dataset = WhiteImagesDataset(num_black_images, (1, 28, 28), 11)
    white_test_dataset = WhiteImagesDataset(num_black_images, (1, 28, 28), 11)

    # Concatenate MNIST with black images dataset
    train_dataset = ConcatDataset([train_dataset, black_train_dataset, white_train_dataset])
    test_dataset = ConcatDataset([test_dataset, black_test_dataset, white_test_dataset])

    return train_dataset, test_dataset



# Prepare datasets
train_dataset, test_dataset = prepare_datasets()

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, drop_last=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=True)

# Now you can use train_loader and test_loader for training and testing your VAE


In [None]:
dataiter = iter(test_loader)
test_images, test_labels = next(dataiter)
test_images[0]
train_iter = iter(train_loader)
train_images, train_labels = next(train_iter)

In [None]:
# ------------------------------------------------------------------------------
# PyTorch implementation of a convolutional Variational Autoencoder (2014 D.
# Kingma "Auto-Encoding Variational Bayes" in https://arxiv.org/abs/1312.6114)
# ------------------------------------------------------------------------------

import os

import matplotlib.pyplot as plt
import numpy as np
import torch as pt

from torch import nn
from torchsummary import summary
from torchvision.utils import save_image



class ConvBlock (nn.Sequential):
    def __init__ (self, in_c, out_c, kernel_size, stride=1):
        super().__init__()
        self.add_module('Convolution', nn.utils.spectral_norm(nn.Conv2d(in_c, out_c, kernel_size, stride)))
        self.add_module('BatchNorm', nn.BatchNorm2d(out_c, affine=True))
        self.add_module('Activation', nn.ELU())

class ConvTransposeBlock (nn.Sequential):
    def __init__ (self, in_c, out_c, kernel_size, stride=1):
        super().__init__()
        self.add_module('ConvTranspose', nn.utils.spectral_norm(nn.ConvTranspose2d(in_c, out_c, kernel_size, stride)))
        self.add_module('BatchNorm', nn.BatchNorm2d(out_c, affine=True))
        self.add_module('Activation', nn.ELU())


class Encoder (nn.Module):
    """
    Encoder of 2D VAE, producing latent multivariate normal distributions from input images. We return logarithmic variances as they have the real numbers as domain.

    Forward pass:
        1 Input:
            i)  Image of shape [N, C, H, W] (by default 1x28x28 MNIST images)
        2 Outputs:
            i)  Means of latent distributions of shape [N, latent_dim]
            ii) Logarithmic variances of latent distribution of shape [N, latent_dim] (Approximation of multivariate Gaussian, covariance is strictly diagonal, i.e. [N, d, d] is now [N, d])

    Arguments:
        X_dim (list) : dimensions of input 2D image, in the form of [Channels, Height, Width]
        latent_dim (int) : dimension of latent space.
    """
    def __init__(self, X_dim=[1,28,28], latent_dim=16):
        super(Encoder, self).__init__()

        conv1_outchannels = 32
        conv2_outchannels = 32

        # How the convolutions change the shape
        conv_outputshape = (
            conv2_outchannels
            * int(((X_dim[1]-4)/2 - 2)/2 + 1)
            * int(((X_dim[2]-4)/2 - 2)/2 + 1)
        )

        self.enc = nn.Sequential(
            ConvBlock(X_dim[0], conv1_outchannels, kernel_size=4, stride=2),
            ConvBlock(conv1_outchannels, conv2_outchannels, kernel_size=3, stride=2)
        )

        self.zmean = nn.Linear(conv_outputshape, latent_dim)
        self.zlogvar = nn.Linear(conv_outputshape, latent_dim)


    def forward (self, X):
        x = self.enc(X)
        x = x.view(X.shape[0], -1)
        mean = self.zmean(x)
        logvar = self.zlogvar(x)

        return mean, logvar


class Decoder (nn.Module):
    """
    Decoder of 2D VAE, producing output multivariate normal distributions from latent vectors. We return logarithmic variances as they have the real numbers as domain.

    Forward pass:
        1 Input:
            i)  Latent vector of shape [N, latent_dim]
        2 Outputs:
            i)  Means of output distributions of shape [N, C, H, W]
            ii) Variances of output distribution of shape [N, C, H, W] (Approximation of multivariate Gaussian, covariance is strictly diagonal). We assume constant variance during VAE training.

    Arguments:
        X_dim (list) : dimensions of input 2D image, in the form of [Channels, Height, Width]
        latent_dim (int) : dimension of latent space.
    """
    def __init__(self, X_dim=[1,28,28], latent_dim=16):
        super(Decoder, self).__init__()

        # Currently number of clusters is set to 4*latent_dim as a good rule of thumb, can be changed, but then also change it later during training!
        self.improved_variance = True
        k = 4*latent_dim
        self.rbfNN = RBF(centers=pt.zeros(k,latent_dim), bandwidth=pt.zeros(k), X_dim=np.prod(X_dim))

        conv1_outchannels = 32
        conv2_outchannels = 32

        # How the convolutions change the shape
        self.conv_outputshape = (
            int(((X_dim[1]-4)/2 - 2)/2 + 1),
            int(((X_dim[2]-4)/2 - 2)/2 + 1)
        )

        self.lin = nn.Linear(latent_dim, conv2_outchannels*self.conv_outputshape[0]*self.conv_outputshape[1])

        self.conv = nn.Sequential(
            ConvTransposeBlock(conv2_outchannels, conv1_outchannels, kernel_size=3, stride=2),
            ConvTransposeBlock(conv1_outchannels, 32, kernel_size=4, stride=2)
        )

        self.Xmean = nn.Sequential(
            nn.Conv2d(32, X_dim[0], kernel_size=1),
            # Output is grayscale between -1 and 1
            nn.Tanh()
        )


    def forward (self, z):
        """
        When improved_variance is set to True, we use a trained RBF to return a better variance estimate as described in "Arvanitidis et al. (2018): Latent Space Oddity". Of course the RBF has to be assigned to the Decoder first.
        """
        x = self.lin(z)
        x = x.view(z.shape[0], -1, self.conv_outputshape[0], self.conv_outputshape[1])
        x = self.conv(x)
        mean = self.Xmean(x)

        if not self.improved_variance:
            # We freeze the variance as constant 0.5. Requires grad so metric computation goes smoothly (i.e. returns no gradient)
            var = pt.ones_like(mean, requires_grad=True) * 0.5
        else:
            var = 1/self.rbfNN(z)

        return mean, var


def train (dataloader, modelE, modelD, latent_dim=2, lr=5e-3, max_epochs=100, device=None):
    """
    Trains the VAE on data presented in dataloader for max_epochs. By default trained using Adam optimizer with learning rate 5e-3 and weight decay 1e-4, and having a multiplicative learning rate scheduler (0.95 multiplier per epoch).

    Arguments:
        dataloader (nn.utils.data.Dataloader) : 2D images of shape [N, C, H, W]
            loaded in batches of N. Unsupervised, so no targets/labels needed.
        latent_dim (int) : latent dimension of autoencoder.

    Returns:
        modelE (Encoder: nn.Module) : trained encoder architecture.
        modelD (Decoder: nn.Module) : trained decoder architecture.
    """
    if not os.path.exists("Outputs"):
        os.makedirs("Outputs")
    if not os.path.exists("TrainedModels"):
        os.makedirs("TrainedModels")

    if device is None:
        device = pt.device('cuda') if pt.cuda.is_available() else pt.device('cpu')

    X_dim = dataloader.dataset[0][0].shape

    ### Initialize encoder, decoder and optimizers
    modelE = modelE
    modelD = modelD
    modelE.train(), modelD.train()
    modelD.improved_variance = False    # During training of VAE no RBF

    # Show the network architectures
    summary(modelE, X_dim)
    summary(modelD, (latent_dim,))

    optimizer = pt.optim.Adam(
                list(modelE.parameters())+list(modelD.parameters()),
                lr=lr,
                weight_decay=1e-4
            )

    scheduler = pt.optim.lr_scheduler.MultiplicativeLR(optimizer, lambda epoch: 0.95)


    ### Loop over epochs
    loss_history = []
    for epoch in range(max_epochs):
        ### Store an evaluation output in every epoch
        with pt.set_grad_enabled(False):
            randidx = np.random.randint(len(dataloader.dataset))
            output = modelD(
                modelE(
                    dataloader.dataset[randidx][0].to(device).unsqueeze(dim=0)
                )[0]
            )[0].detach().cpu().squeeze().numpy()

        fig = plt.figure(figsize=(12, 12))
        # Image from [-1,1] to [0,1]
        plt.imshow((output+1)/2, cmap='gray')
        plt.savefig(f"Outputs/train_{epoch+1: 04d}.png")
        plt.close(fig)


        L = 4
        epoch_loss = 0
        for X_input in dataloader:
            with pt.set_grad_enabled(True):
                # No need for targets/labels
                X_input = X_input[0].to(device)
                modelE.zero_grad(), modelD.zero_grad()

                zmean, zlogvar = modelE(X_input)
                kl_loss = -0.5 * (1 + zlogvar - zmean**2 - zlogvar.exp()).sum(dim=1)

                rec_loss = 0.0
                for _ in range(L):
                    # Reparametrization trick
                    xi = pt.normal(pt.zeros_like(zmean))
                    z = zmean + pt.exp(zlogvar/2) * xi

                    Xmean, Xlogvar = modelD(z)
                    # rec_loss as the negative log likelihood. Constant log(2pi^k/2) keeps loss positive, but is optional.
                    rec_loss += 0.5 * Xlogvar.sum(dim=[1,2,3]) + 0.5 * ((X_input - Xmean)**2 / Xlogvar.exp()).sum(dim=[1,2,3])
                    rec_loss += (np.prod(X_input.shape[1:])/2) * np.log(6.283)
                rec_loss /= L

                loss = (kl_loss + rec_loss).mean()

            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        scheduler.step()

        # Length of dataloader is the amount of batches, not the total number of data points
        epoch_loss /= len(dataloader.dataset)
        loss_history.append(epoch_loss)
        print(f"Epoch [{epoch+1}/{max_epochs}]: Loss {epoch_loss:.4e}")


        ### See how the loss evolves
        fig = plt.figure(figsize=(12,9))
        plt.plot(loss_history, label='Loss History')

        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.xlim(0, epoch+1)
        plt.legend()
        plt.grid(True)
        fig.savefig("Outputs/LossHistory.png", bbox_inches='tight')
        plt.close(fig)

        pt.save(modelE.state_dict(), "TrainedModels/trainedVAE_E.pth")
        pt.save(modelD.state_dict(), "TrainedModels/trainedVAE_D.pth")


    ### After training has finished, create better variance estimate with RBF
    # Currently parameters are just set with default values, work well in general setting.
    rbfNN = trainRBF(modelE, modelD, dataloader, latent_dim, X_dim, k=4*latent_dim, zeta=1e-2, curveMetric=1, max_epochs=50, batch_size=dataloader.batch_size)
    # Set the better estimate in the decoder
    modelD.rbfNN = rbfNN
    modelD.improved_variance = True

    pt.save(modelD.state_dict(), "TrainedModels/trainedVAE_D.pth")

    return modelE, modelD



print("Starting VAE training on MNIST data...")
latent_dim = 40
#dataset = MNISTDigits(
#    list(range(10)),
#    number_of_samples=20000,
#    train=True
#)
#data_loader = pt.utils.data.DataLoader(
#    dataset,
#    batch_size=128,
#    shuffle=True,
#    num_workers=1,
#    pin_memory=True
#)

X_dim = train_loader.dataset[0][0].shape
modelE= Encoder(X_dim, latent_dim).to(device)
modelD = Decoder(X_dim, latent_dim).to(device)
modelE, modelD = train(train_loader, modelE, modelD, latent_dim, lr=5e-3, max_epochs=2)

print("Creating 10x10 grid of samples...")
N = 10
# Plot standard normal Gaussian z
z = pt.randn((N*N, latent_dim)).to(next(modelD.parameters()).device)

with pt.set_grad_enabled(False):
    X_pred = modelD(z)[0].cpu()

save_image(((X_pred+1)/2), "Outputs/VAE_samples.png", nrow=N)



In [None]:
rbfNN = trainRBF(modelE, modelD, train_loader, latent_dim, X_dim, k=4*latent_dim, zeta=1e-2, curveMetric=1, max_epochs=50, batch_size=train_loader.batch_size)
# Set the better estimate in the decoder
modelD.rbfNN = rbfNN
modelD.improved_variance = True

pt.save(modelD.state_dict(), "TrainedModels/trainedVAE_D.pth")

return modelE, modelD


In [None]:
# ------------------------------------------------------------------------------
# Computing metric tensor and metric derivative at points in latent space. Also
# includes computations that require the induced metric (e.g. curve length).
# ------------------------------------------------------------------------------

import torch as pt
import numpy as np
import gc


class InducedMetric:
    """
    Class combining the functionality for the metric tensor.
    """
    def __init__(self, modelG, X_dim, latent_dim, featureMapping=None):
        self.modelG = modelG
        self.modelG.eval()
        self.device = next(modelG.parameters()).device
        self.featureMapping = featureMapping

        self.X_dim = X_dim
        # Input_dim is the scalar dimension of the input, i.e. all dims multiplied if input is multi-dimensional.
        self.input_dim = X_dim
        if not isinstance(self.X_dim, int) and len(self.X_dim) > 1:
            self.input_dim = np.asarray(self.input_dim).prod()
        self.latent_dim = latent_dim


    def curveLength (self, dt, curve_points, curve_derivatives=None, M_batch_size=4):
        """
        For a discretized curve defined by N points we find the curve length. If analytic curve_derivatives can be computed, we use those, otherwise we use finite difference on the curve_points.

        Arguments:
            dt (float or np.ndarray) : Time difference. Scalar if the t are
                uniformly distributed, otherwise vector of time differences (should have one padded value such that shape is [N])
            curve_points (np.ndarray) : Shape (N, d) where d is the coordinate
                dimension of the curve.
            curve_derivatives (np.ndarray) : Shape (N, d) containing the
                derivatives of the curve at each point.

        Returns:
            Scalar value representing the length. No gradients enabled by default.
        """
        N, d = curve_points.shape

        if curve_derivatives is None:
            z_upper = curve_points[1:]
            z_lower = curve_points[:-1]
            z_diff = z_upper - z_lower
            M = self.M_valueAt(pt.Tensor((z_upper + z_lower)/2).to(self.device), M_batch_size=M_batch_size)

            length = np.sqrt(np.matmul(np.matmul(z_diff.reshape(N-1, 1, d), M), z_diff.reshape(N-1, d, 1)).reshape(-1) + 1e-6).sum()

        else:
            M = self.M_valueAt(pt.Tensor(curve_points).to(self.device), M_batch_size=M_batch_size)
            length = dt * np.sqrt(np.matmul(np.matmul(curve_derivatives.reshape(N, 1, d), M), curve_derivatives.reshape(N, d, 1)).reshape(-1) + 1e-6).sum()

        return length


    def curve_measure (self, curve_points, curve_derivatives, M_batch_size=4):
        """
        Computes a measure to describe how much a curve follows the minimal eigenvectors of the induced metric tensor. Can show the improvement that is possible for a certain curve.
        """
        N = curve_points.shape[0]
        M = self.M_valueAt(pt.Tensor(curve_points).to(self.device), M_batch_size)
        derivative_norm = np.sqrt(np.sum(curve_derivatives**2, axis=-1, keepdims=True))

        eig, eigv = np.linalg.eig(M)
        eigS = np.min(eig, axis=-1)
        eigL = np.max(eig, axis=-1)
        eigSIdx = np.argmin(eig, axis=-1)
        # Shape of eigv is [N, d, d] where the COLUMNS of the [d,d] matrix are the eigenvectors of the corresponding eigenvalue.
        eigvS = np.take_along_axis(eigv, eigSIdx.reshape(-1,1,1), axis=-1).reshape(N,-1)

        condition_number = (eigL/eigS).reshape(N)
        scalar_prod = np.einsum('ij, ij -> i', curve_derivatives/(derivative_norm+1e-6), eigvS)
        # Would normally multiply with dt, but cancels out with normalization
        measure = np.sum(condition_number*abs(scalar_prod))
        normalized_measure = measure / np.sum(condition_number)

        return normalized_measure


    def M_valueAt(self, z, M_batch_size=None):
        """
        Computes the M = J.T * J value at a certain position in the latent space z.

        Arguments:
            z (torch.Tensor) : position in latent space. Shape [N, d]
            M_batch_size (int) : as computing the Jacobian and Hessian are very memory intensive, we may wish to use small batches instead. But this can only be used when no gradients are required!

        Returns:
            M (torch.Tensor or np.ndarray) : M matrix at z. Shape [N, d, d]. Numpy output when we use M_batch_size.
        """
        N = 1 if len(z.shape)==1 else z.shape[0]
        z = z.view(N, -1)

        if M_batch_size is not None:
            ### Loop over ourselves in batches. Detach every output and move to CPU.
            M_values = []
            for batch in range(0, N, M_batch_size):
                M_values.append(
                    self.M_valueAt(
                        z[batch: batch+M_batch_size]
                    ).detach().cpu().numpy()
                )
                gc.collect()
            # Could output torch Tensor here, but NumPy will prevent confusion (torch Tensor with no gradients and different device)
            return np.concatenate(M_values, axis=0)

        z_J = z.repeat_interleave(self.input_dim, dim=0)
        z_J.requires_grad_(True)

        X_pred = self.modelG(z_J)
        X_var = None
        # For stochastic decoder
        if isinstance(X_pred, tuple):
            X_pred, X_var = X_pred

        ### Feature Mapping
        # Default is identity matrix mapping to output space
        #M_fx = pt.eye(self.input_dim).view(1,self.input_dim, self.input_dim).repeat(N, 1, 1)
        M_fx = None
        if self.featureMapping is not None:
            x_J = self.modelG(z)
            if isinstance(x_J, tuple):
                # For feature mapping to new output space we just consider mean of stochastic decoder
                x_J = x_J[0]

            x_J = x_J.view(-1, self.input_dim).detach().repeat_interleave(self.featureMapping.out_dim, dim=0)
            x_J.requires_grad_(True)

            feature_out = self.featureMapping(x_J)
            grad_outputs = pt.eye(self.featureMapping.out_dim).repeat(N,1).to(self.device)

            J_fx = pt.autograd.grad(outputs=feature_out, inputs=x_J,
                grad_outputs=grad_outputs, create_graph=True, retain_graph=True,
                only_inputs=True)[0].reshape(N, self.featureMapping.out_dim, self.input_dim)

            M_fx = pt.matmul(pt.transpose(J_fx, 1, 2), J_fx)


        # Generate gradients
        grad_outputs = pt.eye(self.input_dim).repeat(N,1).to(self.device)
        J = pt.autograd.grad(outputs=X_pred.view(-1, self.input_dim), inputs=z_J,
            grad_outputs=grad_outputs, create_graph=True, retain_graph=True,
            only_inputs=True)[0].reshape(N, self.input_dim, self.latent_dim)

        if M_fx is None:
            M = pt.matmul(pt.transpose(J, 1, 2), J)
        else:
            M = pt.matmul(pt.transpose(J, 1, 2), pt.matmul(M_fx, J))


        ### For stochastic decoder
        if X_var is not None:
            X_std = pt.sqrt(X_var)
            J_std = pt.autograd.grad(outputs=X_std.view(-1, self.input_dim), inputs=z_J,
                    grad_outputs=grad_outputs, create_graph=True, retain_graph=True,
                    only_inputs=True, allow_unused=True)[0]

            if J_std is None:
                # Constant variance, i.e. not important to consider
                J_std = pt.zeros(N, self.input_dim, self.latent_dim)
            else:
                J_std = J_std.reshape(N, self.input_dim, self.latent_dim)


            if M_fx is None:
                M += pt.matmul(pt.transpose(J_std, 1, 2), J_std)
            else:
                M += pt.matmul(pt.transpose(J_std, 1, 2), pt.diagonal(M_fx, dim1=-2, dim2=-1).view(N, self.input_dim, 1) * J_std)


        # Prevent singular M matrices
        eps = 1e-6
        return M + eps*pt.eye(self.latent_dim).to(self.device)

In [None]:
# ------------------------------------------------------------------------------
# Bezier and Bspline curve definitions, including a trainable curve with
# variable parameters.
# ------------------------------------------------------------------------------

import torch as pt
import numpy as np

from torch import nn
from scipy.special import binom
from scipy.integrate import solve_bvp

import warnings


def BezierCurve (control_points):
    """
    Creates a Bezier curve using all control points provided. Curve has the dimension of the input control_points.

    Arguments:
        control_points (torch.Tensor) [n, d] : n points to use for the curve, first element
        should be start point and last element the end point.

    Returns:
        Python function taking t in [0,1] and returning a value on the curve [b, D] for batch size b.
        Python function returning the derivative at each point too.
        -- Not anymore needed // Python function for second derivative. Second derivative always exists if n>1 (which is the case for our endpoint interpolation).
    """
    def bernstein (i, n):
        binomial = binom(n, i.cpu()).to(i.device)
        # If i has shape [n] and t has shape [b], result is [b, n]
        return (lambda t: binomial * t**i * (1-t)**(n-i))

    # Curve of order n has points P0 ... Pn, i.e. n+1 total points
    n = control_points.shape[0]-1

    b_0n = bernstein(pt.linspace(0, n, n+1).to(control_points.device), n)
    # Cannot be done using b_0n as case where t=1 would have divide by zeros. otherwise 0**0 is handled correctly.
    b_0n_1 = bernstein(pt.linspace(0, n-1, n).to(control_points.device), n-1)
    #b_0n_2 = bernstein(torch.linspace(0, n-2, n-1).to(control_points.device), n-2)

    # If computation too inefficient, could consider implementing explicit formula for Bezier.
    def curve(t):
        if isinstance(t, pt.Tensor):
            t = t.view(-1, 1)

        return (
            pt.matmul(b_0n(t), control_points),
            n * pt.matmul(b_0n_1(t), control_points[1:] - control_points[:-1]),
            # n * (n-1) * torch.matmul(b_0n_2(t), control_points[2:] - 2*control_points[1:-1] + control_points[:-2])
        )

    return curve


def CubicBSpline (control_points, knot_vector=None):
    """
    Creates a Cubic B-spline using all control points provided. Order 4 B-spline. Requires at least 4 control points!

    Arguments:
        control_points (torch.Tensor) [n, d] : n control points to use for the curve, first element
        should be start point and last element the end point.

    Returns:
        Python function taking t in [0,1] and returning a value on the curve [b, D] for batch size b.
        Python function returning the derivative at each point too.
        Python function for second derivative. Second derivative always exists if n>1 (which is the case for our endpoint interpolation).
    """
    device = control_points.device
    num_control = control_points.shape[0]
    order = 4
    if num_control < order:
        assert False, "Not enough control points for cubic Bspline!"

    if knot_vector is None:
        # Cardinal Bspline if no knot vector is given.
        knot_vector = pt.Tensor(np.concatenate([[0]*3, np.linspace(0, 1, num_control-2), [1]*3]).reshape(1, -1)).to(device)
        # Total length of knot vector is (n+k). We want shape [1,n] such that subtraction will work later on.
    else:
        # Make sure it's a tensor on the right device and has the correct shape.
        knot_vector = pt.as_tensor(knot_vector).view(1, -1).to(device)

    def basis_func (knots, k, t):
        """
        Returns the basis function at control point i and of degree k and evaluated at t, a torch Tensor of shape [N,1].
        If this method is too slow, I can write down the whole formula for cubic Bsplines explicitly and implement it that way.

        Arguments:
            knots (torch.Tensor [1, n])
            k (int)
            t (torch.Tensor [N, 1])

        Returns:
            torch.Tensor [N, n]: the Bspline evaluated at given points for all knots
            torch.Tensor [N, n-1]: derivative of Bspline at given points for all knots
        """
        # Shape of knots is n+k, so n (active knots) is knots.shape - k
        n = knots.shape[1] - k

        if k == 1:
            return pt.where(pt.logical_and(t >= knots[:, :-1], t <= knots[:, 1:]),
                            pt.ones([t.shape[0], n]).to(device),
                            pt.zeros([t.shape[0], n]).to(device)
                        )
        else:
            # Such that total new knot_vector length is still n+k
            B_ik_1 = basis_func(knots[:, :-1], k-1, t)
            B_i1k_1 = basis_func(knots[:, 1:], k-1, t)

            # t - knots[:n] is a shape [N, n] matrix
            term1 = pt.where(knots[:, k-1:-1] - knots[:, :n] != 0,
                                B_ik_1 * (t - knots[:, :n]) / (knots[:, k-1:-1] - knots[:, :n]),
                                pt.zeros_like(B_ik_1))
            term2 = pt.where(knots[:, k:] - knots[:, 1:n+1] != 0,
                                    B_i1k_1 * (knots[:, k:] - t) / (knots[:, k:] - knots[:, 1:n+1]),
                                pt.zeros_like(B_i1k_1))

            res = term1+term2
            if knots.shape[1] == num_control+order:
                # We're in the highest loop
                return res, B_i1k_1[:,:-1]

            return res


    def curve(t):
        if not isinstance(t, pt.Tensor):
            t = pt.as_tensor(t)
        t = t.view(-1, 1).to(device)

        basis, dbasis = basis_func(knot_vector, order, t)
        #dbasis = basis_func(knot_vector[:, 1:-1], order-1, t)
        gamma = pt.matmul(basis, control_points)
        dgamma = (order-1) * pt.matmul(
            dbasis / (knot_vector[:, order:-1] - knot_vector[:, 1:num_control]),
            (control_points[1:] - control_points[:-1])
        )

        return gamma, dgamma

    return curve


class trainableCurve (nn.Module):
    def __init__(self, start, end, max_nodes=10, bspline=True):
        """
        BSpline: Adding nodes to the curve happen in a binary fashion, we always add 2^n nodes such that knots stay the same, and we simply refine the curve by adding more points in between the previous knots.

        Arguments:
            max_nodes (int) : all node parameters are created at the start (such that
                              they are trainable parameters of the module). Counts all nodes in between start and end (excluding both). For BSplines, it's bes that max_nodes is 2^n - 1, wastes no memory in that way.
        """
        super().__init__()
        # There is one UserWarning thrown for instantiating ParameterLists. Should be fixed by PyTorch soon?
        warnings.filterwarnings("ignore", category=UserWarning)

        self.start = nn.Parameter(start, requires_grad=False)
        self.end = nn.Parameter(end, requires_grad=False)
        self.bspline = bspline

        # Without ParameterList the entries within a Parameter are not moved to device.
        # Initialize first 2 points on a straight line, rest will be set later
        self.new_nodes = nn.ParameterList([nn.Parameter(self.start + (i+1)/3 * (self.end - self.start)) for i in range(max_nodes)])

        # Keep as list such that moving to device and adding nodes work as expected.
        # CubicBsplines require 2 new nodes, whereas Bezier can start with 1 node.
        self.points = [self.start, self.new_nodes[0], self.new_nodes[1], self.end]
        self.nodecount = 2
        self.knot_vector = [0,0,0,0,1,1,1,1]

    def add_node(self):
        if (self.nodecount >= len(self.new_nodes)):
            assert False, "Not enough max_nodes to use for another node addition!"

        ### For Bezier -----------
        if not self.bspline:
            self.points = self.points[:-1] + [self.new_nodes[self.nodecount], self.end]
            self.nodecount += 1
        ### ----------------------

        ### For Bspline ----------
        else:
            # We change 4 control_points into 5 control_points on the knot interval that is largest (we halve it).
            # Find largest knot interval:
            k = np.argmax(np.array(self.knot_vector)[1:] - np.array(self.knot_vector)[:-1])
            new_knot = (self.knot_vector[k+1]+self.knot_vector[k])/2
            # Degree 3 curve
            p = 3

            # First we add the new control point
            if self.knot_vector[k+p] - self.knot_vector[k] == 0:
                import pdb; pdb.set_trace()
            with pt.set_grad_enabled(False):
                ratio = (new_knot - self.knot_vector[k]) / (self.knot_vector[k+p] - self.knot_vector[k])
                self.new_nodes[self.nodecount] *= 0
                self.new_nodes[self.nodecount] += (1-ratio)*self.points[k-1] + ratio*self.points[k]

                # We update the points from back to front in-place.
                for i in range(k-1, k-p, -1):
                    if self.knot_vector[i+p] - self.knot_vector[i] == 0:
                        import pdb; pdb.set_trace()
                    ratio = (new_knot - self.knot_vector[i]) / (self.knot_vector[i+p] - self.knot_vector[i])
                    self.points[i] *= ratio
                    self.points[i] += (1-ratio)*self.points[i-1]

            self.points.insert(k, self.new_nodes[self.nodecount])
            self.knot_vector.insert(k+1, new_knot)
            self.nodecount += 1
        ### ----------------------


    def forward(self, t):
        if not self.bspline:
            return BezierCurve(pt.stack(self.points))(t)
        else:
            return CubicBSpline(pt.stack(self.points), self.knot_vector)(t)

In [None]:
# ------------------------------------------------------------------------------
# Training shorter curves than straight line.
# ------------------------------------------------------------------------------

import torch as pt
import numpy as np
import matplotlib.pyplot as plt
import copy

import sys
sys.path.append('./')


def trainGeodesic (bc0, bc1, N, metricSpace, M_batch_size=4, max_epochs=1000, val_epoch=10, verbose=2):
    """
    Finds a shorter curve from bc0 to bc1 (attention, this may not be symmetric) in the metricSpace.

    Arguments:
        bc0 (torch.Tensor [latent_dim]) : starting point for interpolation.
        bc1 (torch.Tensor [latent_dim]) : end point of interpolation.
        N (int) : discretization of shorter curve.
        metricSpace (Geometry.metric.InducedMetric) : contains generator model
            with jacobian computation.
        M_batch_size (int) : batchsize for computation of metric.
        val_epoch (int) : Defines when the curve is reset to optimal.
        verbose (int) : 0 is no plots nor prints, 1 is no plots but print outputs, 2 is both.

    Returns:
        best_gamma (func: [b, 1] -> [b, latent_dim]) : curve function mapping
            scalar parameter to vector points in latent space.
        length_history (list) : list of lengths during training of shorter curve.
    """
    ### Parameters for training
    lr_init = 1e0
    lr_gamma = 0.9
    max_nodecount = 10
    max_hardschedules = 5
    hardschedule_factor = 0.3

    # Have a validation set of points to use for validation. Let's use half of N while training.
    t_val = pt.linspace(0, 1, N)

    gamma = trainableCurve(bc0, bc1, max_nodes=max_nodecount)
    gamma.to(metricSpace.device)

    # Start with straight line
    best_gamma = copy.deepcopy(gamma)
    with pt.set_grad_enabled(False):
        res, diff = best_gamma(t_val.to(metricSpace.device))
        g = res.detach().cpu().numpy()
        dg = diff.detach().cpu().numpy()

    dt = t_val[1] - t_val[0]
    best_length = metricSpace.curveLength(dt, g, dg, M_batch_size=M_batch_size)
    straight_measure = metricSpace.curve_measure(g, dg, M_batch_size=M_batch_size)

    # Let tolerance depend on the length of straight line.
    length_tol = best_length/200.

    print(f"Straight curve length: {best_length:.3f}")
    print(f"Straight curve measure: {straight_measure:.3f}")

    optimizer = pt.optim.Adam(gamma.parameters(), lr=lr_init, weight_decay=1e-4)

    # Multiplies the given lr with lambda every call to the scheduler
    scheduler = pt.optim.lr_scheduler.MultiplicativeLR(optimizer, lambda epoch: lr_gamma)

    hardSchedules = 0
    length_history = [best_length]
    for epoch in range(max_epochs):
        if (epoch+1) % val_epoch:
            # Training: best_gamma unchanged
            runGammaEpoch(gamma, optimizer, scheduler, t_val, metricSpace, M_batch_size=M_batch_size, train=True)
        else:
            # Validation
            length = runGammaEpoch(gamma, None, None, t_val, metricSpace, M_batch_size=M_batch_size, train=False)
            length_history.append(length)

            if verbose >= 1:
                print('-'*10)
                print(f"Learning rate: {optimizer.param_groups[0]['lr']:.5e}")
                print(f"Epoch[{epoch+1:04d}/{max_epochs}]: Length: {length:.3f}")

            length_improvement = best_length - length
            if length < best_length:
                # Store current best network for minimal length
                if verbose >= 1:
                    print("Found better curve!")
                best_gamma = copy.deepcopy(gamma)
                best_length = length

            if length_improvement < length_tol:
                # In case the loss increases, we first wanna rapidly decrease lr before we add nodes.
                # We restart from the best solution when adding nodes or decreasing LR
                if hardSchedules >= max_hardschedules:
                    ### New Node
                    if (best_gamma.nodecount >= max_nodecount):
                        print("Node limit reached!")
                        break
                    if verbose >= 1:
                        print("*** Adding node ***")
                    best_gamma.add_node()

                ### Set gamma, and Reset best_gamma so it isn't trained
                gamma = best_gamma
                best_gamma = copy.deepcopy(best_gamma)

                # Re-initialize the optimizer to only the gamma parameters with gradients
                optimizer = pt.optim.Adam(filter(lambda p: p.requires_grad, gamma.parameters()), lr=lr_init, weight_decay=1e-4)
                scheduler = pt.optim.lr_scheduler.MultiplicativeLR(optimizer, lambda epoch: lr_gamma)

                if hardSchedules < max_hardschedules:
                    if verbose >= 1:
                        print("* Decreasing LR *")
                    hardSchedules += 1
                    optimizer.param_groups[0]['lr'] *= hardschedule_factor**hardSchedules
                else:
                    # Reset hardSchedules when adding new node
                    hardSchedules = 0


    with pt.set_grad_enabled(False):
        res, diff = best_gamma(t_val.to(metricSpace.device))

    curve_measure = metricSpace.curve_measure(res.detach().cpu().numpy(), diff.detach().cpu().numpy(), M_batch_size=M_batch_size)
    del res, diff

    print(f"New curve length: {best_length:.3f}")
    print(f"New curve measure: {curve_measure:.3f}")

    return best_gamma, length_history


def runGammaEpoch(gamma, optimizer, scheduler, t_val, metricSpace, M_batch_size=4, train=True):
    """
    During validation we do not perturb the curve parameter.
    """
    eps = 1e-6
    dt = (t_val[1] - t_val[0]).to(metricSpace.device)
    t = t_val.to(metricSpace.device)

    if train:
        gamma.train()

        # Slightly perturb all the t values, such that it isn't sampled at the same points every time.
        perturb = pt.normal(pt.zeros_like(t), 0.1*dt).to(metricSpace.device)
        t = pt.min(pt.max(t+perturb, 0*t), 0*t+1)
    else:
        gamma.eval()

    length = 0
    gamma.zero_grad()
    for batch in range(0, t_val.shape[0], M_batch_size):
        # Grad necessary during validation as well for M computation
        with pt.set_grad_enabled(True):
            res_batch, diff_batch = gamma(t[batch:batch+M_batch_size])
            N = res_batch.shape[0]
            M = metricSpace.M_valueAt(res_batch)
            # Length minimized
            norm = pt.matmul(pt.matmul(diff_batch.view(N, 1, -1), M), diff_batch.view(N, -1, 1)).view(-1)

            loss = (dt**2) * norm.sum()
            length += dt * pt.sqrt(norm.detach().cpu()+eps).sum()
            # When we don't backward during evaluation too, M clogs up the GPU memory.
            loss.backward()

    if train:
        optimizer.step()
        scheduler.step()

    return length.item()


In [None]:
# ------------------------------------------------------------------------------
# Methods for evaluating generative models.
# ------------------------------------------------------------------------------

import math
import torch as pt
import numpy as np
import matplotlib.pyplot as plt


def create_sequence (model, straight_plot, curve_plot, seq_length=None):
    """
    Compares straight line vs shorter curve side by side. Rounds towards seq_length, round down. So if we wanted 10 sequence length, but our curves had 14 points, then we get 7 in the end (14/10 = 1.4 -> rounded up 2, stepsize 2 [0:14:2]).

    Arguments:
        model (nn.Module) : generative model creating images from latent vectors.
        straight_plot (np.ndarray [N, latent_dim]) : array of points on a
            straight line in latent space.
        curve_plot (np.ndarray [N, latent_dim]) : array of points on a
            shorter curve in latent space.
        seq_length (int) : length of image sequence to create, if None then it's
            the length of straight_plot/curve_plot.

    Returns:
        None, creates a figure and stores it at "Outputs/interpolation_sequence.png".
    """
    print("Creating interpolation sequence...")

    if seq_length is not None:
        N = straight_plot.shape[0]
        straight_plot = straight_plot[::math.ceil(N/seq_length)]
        curve_plot = curve_plot[::math.ceil(N/seq_length)]

    seq_length = straight_plot.shape[0]
    device = next(model.parameters()).device

    # figsize is W x H
    fig = plt.figure(figsize=(seq_length, 2))

    for point_num in range(seq_length):
        ax0 = plt.subplot2grid((2, seq_length), (0, point_num))
        ax0.axis('off')
        ax1 = plt.subplot2grid((2, seq_length), (1, point_num))
        ax1.axis('off')
        if point_num == seq_length-1:
            ax0.text(30, 13, "Straight Curve", fontsize=9)
            ax1.text(30, 13, "Shorter Curve", fontsize=9)

        curve_point = curve_plot[point_num]
        straight_point = straight_plot[point_num]

        with pt.set_grad_enabled(False):
            out_straight = model(pt.Tensor(straight_point).to(device).view(1,-1))
            if isinstance(out_straight, tuple):
                out_straight = out_straight[0]
            out_straight = out_straight.detach().squeeze().cpu().numpy()

            out_curve = model(pt.Tensor(curve_point).to(device).view(1,-1))
            if isinstance(out_curve, tuple):
                out_curve = out_curve[0]
            out_curve = out_curve.detach().squeeze().cpu().numpy()

        ax0.imshow(out_straight, cmap='gray')
        ax1.imshow(out_curve, cmap='gray')

    fig.savefig("Outputs/interpolation_sequence.png", bbox_inches='tight')
    plt.close(fig)


def create_crosscorrelation (model, straight_plot, curve_plot, featureMapping=None):
    """
    Cross correlation of outputs from both curves. Currently implemented as elementwise dot product between images.

    Arguments:
        model (nn.Module) : generative model creating images from latent vectors.
        straight_plot (np.ndarray [N, latent_dim]) : array of points on a
            straight line in latent space.
        curve_plot (np.ndarray [N, latent_dim]) : array of points on a
            shorter curve in latent space.

    Returns:
        None, creates a figure and stores it at "Outputs/cross_correlation.png".
    """
    print("Creating cross-correlation...")
    device = next(model.parameters()).device
    N = straight_plot.shape[0]

    # figsize is W x H
    curve_list = []
    straight_list = []
    for point_num in range(N):
        curve_point = curve_plot[point_num]
        straight_point = straight_plot[point_num]

        with pt.set_grad_enabled(False):
            out_curve = model(pt.Tensor(curve_point).to(device).view(1,-1))
            out_straight = model(pt.Tensor(straight_point).to(device).view(1,-1))

            if isinstance(out_straight, tuple):
                out_curve = out_curve[0]
                out_straight = out_straight[0]

            if featureMapping is not None:
                out_curve = featureMapping(out_curve)
                out_straight = featureMapping(out_straight)

            out_curve = out_curve.detach().squeeze().cpu().numpy().reshape(-1)
            out_straight = out_straight.detach().squeeze().cpu().numpy().reshape(-1)

        curve_list.append(out_curve)
        straight_list.append(out_straight)

    # Maybe we wanna convolve the whole lists
    curve_vec = np.stack(curve_list, axis=0)
    straight_vec = np.stack(straight_list, axis=0)
    # Normalize
    curve_vec /= np.linalg.norm(curve_vec, axis=1, keepdims=True)
    straight_vec /= np.linalg.norm(straight_vec, axis=1, keepdims=True)
    curve_corr = np.dot(curve_vec, np.transpose(curve_vec))
    straight_corr = np.dot(straight_vec, np.transpose(straight_vec))
    #signal.convolve2d(out_curve.reshape(28,28), out_straight.reshape(28,28))

    ### Plotting
    fig, (ax0, ax1) = plt.subplots(2, 1, figsize=(6, 12))
    vmin = min(np.amin(curve_corr), np.amin(straight_corr))
    vmax = max(np.amax(curve_corr), np.amax(straight_corr))
    # Could stick to [0,1] so we can compare images with one another
    levels = np.linspace(vmin, vmax, 15)

    im0 = ax0.contourf(np.linspace(1,N,N), np.linspace(1,N,N), straight_corr, levels=levels, cmap='jet')
    fig.colorbar(im0, ax=ax0)
    ax0.set_title("Straight Curve")

    im1 = ax1.contourf(np.linspace(1,N,N), np.linspace(1,N,N), curve_corr, levels=levels, cmap='jet')
    fig.colorbar(im1, ax=ax1)
    ax1.set_title("Shorter Curve")

    fig.savefig("Outputs/cross_correlation.png", bbox_inches='tight')
    plt.close(fig)

    ### Compute a measure based on cross-correlation
    straight_var = np.var(straight_corr)
    curve_var = np.var(curve_corr)
    print(f"Straight variance: {straight_var:.3f}")
    print(f"Curve variance: {curve_var:.3f}")

In [None]:
# ------------------------------------------------------------------------------
# Interpolation in latent space, by default for MNIST digits, but can be changed
# easily.
# ------------------------------------------------------------------------------

import numpy as np
import torch as pt
import matplotlib.pyplot as plt
import os
import gc
import time

from argparse import ArgumentParser

import sys
sys.path.append('./')




if not os.path.exists("Outputs"):
    os.makedirs("Outputs")

device = pt.device('cuda') if pt.cuda.is_available() else pt.device('cpu')

# Need to manually set bc anyways.
X_dim = [1,28,28]

# Discretization of geodesic curve
N_t = 20
#bc0 = -pt.ones(latent_dim)
#bc1 = pt.ones(latent_dim)

bc0,_ = modelE(pt.zeros((1, 1, 28, 28)).cuda())
#bc0,_ = modelE(train_images[9].view(1,1,28,28).cuda())
bc1,_ = modelE(train_images[9].view(1,1,28,28).cuda())
bc0 = bc0.squeeze()
bc1 = bc1.squeeze()

epochs= 200

M_batch_size = 1

modelG = Decoder(X_dim, latent_dim)
trained_gen = "trainedVAE_D.pth"

modelG.load_state_dict(pt.load(os.path.join("TrainedModels", trained_gen)))
modelG.to(device)
modelG.eval()
print("Generator loaded!")

### Create metric space for curvelengths
metricSpace = InducedMetric(modelG, X_dim, latent_dim)


### Find shorter path than straight line
print("Optimizing for shorter path...")
start = time.time()
best_gamma, length_history = trainGeodesic(
    bc0, bc1, N_t, metricSpace,
    M_batch_size=M_batch_size,
    max_epochs=epochs,
    val_epoch=5
)
print(f"Optimization took {time.time()-start:.1f}s.")

fig, ax1 = plt.subplots(figsize=(12,9))
ax1.plot(np.arange(len(length_history)-1), length_history[1:], linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Length', color='blue')
fig.savefig("Outputs/Length_History.png", bbox_inches='tight')
plt.close(fig)


### Plot shorter curve
t_plot = pt.linspace(0, 1, 2*N_t).to(device).view(-1,1)
dt = 1 / (2*N_t - 1)

with pt.set_grad_enabled(False):
    straight_plot = BezierCurve(pt.stack([bc0, bc1]).to(device))(t_plot)[0].cpu().numpy()
    curve_plot = best_gamma(t_plot)[0].detach().cpu().numpy()


### Evaluate interpolation curves
create_sequence(modelG, straight_plot, curve_plot, seq_length=20)
create_crosscorrelation(modelG, straight_plot, curve_plot)

In [None]:
bc0,_ = modelE(pt.zeros((1, 1, 28, 28)).cuda())
bc1,_ = modelE(train_images[1].view(1,1,28,28).cuda())

In [None]:
dataiter = iter(dataloader)
train_images, train_labels = next(dataiter)

In [None]:
import torchvision.utils as vutils

plt.figure(figsize = (5,10))
out = vutils.make_grid(train_images[0:10], normalize=True)
plt.imshow(out.numpy().transpose((1, 2, 0)))