In [None]:
#package imports

import torch
import torch.nn as nn
from torch import Tensor
from torch.distributions import MultivariateNormal
import scipy.optimize
import scipy.stats
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
from typing import Tuple
from typing import Callable
from mpl_toolkits.mplot3d import Axes3D
from pathlib import Path
import timeit as time

In [None]:
#the planar flow class chains the planar flow layers using nn.Sequential

class PlanarFlow(nn.Module):
    def __init__(self, K: int = 6): #K is the number of planar flow layers chained together
        super().__init__()
        
        #chain the planar transforms together in a list
        self.layers = [PlanarTransform() for _ in range(K)]
        
        #create the model from the list
        self.model = nn.Sequential(*self.layers)
        
    def forward(self, z: Tensor) -> Tuple[Tensor, float]:
        
        #set the log of the Jacobian to zero
        log_det_J = 0
        
        for layer in self.layers:
            
            #sum the log Jacobian of each layer
            log_det_J += layer.log_det_J(z)
            
            #calculate the new points from the planar flow
            z = layer(z)
            
        return z, log_det_J

In [None]:
#the planar transform class contains the functions for the planar flow

class PlanarTransform(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        #randomly assign starting values to the planar flow parameters from a normal distribution
        self.u = nn.Parameter(torch.randn(1, 2).normal_(0, 0.1))
        self.w = nn.Parameter(torch.randn(1, 2).normal_(0, 0.1))
        self.b = nn.Parameter(torch.randn(1).normal_(0, 0.1))
        
    def forward(self, z: Tensor) -> Tensor:
        
        #check the invertibility condition
        if torch.mm(self.u, self.w.T) < -1:
            
            #update u to ensure invertibility
            self.get_u_hat()
            
        #return the planar flow layer function
        return z + self.u*nn.Tanh()(torch.mm(z, self.w.T) + self.b)
    
    def log_det_J(self, z: Tensor) -> Tensor:
        
        #check the invertibility condition
        if torch.mm(self.u, self.w.T) < -1:
            self.get_u_hat()
            
        #calculate the log of the Jacobian
        a = torch.mm(z, self.w.T) + self.b
        psi = (1 - nn.Tanh()(a)**2)*self.w
        abs_det = (1 + torch.mm(self.u, psi.T)).abs()
        log_det = torch.log(1e-10 + abs_det)
        
        return log_det
    
    def get_u_hat(self) -> None:
        
        #invertibility condition
        wtu = torch.mm(self.u, self.w.T)
        m_wtu = -1 + torch.log(1 + torch.exp(wtu))
        self.u.data = (self. u + (m_wtu - wtu)*self.w/torch.norm(self.w, p=2, dim=1)**2)

In [None]:
#the target distribution class holds the definitions for the 2D target distribution used in the examples

class TargetDistribution:
    
    def __init__(self, name: str, t: int, a: int): #t is the annealing value, a is the mean for GMM
        
        
        #get the name of the target distirbution to be used
        self.func = self.get_target_distribution(name, t, a)
        
    def __call__(self, z: Tensor) -> Tensor:
        
        #return the target function
        return self.func(z)
    
    @staticmethod
    def get_target_distribution(name: str, t: int, a: int) -> Callable[[Tensor], Tensor]:
        
        #target function in 2 dimensions
        if name == "Bimodal2D":
        
            def Bimodal2D(z):
                
                f = 0.5*16/(np.pi)*torch.exp(-16*t*((z[:,0]+1+a)**2 + (z[:,1]-a)**2)) + 0.5*16/(np.pi)*torch.exp(-16*t*((z[:,0]-1-a)**2 + (z[:,1]-a)**2))
                return f
            
            return Bimodal2D

In [None]:
#free energy loss function used

class VariationalLoss(nn.Module):
    
    def __init__(self, distribution: TargetDistribution, mean: int, std: int):
        super().__init__()
        
        #the target distribution
        self.distr = distribution
        
        #the starting distirbution
        self.base_distr = MultivariateNormal(mean*torch.ones(2), (std**2)*torch.eye(2))
        
    def forward(self, z0: Tensor, z: Tensor, sum_log_det_J: float) -> float:
        
        #calculate the log of the starting distribution at initial points z0
        base_log_prob = self.base_distr.log_prob(z0)
        
        #calculate the log of the target distribution at final points z
        target_density_log_prob = torch.log(self.distr(z)+1e-20)
        
        #calculate the free energy
        return (base_log_prob - target_density_log_prob - sum_log_det_J).mean()

In [None]:
#function definitions required for plotting purposes

#the starting function
def StartingFunction(x, mean, std):
    return scipy.stats.multivariate_normal.pdf(x, [mean,mean], [[std**2,0], [0,std**2]])

#the jacobian calculation using tanh
def Jacobian(x, P):
    
    num_points = np.shape(x)[0]
    J = np.zeros(num_points)
    
    u = np.array([P[0], P[1]])
    w = np.array([P[2], P[3]])
    b = P[4]
    
    for j in range(num_points):
        J[j] = np.abs(1 + u@w.T*(1 - np.tanh(w.T@x[j] + b)**2))
        
    return J 

#compute the inverse
def H(x, P):
    
    u = np.array([P[0], P[1]])
    w = np.array([P[2], P[3]])
    b = P[4]
    
    X = x + u*np.tanh(w.T@x + b)
    return X

def computeInverse(z, P):
    
    u = np.array([P[0], P[1]])
    w = np.array([P[2], P[3]])
    b = P[4]
    
    hInverse = (z - u*b)/(w + u*w)
    num_points = np.shape(hInverse)[0]
    
    for j in range(num_points):
        
        def optFun(x):
            return H(x, P) - z[j]
        
        hInverse[j] = scipy.optimize.fsolve(optFun, hInverse[j])
        
    return hInverse

In [None]:
#the definitions for plotting

plt.rc('font', family='Arial') 
plt.rc('xtick', labelsize='x-small') 
plt.rc('ytick', labelsize='x-small')

def plot_points(num, lim=5):
    
    #create the points in x and y directions for plotting
    x = y = torch.linspace(-lim, lim, num)
        
    #create the mesh grid
    X, Y = torch.meshgrid(x, y)
        
    #reshape the mesh into a vector
    shape = X.shape
    X_flatten, Y_flatten = np.reshape(X, (-1, 1)), np.reshape(Y, (-1, 1))
    XY = torch.from_numpy(np.concatenate([X_flatten, Y_flatten], 1))
        
    return X, Y, XY, shape
    
def plot_density(density, num=500, lim=5, cmap=cm.magma):
        
    X, Y, XY, shape = plot_points(num, lim)

    #calculate the density function
    Z = density(XY)

    #reshape the vector into a grid
    Z = Z.reshape(shape)
    Z = Z.detach().numpy()

    #create the figure and subplots
    fig = plt.figure(figsize=(11, 4), dpi=300)
    ax1 = fig.add_subplot(121, projection='3d')
    ax2 = fig.add_subplot(122)
    
    #3d plot
    ax1.plot_surface(X, Y, Z, cmap=cmap, antialiased=True)
    ax1.set_xlabel('x')
    ax1.set_ylabel('y')
    
    #2d plot
    twoD = ax2.pcolormesh(X, Y, Z, cmap=cmap, shading='auto')
    plt.colorbar(twoD)
    ax2.set_xlabel('x')
    ax2.set_ylabel('y')

    plt.suptitle('Target PDF')
    #plt.savefig(f'2D Target PDF.png', bbox_inches='tight')
    plt.show()

def compute_opt(mean, std, num, lim):
    
    #put the parameters into a vector
    P = []
    for param in model.parameters():
        p = param.detach().numpy().tolist()
        p = np.reshape(p, -1)
        P = np.concatenate([P, p])
            
    #points
    X, Y, XY, shape = plot_points(num, lim)
    XY = XY.detach().numpy().tolist()
    XY = np.array(XY)
        
    #compute the optimized pdf
    Z = np.ones(np.shape(XY)[0])
    
    i = 5*(flow_length - 1)
    for j in range(flow_length):
        XY = computeInverse(XY, P[i:i+5])
        Z = Z/Jacobian(XY, P[i:i+5])
        i-=5
            
    Z = Z*StartingFunction(XY, mean, std)
    Z = Z.reshape(num, num)
    
    return X, Y, Z
    
#two different ways to plot the optimized pdf (using a grid vs sample points)
def plot_optimized(mean, std, num=100, lim=5, cmap=cm.magma):
    
    #compute the optimzed pdf
    X, Y, Z = compute_opt(mean, std, num, lim)
    
    #create the figure and subplots
    fig = plt.figure(figsize=(11, 4), dpi=300)
    ax1 = fig.add_subplot(121, projection='3d')
    ax2 = fig.add_subplot(122)
    
    #3d plot
    ax1.plot_surface(X, Y, Z, cmap=cmap, antialiased=True)
    ax1.set_xlabel('x')
    ax1.set_ylabel('y')
    
    #2d plot
    twoD = ax2.pcolormesh(X, Y, Z, cmap=cmap, shading='auto')
    plt.colorbar(twoD)
    ax2.set_xlabel('x')
    ax2.set_ylabel('y')
    
    plt.suptitle('Optimized PDF')
    #plt.savefig(f'2D Optimized PDF.png', bbox_inches='tight')
    plt.show()
    
def plot_model_points(model, mean, std, num=500, lim=5, cmap=cm.magma):
    
    #get points
    _, _, XY, _ = plot_points(num, 5*lim)
    
    #compute optimized points
    Z, sum_log_jacobians = model(XY)
    base_distr = MultivariateNormal(mean*torch.ones(2), (std**2)*torch.eye(2))
    base_log_prob = base_distr.log_prob(XY).reshape(1,num**2)
    final_log_prob = base_log_prob - sum_log_jacobians
    prob = torch.exp(final_log_prob)
    
    #reshape points
    X = Z[:, 0].detach().reshape(num, num)
    Y = Z[:, 1].detach().reshape(num, num)
    prob = prob.detach().reshape(num, num)
    
    #plot
    fig = plt.figure(figsize=(4, 3), dpi=300)
    twoD = plt.pcolormesh(X, Y, prob, cmap=cmap, shading='auto')
    plt.colorbar(twoD)
    plt.xlim(-lim, lim)
    plt.ylim(-lim, lim)
    plt.xlabel('x')
    plt.ylabel('y')
    
    plt.title('Optimized PDF')
    #plt.savefig(f'2D Optimized PDF.png', bbox_inches='tight')
    plt.show()
    
def plot_tvals(tvals):
    
    #plot
    fig = plt.figure(figsize=(6, 4), dpi=300)
    plt.plot(np.linspace(1,len(tvals), len(tvals)), tvals, 'darkmagenta', lw=1)
    plt.title('Annealing Schedule')
    #plt.savefig(f'Annealing Schedule.png', bbox_inches='tight')
    plt.show()

In [None]:
#AdaAnn scheduler set-up
    
#list to collect t values
tvals = []

#choose the starting distribution parameters
mean_starting = 0
std_starting = 2
    
#choose the target distribution and set a
target_distr = "Bimodal2D"
a = 0

#plot target density
density_target = TargetDistribution(target_distr, t=1, a=a)
plot_density(density_target, lim=4)

#set the parameters
flow_length = 75       #number of planar flow layers K
lr = 0.0005            #learning rate for the optimizer
t0 = 0.01              #starting t value
tol = 0.01             #KL divergence tolerance for AdaAnn
M = 1000               #number of sample points to compute step size
dt = 0

#set the number of samples in each batch
N = 100                #at t0 and each annealing step
N_1 = 1000             #at t = 1

#set the number of iterations
T_0 = 500              #at t0
T = 5                  #at each annealing step
T_1 = 5000             #at t = 1

#create the model and optimizer using Adam
model = PlanarFlow(K=flow_length)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
#start time
start = time.default_timer()

#optimization using AdaAnn
t = t0
while t < 1:
    
    #new t value        
    t = min(1, t + dt) 
    tvals = np.concatenate([tvals, np.array([t])])
      
    #number of iterations and batch size at each annealing step
    num_iter = T
    batch_size = N
        
    #update parameters at t0
    if t == t0:
        num_batches = T_0
     
    #update parameters at t = 1 and include a learning rate scheduler
    if t == 1:
        num_iter = T_1
        batch_size = N_1
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.75)
            
    #update the target density and loss function with current t value
    density = TargetDistribution(target_distr, t=t, a=a)
    bound = VariationalLoss(density, mean=mean_starting, std=std_starting)
        
    #train the model   
    for iter_num in range(1, num_iter + 1):
            
        #get the batches from starting distribution
        batch = torch.zeros((batch_size,2)).normal_(mean=mean_starting, std=std_starting)

        #pass the batch through the planar flow model
        zk, log_jacobians = model(batch)

        #compute the loss
        loss = bound(batch, zk, log_jacobians)

        #train the model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
         
        #apply a learning rate scheduler when t = 1
        if t == 1:
            scheduler.step()
       
    #compute the dt value using M points
    density_dt = TargetDistribution(target_distr, t=1, a=a)
    zk, log_jacobians = model(torch.zeros((M,2)).normal_(mean=mean_starting, std=std_starting))
    log_qk = torch.log(density_dt(zk) + 1e-10)
    dt = tol/torch.sqrt(log_qk.var())
    dt = dt.detach().numpy()
    
#compute time
end = time.default_timer()
opt_time = end - start
    
#plot approximation and annealing schedule
plot_model_points(model, mean=mean_starting, std=std_starting)
plot_tvals(tvals)

In [None]:
#linear scheduler set-up
    
#list to collect t values
tvals = []

#choose the starting distribution parameters
mean_starting = 0
std_starting = 2
    
#choose the target distribution and set a
target_distr = "Bimodal2D"
a = 0

#plot target density
density_target = TargetDistribution(target_distr, t=1, a=a)
plot_density(density_target, lim=4)

#set the parameters
flow_length = 75       #number of planar flow layers K
lr = 0.0005            #learning rate for the optimizer
t0 = 0.01              #starting t value
eps = 1/10000          #set constant step size
dt = 0

#set the number of samples in each batch
N = 100                #at t0 and each annealing step
N_1 = 1000             #at t = 1

#set the number of iterations
T_0 = 500              #at t0
T = 1                  #at each annealing step
T_1 = 5000             #at t = 1

#create the model and optimizer using Adam
model = PlanarFlow(K=flow_length)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
#start time
start = time.default_timer()

#optimization using AdaAnn
t = t0
while t < 1:
    
    #new t value        
    t = min(1, t + dt) 
    tvals = np.concatenate([tvals, np.array([t])])
      
    #number of iterations and batch size at each annealing step
    num_iter = T
    batch_size = N
        
    #update parameters at t0
    if t == t0:
        num_batches = T_0
     
    #update parameters at t = 1 and include a learning rate scheduler
    if t == 1:
        num_iter = T_1
        batch_size = N_1
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.75)
            
    #update the target density and loss function with current t value
    density = TargetDistribution(target_distr, t=t, a=a)
    bound = VariationalLoss(density, mean=mean_starting, std=std_starting)
        
    #train the model   
    for iter_num in range(1, num_iter + 1):
            
        #get the batches from starting distribution
        batch = torch.zeros((batch_size,2)).normal_(mean=mean_starting, std=std_starting)

        #pass the batch through the planar flow model
        zk, log_jacobians = model(batch)

        #compute the loss
        loss = bound(batch, zk, log_jacobians)

        #train the model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
         
        #apply a learning rate scheduler when t = 1
        if t == 1:
            scheduler.step()
       
    dt = eps
    
#compute time
end = time.default_timer()
opt_time = end - start
    
#plot approximation and annealing schedule
plot_model_points(model, mean=mean_starting, std=std_starting)
plot_tvals(tvals)

In [None]:
#no scheduler set-up
    
#choose the target distribution and set a
target_distr = "Bimodal2D"
a = 0

#choose the starting distribution parameters
mean_starting = 0
std_starting = 2

#plot target density
density_target = TargetDistribution(target_distr, t=1, a=a)
plot_density(density_target, lim=4)

#set the parameters
flow_length = 75       #number of planar flow layers K
lr = 0.0005            #learning rate for the optimizer

#set the number of samples in each batch
batch_size = 100

#set the number of iterations
num_iter = 5000

#create the model, optimizer using Adam, target density, and loss function
model = PlanarFlow(K=flow_length)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
density = TargetDistribution(target_distr, t=t, a=a)
bound = VariationalLoss(density, mean=mean_starting, std=std_starting)
    
#learning rate scheduler
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.75)       

#start time
start = time.default_timer()

#optimization using no annealing scheduler
#train the model   
for iter_num in range(1, num_iter + 1):
            
    #get the batches from starting distribution
    batch = torch.zeros((batch_size,2)).normal_(mean=mean_starting, std=std_starting)

    #pass the batch through the planar flow model
    zk, log_jacobians = model(batch)

    #compute the loss
    loss = bound(batch, zk, log_jacobians)

    #train the model
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
         
    #apply a learning rate scheduler
    #scheduler.step()
    
#compute time
end = time.default_timer()
opt_time = end - start
    
#plot approximation
plot_model_points(model, mean=mean_starting, std=std_starting)