In [None]:
import math
import matplotlib.pyplot as plt
import numpy as np
from pylab import cm
import torch
import torch.nn as nn
torch.manual_seed(2)

import sys
from google.colab import drive

from torch.distributions import Normal
from typing import List
from tqdm import tqdm
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

drive.mount('/content/drive')
sys.path.append('drive/MyDrive/ATML_HT22')
from models import *
%load_ext autoreload
%autoreload 2


Mounted at /content/drive


In [None]:
# example saving a model for continued training
DRIVE_ROOT='/content/drive/MyDrive/'
def save_model(epoch, model,optimizer, name):
  path = f"{DRIVE_ROOT}/saved_models/{name}" 

  torch.save({
    'steps': epoch,
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
  }, path)

# example loading a model for continued training
def load_model(model,optimizer, name):
  path = f"{DRIVE_ROOT}/saved_models/{name}" 
  checkpoint = torch.load(path)

  model.load_state_dict(checkpoint['model'])
  optimizer.load_state_dict(checkpoint('optimizer'))
  
  model.train()
  model.controller.train()

  return checkpoint['steps']

In [None]:
def omega(x):
  return 1/(1+torch.exp(-x))
def w1(z):
  return torch.sin(2*torch.pi*z[:, 0]/4)
def w2(z):
  return 3*torch.exp(-1/2*((z[:, 0]-1)/0.6)**2)
def w3(z):
  return 3*(omega((z[:, 0]-1)/0.3))

def energy1(z):
  eps=1e-7
  return 1/2*((torch.linalg.norm(z, dim=-1)-2)/0.4)**2-torch.log(torch.exp(-1/2*((z[:, 0]-2)/0.6)**2)+torch.exp(-1/2*((z[:, 0]+2)/0.6)**2)+eps)

def log_density1(z):
  return -energy1(z)

def energy2(z):
  return 1/2*(((z[:, 1]-w1(z))/0.4)**2)

def log_density2(z):
  return-energy2(z)

def energy3(z):
  eps=1e-7
  return -torch.log(torch.exp(-1/2*((z[:, 1]-w1(z))/0.35)**2)+torch.exp(-1/2*((z[:, 1]-w1(z)+w2(z))/0.35)**2)+eps)

def log_density3(z):
  return -energy3(z)

def energy4(z):  
  eps=1e-7
  return -torch.log(torch.exp(-1/2*((z[:, 1]-w1(z))/0.4)**2)+torch.exp(-1/2*((z[:, 1]-w1(z)+w3(z))/0.35)**2)+eps)

def log_density4(z):
  return -energy4(z)

def log_density5(z):
  return Normal(0., 1.).log_prob(z).sum(1)

In [None]:
def log_N(z):
  return Normal(0., 1.).log_prob(z).sum(1)
def fit_posterial_approx(flowModel: FlowModule, optimizer, log_density_fn, dims:List, T:int):
  for t in tqdm(range(T)):
    
    zo = torch.randn(dims).to(DEVICE)
    zk, log_det_sum = flowModel(zo)

    log_pz = log_density_fn(zk).to(DEVICE)
    loss = (log_N(zo)-log_det_sum-log_pz).mean() # R-KL

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


In [None]:
def plot_true_density(true_log_density_fn, ax, axis_min_max=np.array([[-4., 4.], [-4., 4.]])):
  N=1000


  x1, x2 = np.meshgrid(np.linspace(axis_min_max[0][0],axis_min_max[0][1], N), np.linspace(axis_min_max[1][0],axis_min_max[1][1], N))

  tx1 = torch.from_numpy(x1).view(-1)
  tx2 = torch.from_numpy(x2).view(-1)
  tx = torch.stack([tx1, tx2]).T.float().to(DEVICE)
  pz = torch.exp(true_log_density_fn(tx))
  ax.imshow(pz.view(N,N).cpu(),extent=[axis_min_max[0][0],axis_min_max[0][1],axis_min_max[1][0],axis_min_max[1][1]], cmap=cm.jet)
  ax.axis('off')
  ax.set_aspect(1)  



def plot_flow_density(flow_model, ax, axis_min_max=np.array([[-4., 4.], [-4., 4.]])):
  N=1000
  scale=5

  x1, x2 = np.meshgrid(np.linspace(axis_min_max[0][0]*scale,axis_min_max[0][1]*scale, N), np.linspace(axis_min_max[1][0]*scale,axis_min_max[1][1]*scale, N))

  tx1 = torch.from_numpy(x1).view(-1)
  tx2 = torch.from_numpy(x2).view(-1)
  zo = torch.stack([tx1, tx2]).T.float().to(DEVICE)
  
  zk, log_det_sum = flow_model(zo)

  zk = zk.detach().data.cpu().numpy()
  tx1, tx2 = zk[:,0], zk[:,1]
  
  qk = torch.exp(log_N(zo) - log_det_sum).detach().cpu().numpy()
  

  ax.pcolormesh(tx1.reshape(N,N), -1*tx2.reshape(N,N) , qk.reshape(N,N),rasterized=True, cmap =cm.jet )
  ax.axis('off')
  ax.set_xlim(axis_min_max[0][0],axis_min_max[0][1])
  ax.set_ylim(axis_min_max[1][0],axis_min_max[1][1])
  ax.set_aspect(1)  


In [None]:
!mkdir -p saved_models

In [None]:
def train(flow_layers_num, optim_lr, ITERATIONS,batch_size,true_log_density, flow_type='Planar', encoder_out_dim=0):
  D=2
    
  model = FlowModule(D, num_layers=flow_layers_num, flow_type=flow_type, encoder_out_dim=0)
  model.to(DEVICE)
  optimizer = torch.optim.RMSprop(model.parameters(), lr=optim_lr, momentum=0.9)
  fit_posterial_approx(model, optimizer,true_log_density,[batch_size, D], ITERATIONS)
  return model, optimizer

In [None]:


ITERATIONS=500000
batch_size=100*100
optim_lr = 1e-5
flow_lenths=[2, 8, 32]
log_densities =[log_density1, log_density2, log_density3, log_density4]

col=4
row =4

fig, axes =plt.subplots(4, 4, figsize=(8, 8))
#for square graphs

for (i, ax) in enumerate(axes.flatten()):

  true_log_density = log_densities[i//row]

  if(i%col==0):
    plot_true_density(true_log_density, ax)
    if i==0:
      ax.set_title(f'exp[-U(z]')
  else:
    flow_layers_num = flow_lenths[i%col-1]

    try:

      flow_model, optimiser = train(flow_layers_num, optim_lr, ITERATIONS,batch_size,true_log_density, flow_type='Planar', encoder_out_dim=0)
      save_model(ITERATIONS, flow_model,optimiser, f'v2_U={i}-K={flow_layers_num}-lr={optim_lr}-iternations={ITERATIONS}')
      # plot_flow_density(flow_model, ax)
      if(i<col):
        ax.set_title(f'K = {flow_layers_num}')
    except Exception as err:
      print(f' Error on K ={flow_layers_num}')
      print(err)


# plt.savefig('/content/drive/MyDrive/ATML_HT22/figure3.pdf') #/content/drive/MyDrive/ATML_HT22/figure3.pdf
plt.show()

In [None]:
#Figure 1

In [None]:
# Normal Distribtuion


N = 80



fig, axes =plt.subplots(1,4, figsize=(8,2))
fig.tight_layout()



axis_min_max=np.array([[-4., 4.], [-4., 4.]])

x1, x2 = np.meshgrid(np.linspace(axis_min_max[0][0],axis_min_max[0][1], N), np.linspace(axis_min_max[1][0],axis_min_max[1][1], N))

tx1 = torch.from_numpy(x1).view(-1)
tx2 = torch.from_numpy(x2).view(-1)
tx = torch.stack([tx1, tx2]).T

density = Normal(0., 1.).log_prob(tx).sum(1).exp()

axes[0].imshow(density.view(N, N),extent=[axis_min_max[0][0],axis_min_max[0][1],axis_min_max[1][0],axis_min_max[1][1]], cmap=cm.jet.reversed() )
# axes[0].axis('off')
# axes[0].set_xlim(axis_min_max[0][0],axis_min_max[0][1])
# axes[0].set_ylim(axis_min_max[1][0],axis_min_max[1][1])




scale = 4
new_N = N*scale
new_axis_min_max=np.array([[-4., 4.], [-4., 4.]])*scale


flow_type='Planar'
latent_dim =2





for i, K in enumerate([1, 2, 10]):
  
  
  x1_, x2_ = np.meshgrid(np.linspace(new_axis_min_max[0][0],new_axis_min_max[0][1], new_N), np.linspace(new_axis_min_max[1][0],new_axis_min_max[1][1], new_N))

  tx1_ = torch.from_numpy(x1_).view(-1)
  tx2_ = torch.from_numpy(x2_).view(-1)
  tx_ = torch.stack([tx1_, tx2_]).T.type(torch.FloatTensor).to(DEVICE)


  model = FlowModule(latent_dim, num_layers=K, flow_type=flow_type)
  model.to(DEVICE)

  new_tx, log_det_sum = model(tx_)

  new_density = torch.exp(log_N(new_tx)-log_det_sum.view(-1))

  j=i+1
  axes[j].pcolormesh(new_tx[:,0].detach().cpu().numpy().reshape(new_N,new_N), -1*new_tx[:,1].detach().cpu().numpy().reshape(new_N,new_N), new_density.detach().cpu().numpy().reshape(new_N,new_N), rasterized=True, cmap =cm.jet.reversed() )
  axes[j].set_xlim(axis_min_max[0][0],axis_min_max[0][1])
  axes[j].set_ylim(axis_min_max[1][0],axis_min_max[1][1])
  # axes[j].axis('off')
plt.show()

In [None]:
log_N(new_tx).shape

In [None]:

# Uniform

N = 100



fig, axes =plt.subplots(1,4, figsize=(8,2))
fig.tight_layout()


axis_min_max=np.array([[-4., 4.], [-4., 4.]])

x1, x2 = np.meshgrid(np.linspace(axis_min_max[0][0],axis_min_max[0][1], N), np.linspace(axis_min_max[1][0],axis_min_max[1][1], N))

tx1 = torch.from_numpy(x1).view(-1)
tx2 = torch.from_numpy(x2).view(-1)
tx = torch.stack([tx1, tx2]).T

density = density =  torch.ones_like(tx.sum(1))/tx.shape[0]

axes[0].imshow(density.view(N, N),extent=[axis_min_max[0][0],axis_min_max[0][1],axis_min_max[1][0],axis_min_max[1][1]], cmap=cm.jet.reversed())
# axes[0].axis('off')




scale = 4
new_N = N*scale*2
new_axis_min_max=np.array([[-4., 4.], [-4., 4.]])*scale

flow_type='Planar'
latent_dim =2


for i, K in enumerate([1, 2, 10]):
  
  
  x1_, x2_ = np.meshgrid(np.linspace(new_axis_min_max[0][0],new_axis_min_max[0][1], new_N), np.linspace(new_axis_min_max[1][0],new_axis_min_max[1][1], new_N))

  tx1_ = torch.from_numpy(x1_).view(-1)
  tx2_ = torch.from_numpy(x2_).view(-1)
  tx_ = torch.stack([tx1_, tx2_]).T.type(torch.FloatTensor).to(DEVICE)


  model = FlowModule(latent_dim, num_layers=K, flow_type=flow_type)
  model.to(DEVICE)

  new_tx, log_det_sum = model(tx_)
  


  new_density = torch.exp(torch.log(torch.ones_like(new_tx.sum(1))/new_tx.shape[0])-log_det_sum.view(-1))

  j=i+1
  axes[j].pcolormesh(new_tx[:,0].detach().cpu().numpy().reshape(new_N,new_N),-1*new_tx[:,1].detach().cpu().numpy().reshape(new_N,new_N), new_density.detach().cpu().numpy().reshape(new_N,new_N), rasterized=True, cmap =cm.jet.reversed())
  axes[j].set_xlim(axis_min_max[0][0],axis_min_max[0][1])
  axes[j].set_ylim(axis_min_max[1][0],axis_min_max[1][1])
  axes[j].axis('off')
plt.show()