# Organization


1.   Training/loading VAE
2.   Computing geodesics on data manifold
3. Geodesics restricted to straight line
4. Derivative direction vs. interpolation direction
5. Greedy geodesics vs. straight-line geodesics
6. Studying the geodesics in more detail



In [1]:
import os
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import pickle
import sys
import numpy as np
import scipy
from scipy import stats
import matplotlib.pyplot as plt
import math

 # To use the Google Colab GPU acceleration, go to Edit --> Notebook Settings.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
""" The following code will store the trained models in your Google drive, so you
do not need to retrain the models every single time the notebook restarts.
"""
"""from google.colab import drive
drive.mount('/content/drive')
notebook_data_path = '/content/drive/MyDrive/geomtoolkit/'
if not os.path.exists(notebook_data_path):
  os.mkdir(notebook_data_path)"""
notebook_data_path = '/home/gridsan/hanlaw/' # change per-person

# Training/loading the VAE

The following code either trains the VAE, or loads a pre-trained VAE from a pickle file. The latter is preferable if the VAE has not been yet trained.

In [9]:
#@title Code: VAE class definition
"""
The following code is a slightly modified version of the pytorch library's
example directory for representing a simple VAE.
"""
class VAE(nn.Module):
    def __init__(self, latent_d=20):
        super(VAE, self).__init__()

        self.latent_d = latent_d
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, latent_d)
        self.fc22 = nn.Linear(400, latent_d)
        self.fc3 = nn.Linear(latent_d, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [10]:
def get_trained_vae(dataset_name,latent_d):
  model_path = notebook_data_path + 'trained_models/'
  if not os.path.exists(model_path):
    os.mkdir(model_path)
  vae_filename = model_path + dataset_name + '_vae_d' + str(latent_d) + '.pkl'
  if not os.path.exists(vae_filename):
    print('VAE does not exist. Training it now.')
    train_vae(dataset_name,latent_d,vae_filename)
  model = pickle.load(open(vae_filename, 'rb'))
  return model

In [11]:
#@title Code: VAE training
"""
The following code is a slightly modified version of the pytorch library's
example directory for training a simple VAE.
"""
def train_vae(datasetname,latent_d,filename):
    sys.argv = ['-f']

    parser = argparse.ArgumentParser(description='VAE MNIST Example')
    parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 128)')
    parser.add_argument('--epochs', type=int, default=10, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--vis-interval',type=int, default=10, metavar='N',
                        help='how many batches to wait before dumping visualization')
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    # torch.manual_seed(args.seed)

    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}


    dataset_name = datasetname
    if dataset_name == 'mnist':
        traindataset = datasets.MNIST(notebook_data_path + 'data/', train=True, download=True,
                      transform=transforms.ToTensor())

        testdataset = datasets.MNIST(notebook_data_path + 'data/', train=False, transform=transforms.ToTensor())

    elif dataset_name in ['angle', 'circleangle']:
        transformlist = transforms.Compose([transforms.Grayscale(num_output_channels=1),
                                        transforms.ToTensor()])
        traindataset = datasets.ImageFolder('./data/' + dataset_name + '/', transform=transformlist)
        testdataset = datasets.ImageFolder('./data/' + dataset_name + '/', transform=transformlist)

    elif dataset_name == 'untrained':
        args.epochs = 0
        # dummy code
        traindataset = datasets.MNIST(notebook_data_path + 'data/', train=True, download=True,
              transform=transforms.ToTensor())

        testdataset = datasets.MNIST(notebook_data_path + 'data/', train=False, transform=transforms.ToTensor())


    train_loader = torch.utils.data.DataLoader(traindataset,
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(testdataset,
        batch_size=args.batch_size, shuffle=True, **kwargs)

    # Reconstruction + KL divergence losses summed over all elements and batch
    def vae_loss_function(recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        return BCE + KLD

    def vae_train(model, optimizer, epoch):
        model.train()
        train_loss = 0
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(device)
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(data)
            loss = vae_loss_function(recon_batch, data, mu, logvar)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    loss.item() / len(data)))
            # if batch_idx % args.vis_interval == 0:
                # vis_weights(model)

        print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))

    model = VAE(latent_d = latent_d).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(1, args.epochs + 1):
        vae_train(model, optimizer, epoch)
    pickle.dump(model, open(filename, 'wb'))
    print('trained',latent_d)

In [12]:
"""
Train VAE on MNIST dataset with latent dimension 5.
The first time that you do this, the model will save the model to your Google
Drive to allow for re-use later on.
"""
model = get_trained_vae('mnist', 5)

VAE does not exist. Training it now.
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /home/gridsan/hanlaw/data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

URLError: <urlopen error [Errno 101] Network is unreachable>

# Computing geodesics on manifolds

Implementation of the "classic" approach where you start with straight-line interpolation in latent space and perturb it to obtain shorter a path.

In [None]:
import torch.nn.functional as F
import torch.optim as optim

def minimize_path_energy(model, z0, z1, numpts=10,numtrainiter=1000,learning_rate=0.001,verbose=False):
  """
  Initialization: a linear interpolating path between z0 and z1 with numpts number of points.
  Output: a path with numpts between z0 and z1 trained to minimize energy with respect to the model.
  """

  # Initialize the interpolation between z0 and z1 with a linear interpolation.
  t = torch.linspace(1 / (numpts + 1), numpts / (numpts + 1), numpts).to(device)
  interp_points = torch.outer(1 - t, torch.flatten(z0)) + torch.outer(t, torch.flatten(z1))
  interp_points.requires_grad = True

  # For convenience, precompute x0 and x1.
  x0 = model.decode(z0).detach()
  x1 = model.decode(z1).detach()

  optimizer = optim.Adam([interp_points], lr=learning_rate)
  for i in range(numtrainiter):
    imgs = model.decode(interp_points)

    # compute the energy of the path
    loss = torch.sum((torch.cat((x0, imgs)) - torch.cat((imgs, x1)))**2)

    # # compute the length of the path
    # loss = torch.sqrt(torch.sum((torch.cat((x0, imgs)) - torch.cat((imgs, x1)))**2))

    if verbose:
      with torch.no_grad():
        if i % 1000 == 0:
          print('loss',i,loss.to('cpu'))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  imgs = model.decode(interp_points)

  # compute the energy of the path
  loss = torch.sum((torch.cat((x0, imgs)) - torch.cat((imgs, x1)))**2)

  # # compute the length of the path
  # loss = torch.sqrt(torch.sum((torch.cat((x0, imgs)) - torch.cat((imgs, x1)))**2))

  return torch.cat((z0, interp_points.detach(), z1)), loss.detach()

In [None]:
def get_curve_length(model, zseq):
  length = 0
  for i in range(zseq.shape[0]-1):
    length = length + torch.sqrt(torch.sum((model.decode(zseq[i,:]) - model.decode(zseq[i+1,:])) ** 2))
  return length

In [None]:
def disp_vec_img(a,filename=None):
  with torch.no_grad():
    a = np.asarray(a.to('cpu'))
    pixels = a.reshape((28, 28))
    plt.imshow(pixels, cmap='gray')
    if filename is None:
      plt.show()
    else:
      plt.savefig(filename)

In [None]:
import imageio
from base64 import b64encode
from IPython.display import HTML

def display_imageseq_video(zseq, video_file='test.mp4'):
  with torch.no_grad():
    imageseq = model.decode(zseq)
    imageseq = (imageseq * 255).byte()
  imageseq = np.asarray(imageseq.to('cpu'))
  imageseq = [np.pad(imageseq[i,:].reshape((28,28)),((2,2),(2,2))) for i in range(imageseq.shape[0])]
  imageio.mimwrite('test.mp4', imageseq, ffmpeg_params=['-sws_flags', 'neighbor', '-vf', 'scale=320:320'], fps=200); 
  mp4 = open('test.mp4','rb').read()
  data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
  htmlobj = HTML("""
  <video width=320 controls>
        <source src="%s" type="video/mp4">
  </video>
  """ % data_url)
  return htmlobj

In [None]:
z0 = torch.randn([1, model.latent_d]).to(device)
z1 = torch.randn([1, model.latent_d]).to(device)
print('We will be interpolating between the following two random digits:')
disp_vec_img(model.decode(z0))
disp_vec_img(model.decode(z1))

In [None]:
numpts = 1000
zlinearseq, _ = minimize_path_energy(model, z0, z1, numpts=numpts,numtrainiter=0)
zseq, _ = minimize_path_energy(model, z0, z1, numpts=numpts,numtrainiter=3001,learning_rate=0.0005,verbose=True)

In [None]:
get_curve_length(model, zlinearseq)


In [None]:
get_curve_length(model, zseq)

In [None]:
def component_orth_to(vec, subspace):
  vec2 = torch.clone(vec)
  sub2 = torch.clone(subspace)
  # gram-schmidt orthogonalization of sub2
  for i in range(0,sub2.shape[0]):
    for j in range(i):
      sub2[i,:] = sub2[i,:] - sub2[j,:] * torch.dot(sub2[j,:], sub2[i,:])
    sub2[i,:] = sub2[i,:] / torch.norm(sub2[i,:])
  # print(sub2)
  
  for i in range(vec.shape[0]):
    for j in range(0,sub2.shape[0]):
      vec2[i,:] = vec2[i,:] - sub2[j,:] * torch.dot(sub2[j,:],vec2[i,:])

  return vec2


orth_to_1d = component_orth_to(zseq - z0, z1 - z0)
orth_to_2d = component_orth_to(zseq - z0, torch.cat((z0,z1)))
# print(torch.norm(ans[0,:]))
# print(torch.norm(ans[1,:]))
# print(torch.dot(ans[0,:],ans[1,:]))
print(torch.norm(zseq,dim=1))
print(torch.norm(orth_to_1d,dim=1))
print(torch.norm(orth_to_2d,dim=1))
plt.plot(np.asarray((torch.norm(orth_to_1d,dim=1) / torch.norm(orth_to_2d,dim=1)).detach().cpu()))

# The takeaway of this code block when run on the untrained network
# is that projection of zseq to (z0,z1) plane is roughly
# the same as projection to z0 -> z1 line. This indicates that any deviations
# from the straight line interpolation are not due to "curvature of the manifold",
# but rather to "noise" effects.

In [None]:
zseq

In [None]:
display_imageseq_video(zlinearseq,video_file='test1.mp4')

In [None]:
display_imageseq_video(zseq,'test2.mp4')

# Geodesics restricted to the straight line

In this section, we optimize over geodesics restricted to the straight line. Obviously, the only improvement over the original geodesic is due to the spacing of the points on the straight line improving.

This can be computed in two ways:

(1) We can take the arc length of the straight line (computed using Euler integration), and derive the optimal energy of an interpolation from it.

(2) We can optimize with Adam.

Naturally, I will implement and compare both approaches.

In [None]:
import torch.nn.functional as F
import torch.optim as optim

def minimize_straight_line_energy(model, z0, z1, numpts=10,numtrainiter=1000,learning_rate=0.001,verbose=False):
  """
  Initialization: a linear interpolating path between z0 and z1 with numpts number of points.
  Output: a path with numpts between z0 and z1 trained to minimize energy with respect to the model.
  """

  # Initialize the interpolation between z0 and z1 with a linear interpolation.
  t = torch.linspace(1 / (numpts + 1), numpts / (numpts + 1), numpts).to(device)
  t.requires_grad = True

  # For convenience, precompute x0 and x1.
  x0 = model.decode(z0).detach()
  x1 = model.decode(z1).detach()

  optimizer = optim.Adam([t], lr=learning_rate)
  for i in range(numtrainiter):
    interp_points = torch.outer(1 - t, torch.flatten(z0)) + torch.outer(t, torch.flatten(z1))
    imgs = model.decode(interp_points)

    # compute the energy of the path
    loss = torch.sum((torch.cat((x0, imgs)) - torch.cat((imgs, x1)))**2)

    if verbose:
      with torch.no_grad():
        if i % 1000 == 0:
          print('loss',i,loss.to('cpu'))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  interp_points = torch.outer(1 - t, torch.flatten(z0)) + torch.outer(t, torch.flatten(z1))
  imgs = model.decode(interp_points)
  loss = torch.sum((torch.cat((x0, imgs)) - torch.cat((imgs, x1)))**2)
  return torch.cat((z0, interp_points.detach(), z1)), loss.detach()

In [None]:
numtrials = 1
dist_parameter = 4
diff = torch.randn([numtrials, model.latent_d]).to(device) * math.sqrt(2)
diff = diff / torch.norm(diff,dim=1).view((numtrials,1)) * dist_parameter
z0 = torch.randn([numtrials, model.latent_d]).to(device) * math.sqrt(1/2) + diff / 2
z1 = z0 + diff

In [None]:
z0 = torch.randn([1, model.latent_d]).to(device)
z1 = torch.randn([1, model.latent_d]).to(device)
print('We will be interpolating between the following two random digits:')
disp_vec_img(model.decode(z0))
disp_vec_img(model.decode(z1))

In [None]:
numpts = 1000
znaiveseq, naiveenergy = minimize_path_energy(model, z0, z1, numpts=numpts,numtrainiter=0)
zstraightseq, straightenergy = minimize_straight_line_energy(model, z0, z1, numpts=numpts,numtrainiter=10000,learning_rate=0.0005,verbose=True)
zseq, pathenergy = minimize_path_energy(model, z0, z1, numpts=numpts,numtrainiter=10000,learning_rate=0.0005,verbose=True)

In [None]:
print(get_curve_length(model,znaiveseq))
print(get_curve_length(model,zstraightseq))
print(get_curve_length(model,zseq))

In [None]:
model = get_trained_vae('untrained',5)

naive_lengths = []
straight_lengths = []
optimized_lengths = []
numtrials = 1

for trial_num in range(numtrials):
  z0 = torch.randn([1, model.latent_d]).to(device)
  z1 = torch.randn([1, model.latent_d]).to(device)

  # dist_parameter = 100
  # diff = torch.randn([1, model.latent_d]).to(device) * math.sqrt(2)
  # diff = diff / torch.norm(diff,dim=1).view((1,1)) * dist_parameter
  # z0 = torch.randn([1, model.latent_d]).to(device) * math.sqrt(1/2) + diff / 2
  # z1 = z0 + diff

  # print('We will be interpolating between the following two random digits:')
  # disp_vec_img(model.decode(z0))
  # disp_vec_img(model.decode(z1))
  print('trial',trial_num)
  numpts = 20
  znaiveseq, naiveenergy = minimize_path_energy(model, z0, z1, numpts=numpts,numtrainiter=0)
  # zstraightseq, straightenergy = minimize_straight_line_energy(model, z0, z1, numpts=numpts,numtrainiter=3001,learning_rate=0.0005,verbose=True)
  zseq, pathenergy = minimize_path_energy(model, z0, z1, numpts=numpts,numtrainiter=10001,learning_rate=0.0005,verbose=True)

  with torch.no_grad():
    naive_lengths.append(get_curve_length(model,znaiveseq).cpu())
    # straight_lengths.append(get_curve_length(model,zstraightseq).cpu())
    optimized_lengths.append(get_curve_length(model,zseq).cpu())

naive_lengths = np.asarray(naive_lengths)
# straight_lengths = np.asarray(straight_lengths)
optimized_lengths = np.asarray(optimized_lengths)

In [None]:
plt.scatter(np.linspace(0,25,100),np.linspace(0,25,100))
plt.scatter(np.asarray(naive_lengths), np.asarray(optimized_lengths))
# plt.scatter(np.asarray(naive_lengths), np.asarray(straight_lengths))
plt.show()
plt.scatter(np.asarray(naive_lengths),np.asarray(optimized_lengths) / np.asarray(naive_lengths))
plt.show()
# straight_lengths - naive_lengths

In [None]:
np.asarray([x.cpu() for x in optimized_lengths])

In [None]:
zseq = minimize_path_energy(model, z0, z1, numpts=numpts,numtrainiter=100000,learning_rate=0.0005,verbose=True)

Below, we implement the other method, which is a simple analytical calculation based on the arc-length [TODO]

# Directions of derivatives

Although the geodesics are not straight lines in the latent space, they are pretty close! Why could this be?

#### Hypothesis:
For any $z_0,z_1$ in the latent space, we have $\frac{\partial (g(z) - g(z_1))^2}{\partial z}|_{z = z_0} \approx z_1 - z_0$, where $\approx$ means that the directions are roughly aligned. This would be the case if, for example, $g$ were a linear model. The hypothesis says that, to first-order approximation, the neural network is a linear model.

#### Takeaway:
The hypothesis is roughly correct, although a little less so if z0,z1 are far apart or if the dimension is large. See below.

In [None]:
import torch.nn.functional as F
import torch.optim as optim

def grad_gz0_to_gz1(model, z0, z1):
  """
  Given z0, z1, find the direction of the gradient at z0 that moves g(z0) to g(z1)
  """
  z0.requires_grad = True
  x0 = model.decode(z0)
  x1 = model.decode(z1).detach()

  loss = torch.sum((x0 - x1)**2)
  loss.backward()
  
  return z0.grad

In [None]:
import math
for latent_d in [100]:
  for dist_parameter in [None, 0.5, 2, 4, 8, 16]:
    model = get_trained_vae('untrained',latent_d)

    numtrials = 10000

    if dist_parameter is None:
      z0 = torch.randn([numtrials, model.latent_d]).to(device)
      z1 = torch.randn([numtrials, model.latent_d]).to(device)

    else:
      diff = torch.randn([numtrials, model.latent_d]).to(device) * math.sqrt(2)
      diff = diff / torch.norm(diff,dim=1).view((numtrials,1)) * dist_parameter
      z0 = torch.randn([numtrials, model.latent_d]).to(device) * math.sqrt(1/2) + diff / 2
      z1 = z0 + diff

    z0grad = grad_gz0_to_gz1(model, z0, z1)
    z0delta = z1 - z0
    sample_distance = torch.norm(z0delta,dim=1).detach().cpu()

    z0grad = z0grad / torch.norm(z0grad,dim=1).view((numtrials,1))

    z0delta = z0delta / torch.norm(z0delta,dim=1).view((numtrials,1))

    grad_delta_alignment = torch.sum(z0grad * z0delta, dim=1).detach().cpu()
    plt.hist(grad_delta_alignment,bins=20)
    plt.title('Dot product of grad with linear interpolation direction. latent_d =' + str(latent_d) )
    plt.show()

    plt.scatter(sample_distance, grad_delta_alignment,marker='.')
    plt.xlabel('Distance of z0 to z1')
    plt.ylabel('Alignment of v = z1 - z0 vs. w = direction of gradient')
    plt.show()

# Greedy geodesic (start at z0 and follow the gradient so that g(z0) reaches g(z1))

The takeaway that I got from my limited experiments here is that straight-line interpolation is better than the path found by the greedy geodesic algorithm.

However, I only ran two trials, with d = 5, so this should be taken with a grain of salt.

Also, this code is not very well written since I was jumping around, so it will not run if executed in a straight line.

In [None]:
for latent_d in [5]:
  numtrials = 1
  numsteps = 10000
  step_size = 0.001
  z0 = torch.randn([numtrials, model.latent_d]).to(device)
  z1 = torch.randn([numtrials, model.latent_d]).to(device)

  ptlist = []
  currz = torch.clone(z0)

  currz.requires_grad = True
  x1 = model.decode(z1).detach()
  optimizer = optim.Adam([currz], lr=0.0001)
  for step in range(numsteps):
    ptlist.append(torch.clone(currz))
    currx = model.decode(currz)
    loss = torch.sum((currx - x1)**2)
    print(loss)
    if loss < 0.02:
      break

    loss.backward()
    # print(torch.norm(currz.grad,dim=1))
    # print(currz.grad)
    with torch.no_grad():
      graddir = currz.grad / torch.norm(currz.grad, dim=1).view(numtrials, 1)
      currz.add_(-graddir, alpha=step_size)

      # print(graddir)

    optimizer.zero_grad()

    # z0grad = z0grad / torch.norm(z0grad,dim=1).view((numtrials,1))
    # print(currz)

In [None]:
ptlist

In [None]:
import torch.nn.functional as F
import torch.optim as optim

def minimize_path_energy(model, z0, z1, init_path=None, numpts=10,numtrainiter=1000,learning_rate=0.001,verbose=False):
  """
  Initialization: a linear interpolating path between z0 and z1 with numpts number of points.
  Output: a path with numpts between z0 and z1 trained to minimize energy with respect to the model.
  """

  if init_path is None:
    # Initialize the interpolation between z0 and z1 with a linear interpolation.
    t = torch.linspace(1 / (numpts + 1), numpts / (numpts + 1), numpts).to(device)
    interp_points = torch.outer(1 - t, torch.flatten(z0)) + torch.outer(t, torch.flatten(z1))
  else:
    interp_points = init_path
  interp_points.requires_grad = True

  # For convenience, precompute x0 and x1.
  x0 = model.decode(z0).detach()
  x1 = model.decode(z1).detach()

  optimizer = optim.Adam([interp_points], lr=learning_rate)
  for i in range(numtrainiter):
    imgs = model.decode(interp_points)

    # compute the energy of the path
    loss = torch.sum((torch.cat((x0, imgs)) - torch.cat((imgs, x1)))**2)

    if verbose:
      with torch.no_grad():
        if i % 1000 == 0:
          print('loss',i,loss.to('cpu'))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  imgs = model.decode(interp_points)
  loss = torch.sum((torch.cat((x0, imgs)) - torch.cat((imgs, x1)))**2)
  return torch.cat((z0, interp_points.detach(), z1)), loss.detach()

In [None]:
path, loss = minimize_path_energy(model,z0,z1,numpts=4852,numtrainiter=10000,learning_rate=0.0001,verbose=True)

In [None]:
compute_path_length(model, path)

In [None]:
def compute_path_length(model,path):
  l = 0
  for i in range(path.shape[0]-1):
    xip1 = model.decode(path[i+1,:])
    xi = model.decode(path[i,:])
    l += torch.norm(xip1 - xi)
  return l

In [None]:
len(ptlist)

In [None]:
compute_path_length(model,torch.cat(ptlist).detach())

In [None]:
minimize_path_energy(model,z0,z1,init_path=torch.cat(ptlist).detach(), numpts=numsteps,numtrainiter=20000,verbose=True)

# Deprecated code: Geodesics of networks at initialization using analytic formula.

This runs into some errors in the calculation of the derivative of the arccosine, and the analytic formula does not account for the random biases. I found it better to simply use the 'untrained' network instead.



In [None]:
#@title Deprecated code: analytic distance calculation
import torch.nn.functional as F
import torch.optim as optim

def get_pairwise_energy_random_net(z1,z2,L):
  norm1 = torch.sum(torch.square(z1))
  norm2 = torch.sum(torch.square(z2))
  dotprod = torch.sum(z1 * z2)

  normprod = torch.sqrt(norm1 * norm2)
  latentangle = torch.acos(dotprod / normprod)

  currangle = latentangle
  for i in range(L):
    currangle = torch.acos(((np.pi - currangle) * torch.cos(currangle) + torch.sin(currangle))/np.pi)
  cosgenangle = torch.cos(currangle)
  # cosgenangle = ((np.pi - latentangle) * torch.cos(latentangle) + torch.sin(latentangle))/np.pi
  # cosgenangle = torch.cos(latentangle)

  energy = norm1 + norm2 - 2 * normprod * cosgenangle

  return energy
  # print(energy)
  # length = torch.sqrt(energy)
  # length = energy ** (1/2)
  # print('Returning length')
  # # print(length)
  # # print(norm1, norm2, dotprod)
  # # print(energy)
  # # return torch.sqrt(energy)
  # return length

In [None]:
#@title Deprecated code: minimize path energy random net

import torch.nn.functional as F
import torch.optim as optim

def minimize_path_energy_random_net(z0, z1, L, numpts=10,numtrainiter=1000,learning_rate=0.001,verbose=False):
  """
  Initialization: a linear interpolating path between z0 and z1 with numpts number of points.
  Output: a path with numpts between z0 and z1 trained to minimize energy with respect to the model.
  """

  # Initialize the interpolation between z0 and z1 with a linear interpolation.
  t = torch.linspace(1 / (numpts + 1), numpts / (numpts + 1), numpts).to(device)
  interp_points = torch.outer(1 - t, torch.flatten(z0)) + torch.outer(t, torch.flatten(z1))
  # interp_points = torch.rand_like(z0)
  interp_points.requires_grad = True

  def totloss(interp_points):
    # compute the energy of the path
    loss = get_pairwise_energy_random_net(z0,interp_points[0,:],L)
    for i in range(numpts-1):
      loss += get_pairwise_energy_random_net(interp_points[i,:],interp_points[i+1,:],L)
    loss += get_pairwise_energy_random_net(interp_points[numpts-1,:],z1,L)
    return loss

  optimizer = optim.Adam([interp_points], lr=learning_rate)
  for i in range(numtrainiter):
    imgs = model.decode(interp_points)

    loss = totloss(interp_points)

    # print(loss)


    if verbose:
      with torch.no_grad():
        if i % 100 == 0:
          print('loss',i,loss.to('cpu'))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  loss = totloss(interp_points)
  return torch.cat((z0, interp_points.detach(), z1)), loss.detach()

In [None]:
#@title Deprecated: optimize paths using analytic distance for random net

import math
for latent_d in [5]:
  # for dist_parameter in [None, 0.5, 2, 4, 8, 16]:
  for dist_parameter in [None]:
    model = get_trained_vae('mnist',latent_d)

    numtrials = 1
    assert(numtrials == 1)

    if dist_parameter is None:
      z0 = torch.randn([numtrials, model.latent_d]).to(device)
      z1 = torch.randn([numtrials, model.latent_d]).to(device)

    else:
      diff = torch.randn([numtrials, model.latent_d]).to(device) * math.sqrt(2)
      diff = diff / torch.norm(diff,dim=1).view((numtrials,1)) * dist_parameter
      z0 = torch.randn([numtrials, model.latent_d]).to(device) * math.sqrt(1/2) + diff / 2
      z1 = z0 + diff

    print(get_pairwise_energy_random_net(z0,z1,L=2))
    torch.autograd.set_detect_anomaly(True)
    outputs = minimize_path_energy_random_net(z0,z1,L=1,numpts=10,learning_rate=0.01,verbose=True)
    print(outputs)

    # z0grad = grad_gz0_to_gz1(model, z0, z1)
    # z0delta = z1 - z0
    # sample_distance = torch.norm(z0delta,dim=1).detach().cpu()

    # z0grad = z0grad / torch.norm(z0grad,dim=1).view((numtrials,1))

    # z0delta = z0delta / torch.norm(z0delta,dim=1).view((numtrials,1))

    # grad_delta_alignment = torch.sum(z0grad * z0delta, dim=1).detach().cpu()
    # plt.hist(grad_delta_alignment,bins=20)
    # plt.title('Dot product of grad with linear interpolation direction. latent_d =' + str(latent_d) )
    # plt.show()

    # plt.scatter(sample_distance, grad_delta_alignment,marker='.')
    # plt.xlabel('Distance of z0 to z1')
    # plt.ylabel('Alignment of v = z1 - z0 vs. w = direction of gradient')
    # plt.show()

In [None]:
# George Stepaniants Code: Computing Jacobian Statistics at Sample Points on Manifold
model_example = get_trained_vae('untrained', 20)

trials = 1000
conds = []
lmaxs = []
lmins = []
for i in range(trials):
  print(i)
  z = torch.randn([model_example.latent_d]).to(device)
  Jg = torch.autograd.functional.jacobian(model_example.decode, z)
  G = torch.matmul(Jg.t(), Jg)

  cond = torch.linalg.cond(G).item()
  lmax = torch.lobpcg(G)[0].item()
  lmin = torch.lobpcg(G, largest=False)[0].item()

  conds.append(cond)
  lmaxs.append(lmax)
  lmins.append(lmin)


In [None]:
plt.figure(1)
plt.title('Condition Number of Jg^T * Jg')
plt.hist(conds, bins=30, density=True)

plt.figure(2)
plt.title('Maximum Eigenvalue of Jg^T * Jg')
plt.hist(lmaxs, bins=30, density=True)

plt.figure(3)
plt.title('Minimum Eigenvalue of Jg^T * Jg')
plt.hist(lmins, bins=30, density=True)

plt.figure(4)
plt.title('Minimum vs. Maximum Eigenvalues')
plt.scatter(lmins, lmaxs)
lims = [np.min([plt.xlim(), plt.ylim()]), np.max([plt.xlim(), plt.ylim()])]
plt.plot(lims, lims, 'r-', alpha=0.75, zorder=0)
plt.xlabel('Minimum Eigenvalue')
plt.ylabel('Maximum Eigenvalue')

In [None]:
print(np.mean(conds))
import statistics
print(statistics.median(conds))