In [None]:
#package imports

import torch
import torch.nn as nn
from torch import Tensor
from torch.distributions import MultivariateNormal
import scipy.optimize
import scipy.stats
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
from typing import Tuple
from typing import Callable
from mpl_toolkits.mplot3d import Axes3D
from pathlib import Path
import timeit as time

In [None]:
#definitions for rk4 to solve lorenz equations

def evalRHSvec(s, r, b, x):
    return torch.tensor([s*(-x[0] + x[1]), r*x[0] -x[1] -x[0]*x[2], -b*x[2] + x[0]*x[1]])

def solveRK4vec(s, r, b, x0, deltaT, totalSteps):
    
    #create time and solution vectors
    time = torch.zeros(totalSteps)
    sol  = torch.zeros((totalSteps,len(x0)))
    
    #set initial condition
    sol[0,:] = x0
    
    #solve at each time step using rk4
    for loopA in range(1,totalSteps):
        
        time[loopA] = time[loopA-1] + deltaT
        
        k1 = evalRHSvec(s, r, b, sol[loopA-1,:])
        k2 = evalRHSvec(s, r, b, sol[loopA-1,:] + 0.5*deltaT*k1)
        k3 = evalRHSvec(s, r, b, sol[loopA-1,:] + 0.5*deltaT*k2)
        k4 = evalRHSvec(s, r, b, sol[loopA-1,:] + deltaT*k3)
        
        sol[loopA,:] = sol[loopA-1,:] + (deltaT/6.0)*(k1 + 2*k2 + 2*k3 + k4)
        
    return sol

def solveLorenzVec(s, b, r, num_data):
    
    #set final time and step size
    totalTime = 1.53                      #final time step will be 1.5
    deltaT = 0.025
    totalSteps = int(totalTime/deltaT)    #60 steps

    #set initial condition
    x0 = torch.tensor([1.0, 1.0, 1.0])

    #solve using rk4
    sol = solveRK4vec(s, r, b, x0, deltaT, totalSteps)

    #select 30 data points (of the 60) to use
    data = torch.zeros(num_data, 3)
    index=0
    
    for j in range(len(sol)):
        if j % 2 == 0:
            data[index] = sol[j]
            index+=1
            
    return data, sol

In [None]:
#the planar flow class chains the planar flow layers using nn.Sequential

class PlanarFlow(nn.Module):
    def __init__(self, K: int = 6): #K is the number of planar flow layers chained together
        super().__init__()
        
        #chain the planar transforms together in a list
        self.layers = [PlanarTransform() for _ in range(K)]
        
        #create the model from the list
        self.model = nn.Sequential(*self.layers)
        
    def forward(self, z: Tensor) -> Tuple[Tensor, float]:
        
        #set the log of the Jacobian to zero
        log_det_J = 0
        
        for layer in self.layers:
            
            #sum the log Jacobian of each layer
            log_det_J += layer.log_det_J(z)
            
            #calculate the new points from the planar flow
            z = layer(z)
            
        return z, log_det_J

In [None]:
#the planar transform class contains the functions for the planar flow

class PlanarTransform(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        #randomly assign starting values to the planar flow parameters from a normal distribution
        self.u = nn.Parameter(torch.randn(1, 3).normal_(0, 0.1))
        self.w = nn.Parameter(torch.randn(1, 3).normal_(0, 0.1))
        self.b = nn.Parameter(torch.randn(1).normal_(0, 0.1))
        
    def forward(self, z: Tensor) -> Tensor:
        
        #check the invertibility condition
        if torch.mm(self.u, self.w.T) < -1:
            
            #update u to ensure invertibility
            self.get_u_hat()
            
        #return the planar flow layer function
        return z + self.u*nn.Tanh()(torch.mm(z, self.w.T) + self.b)
    
    def log_det_J(self, z: Tensor) -> Tensor:
        
        #check the invertibility condition
        if torch.mm(self.u, self.w.T) < -1:
            self.get_u_hat()
            
        #calculate the log of the Jacobian
        a = torch.mm(z, self.w.T) + self.b
        psi = (1 - nn.Tanh()(a)**2)*self.w
        abs_det = (1 + torch.mm(self.u, psi.T)).abs()
        log_det = torch.log(1e-10 + abs_det)
        
        return log_det
    
    def get_u_hat(self) -> None:
        
        #invertibility condition
        wtu = torch.mm(self.u, self.w.T)
        m_wtu = -1 + torch.log(1 + torch.exp(wtu))
        self.u.data = (self. u + (m_wtu - wtu)*self.w/torch.norm(self.w, p=2, dim=1)**2)

In [None]:
#the target distribution class holds the definitions for the 3D target distributions used in the lorenz example

class TargetDistribution:
    
    def __init__(self, name: str, t: int, sigma_sqr: int):  #t is the annealing value, sigma_sqr is the error used
        
        #get the name of the target distirbution to be used
        self.func = self.get_target_distribution(name, t, sigma_sqr)
        
    def __call__(self, d:Tensor, z: Tensor) -> Tensor:
        
        #return the target function
        return self.func(d, z)
    
    @staticmethod
    def get_target_distribution(name: str, t: int, sigma_sqr: int) -> Callable[[Tensor], Tensor]:
        
        if name == "Lorenz":
            
            #target posterior for lorenz equations using a gaussian likelihood and uniform prior
            def Lorenz(d, params):
                
                def evalRHS(s, r, b, x, y, z):
                    return s*(-x + y), r*x - y - x*z, -b*z + x*y

                def solveRK4 (s, r, b, x0, y0, z0, deltaT, totalSteps, true_data):

                    #set starting time to zero
                    time = torch.tensor(0.0, dtype=torch.double)

                    #set initial conditions
                    x = x0
                    y = y0
                    z = z0
                    
                    #compute the norm at the initial conditions
                    val = (true_data[0][0] - x0)**2 + (true_data[0][1] - y0)**2 + (true_data[0][2] - z0)**2

                    #solve at each time step using rk4
                    j = 1
                    for loopA in range(1, totalSteps):

                        time = time + deltaT

                        k1x, k1y, k1z = evalRHS(s, r, b, x, y, z)
                        k2x, k2y, k2z = evalRHS(s, r, b, x + 0.5*deltaT*k1x, y + 0.5*deltaT*k1y, z + 0.5*deltaT*k1z)
                        k3x, k3y, k3z = evalRHS(s, r, b, x + 0.5*deltaT*k2x, y + 0.5*deltaT*k2y, z + 0.5*deltaT*k2z)
                        k4x, k4y, k4z = evalRHS(s, r, b, x + deltaT*k3x, y + deltaT*k3y, z + deltaT*k3z)

                        x = x + (deltaT/6.0)*(k1x + 2*k2x + 2*k3x + k4x)
                        y = y + (deltaT/6.0)*(k1y + 2*k2y + 2*k3y + k4y)
                        z = z + (deltaT/6.0)*(k1z + 2*k2z + 2*k3z + k4z)

                        #compute the summation in the exponential of the gaussian likelihood
                        #summing the norm between the true data and the current values at the 30 selected time steps
                        if loopA % 2 == 0:
                            val += (true_data[j][0] - x)**2 + (true_data[j][1] - y)**2 + (true_data[j][2] - z)**2
                            j += 1

                    return val

                def solveLorenz(s, b, r, true_data):
                    
                    #set the final time and step size
                    totalTime = torch.tensor(1.53, dtype=torch.double)                   #final time will be 1.5
                    deltaT = torch.tensor(0.025, dtype = torch.double)
                    totalSteps = torch.tensor(int(totalTime/deltaT), dtype=torch.int)    #60 time steps

                    #set the initial conditions
                    x0 = torch.tensor(1.0, dtype=torch.double)
                    y0 = torch.tensor(1.0, dtype=torch.double)
                    z0 = torch.tensor(1.0, dtype=torch.double)

                    #compute the summation in the gaussian likelihood
                    val = solveRK4(s, r, b, x0, y0, z0, deltaT, totalSteps, true_data)

                    #replace infinite values and limit too large of values
                    if val > 10000 or torch.isnan(val):
                        val = torch.tensor([10000])
                        
                    return val
                
                #compute the posterior probabilities for each sampled point of parameters (s,b,r)
                f = torch.zeros(len(params))
                for j in range(len(params)):
                    
                    #exponential summation
                    val = solveLorenz(params[j][0], params[j][1], params[j][2], d)
                    
                    #constant value (doesn't affect the free energy bound)
                    c = 1.0/np.sqrt((2*np.pi*sigma_sqr)**(3*30))                        #30 data points, 3 dimensions
                    
                    #compute the log of the posterior
                    f[j] = torch.log(torch.tensor(c)) + -t*val/(2*sigma_sqr)
                
                return f
            
            return Lorenz

In [None]:
#free energy loss function used

class VariationalLoss(nn.Module):
    
    def __init__(self, distribution: TargetDistribution, mean: int, std: int):
        super().__init__()
        
        #the target distribution
        self.distr = distribution
        
        #the starting distirbution
        self.base_distr = MultivariateNormal(mean*torch.zeros(3), (std**2)*torch.eye(3))
        
    def forward(self, d: Tensor, z0: Tensor, z: Tensor, sum_log_det_J: float) -> float:
        
        #calculate the log of the starting distribution at initial points z0
        base_log_prob = self.base_distr.log_prob(z0)
        
        #calculate the log of the target distribution at final points z
        #this function returns the log of the target density
        target_density_log_prob = self.distr(d, z)
        
        #calculate the free energy
        return (base_log_prob - target_density_log_prob - sum_log_det_J).mean()

In [None]:
#function definitions required for plotting purposes

#the starting function
def StartingFunction(x, mean, std):
    return scipy.stats.multivariate_normal.pdf(x, [mean,mean,mean], [[std**2,0,0], [0,std**2,0], [0,0,std**2]])

#the jacobian calculation using tanh
def Jacobian(x, P):
    
    num_points = np.shape(x)[0]
    J = np.zeros(num_points)
    
    u = np.array([P[0], P[1], P[2]])
    w = np.array([P[3], P[4], P[5]])
    b = P[6]
    
    for j in range(num_points):
        J[j] = np.abs(1 + u@w.T*(1 - np.tanh(w.T@x[j] + b)**2))
        
    return J 

#compute the inverse
def H(x, P):
    
    u = np.array([P[0], P[1], P[2]])
    w = np.array([P[3], P[4], P[5]])
    b = P[6]
    
    X = x + u*np.tanh(w.T@x + b)
    return X

def computeInverse(z, P):
    
    u = np.array([P[0], P[1], P[2]])
    w = np.array([P[3], P[4], P[5]])
    b = P[6]
    
    hInverse = (z - u*b)/(w + u*w)
    num_points = np.shape(hInverse)[0]
    
    for j in range(num_points):
        
        def optFun(x):
            return H(x, P) - z[j]
        
        hInverse[j] = scipy.optimize.fsolve(optFun, hInverse[j])
        
    return hInverse

In [None]:
#the definitions for plotting

plt.rc('font', family='Arial') 
plt.rc('xtick', labelsize='x-small') 
plt.rc('ytick', labelsize='x-small')
    
def plot_optimized(model, mean, std, points=500, cmap=cm.inferno):
    
    #compute the transformed sample points
    z0 = torch.zeros((points,3)).normal_(mean=mean, std=std)
    zk, log_jacobians = model(z0)
    zk = zk.detach().numpy()
    
    #compute the probability
    prob = np.exp(np.log(StartingFunction(z0)) - log_jacobians.detach().numpy())
    
    #create the figure and subplot
    fig = plt.figure(figsize=(10, 6), dpi=300)
    ax1 = fig.add_subplot(111, projection='3d')
    
    #3d plot
    CB = ax1.scatter(zk[:,0], zk[:,1], zk[:,2], c=prob, cmap=cm.inferno, alpha=0.3)
    plt.colorbar(CB)
    ax1.set_xlabel('s')
    ax1.set_ylabel('b')
    ax1.set_zlabel('r')
    plt.title('Optimized PDF')
    #plt.savefig(f'Optimized PDF.png', bbox_inches='tight')
    plt.show()
    
def plot_tvals(tvals):
    
    #plot
    fig = plt.figure(figsize=(6, 4), dpi=300)
    plt.plot(np.linspace(1,len(tvals), len(tvals)), tvals, 'darkmagenta', lw=1)
    plt.title('Annealing Schedule')
    #plt.savefig(f'Annealing Schedule.png', bbox_inches='tight')
    plt.show()

In [None]:
#set true parameter values, number of data points, and error
s_true = 10
b_true = 8/3
r_true = 28
num_data = 31
sigma_sqr = 0.2

#add Gaussian noise to selected data points (not including initial values)
data_no_noise, sol = solveLorenzVec(s_true, b_true, r_true, num_data)
error = torch.zeros(num_data,3)
error[1:num_data] = MultivariateNormal(torch.zeros(3), sigma_sqr*torch.eye(3)).sample((num_data - 1,))
data = data_no_noise + error

In [None]:
#plot trajectory and noisy data
time1 = np.linspace(0,1.5,61)
time2 = np.linspace(0,1.5,31)

fig = plt.figure(figsize=(11, 8), dpi=300)
ax1 = fig.add_subplot(221, projection='3d')
ax2 = fig.add_subplot(222)
ax3 = fig.add_subplot(223)
ax4 = fig.add_subplot(224)

#3d trajectory
ax1.plot(sol[:,0].detach().numpy(), sol[:,1].detach().numpy(), sol[:,2].detach().numpy(), color='darkmagenta')
ax1.plot(data[:,0].detach().numpy(), data[:,1].detach().numpy(), data[:,2].detach().numpy(), '.', color='tomato')
ax1.set_xlabel('x')
ax1.set_ylabel('y')
ax1.set_zlabel('z')

#x component
ax2.plot(time1, sol[:,0], color='darkmagenta')
ax2.plot(time2, data[:,0], '.', color='tomato')
ax2.set_xlabel('Time')
ax2.set_ylabel('x')

#y component
ax3.plot(time1, sol[:,1], color='darkmagenta')
ax3.plot(time2, data[:,1], '.', color='tomato')
ax3.set_xlabel('Time')
ax3.set_ylabel('y')

#z component
ax4.plot(time1, sol[:,2], color='darkmagenta')
ax4.plot(time2, data[:,2], '.', color='tomato')
ax4.set_xlabel('Time')
ax4.set_ylabel('z')

plt.suptitle('3D and Component Trajectories')

In [None]:
#AdaAnn scheduler set-up

#list to collect t values
tvals = []

#choose the starting distribution parameters
mean_starting = 10
std_starting = 2
    
#choose the target distribution
target_distr = "Lorenz"

#set the parameters
flow_length = 250
lr = 0.001
t0 = 0.05
tol = 0.5
M = 100    
dt = 0

#set the number of samples in each batch
N = 100                #at t0 and each annealing step
N_1 = 200              #at t = 1

#set the number of iterations
T_0 = 500              #at t0
T = 5                  #at each annealing step
T_1 = 5000             #at t = 1

#create the model and optimizer using Adam
model = PlanarFlow(K=flow_length)    
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

#start time
start = time.default_timer()

#optimization using AdaAnn
t = t0
while t < 1:
        
    #new t value
    t = min(1, t + dt)
    tvals = np.concatenate([tvals, np.array([t])])
        
    #number of iterations and batch size at each annealing step
    num_iter = T
    batch_size = N
    
    #update parameters at t0
    if t == t0:
        num_batches = T_0
      
    #update parameters at t = 1 and include a learning rate scheduler
    if t == 1:
        num_iter = T_1
        batch_size = N_1
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 500, gamma=0.75)
    
    #update the target density and loss function with current t value
    density = TargetDistribution(target_distr, t=t, sigma_sqr=sigma_sqr)
    bound = VariationalLoss(density, mean=mean_starting, std=std_starting) 

    #train the model
    for iter_num in range(1, num_iter + 1):

        #get the batches from starting distribution
        batch = torch.zeros((batch_size, 3)).normal_(mean=mean_starting, std=std_starting)
        
        #pass the batch through the planar flow model
        zk, log_jacobians = model(batch)

        #compute the loss
        loss = bound(data, batch, zk, log_jacobians)

        #train the model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        #apply a learning rate scheduler when t = 1
        if t == 1:
            scheduler.step()
                    
    #compute the dt value using M points
    density_dt = TargetDistribution(target_distr, t=1, sigma_sqr=sigma_sqr)
    zk, log_jacobians = model(torch.zeros((M,3)).normal_(mean=mean_starting,std=std_starting))
    log_qk = density_dt(data, zk)
    dt = tol/torch.sqrt(log_qk.var())
    dt = dt.detach().numpy()

#compute time
end = time.default_timer()
opt_time = end - start

#plot approximation and annealing schedule
plot_optimized(model, mean=mean_starting, std=std_starting)
plot_tvals(tvals)