In [None]:
#package imports

import torch
import torch.nn as nn
from torch import Tensor
from torch.distributions import MultivariateNormal
import scipy.optimize
import scipy.stats
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
from typing import Tuple
from typing import Callable
from mpl_toolkits.mplot3d import Axes3D
from pathlib import Path
import timeit as time

In [None]:
#definitions for rk4 to solve hiv equations

def evalRHSvec(p1, p2, p3, p4, p5, x):
    return torch.tensor([p1 - p2*x[0] - p3*x[0]*x[2], p3*x[0]*x[2] - p4*x[1], p1*p4*x[1] - p5*x[2]])

def solveRK4vec(p1, p2, p3, p4, p5, x0, deltaT, totalSteps):
    
    #create time and solution vectors
    time = torch.zeros(totalSteps)
    sol  = torch.zeros((totalSteps,len(x0)))
    
    #set initial condition
    sol[0,:] = x0
    
    for loopA in range(1,totalSteps):
        
        time[loopA] = time[loopA-1] + deltaT
        
        k1 = evalRHSvec(p1, p2, p3, p4, p5, sol[loopA-1,:])
        k2 = evalRHSvec(p1, p2, p3, p4, p5, sol[loopA-1,:] + 0.5*deltaT*k1)
        k3 = evalRHSvec(p1, p2, p3, p4, p5, sol[loopA-1,:] + 0.5*deltaT*k2)
        k4 = evalRHSvec(p1, p2, p3, p4, p5, sol[loopA-1,:] + deltaT*k3)
        
        sol[loopA,:] = sol[loopA-1,:] + (deltaT/6.0)*(k1 + 2*k2 + 2*k3 + k4)
        
    return sol

def solveHIVvec(p1, p2, x2_0, num_data):
    
    #set final time and step size
    totalTime = 2.06                      #final time step will be 2
    deltaT = 0.05
    totalSteps = int(totalTime/deltaT)    #40 steps

    #set known parameters
    p3 = 4.1
    p4 = 10.2
    p5 = 2.6
    
    #set initial condition
    x1_0 = 0
    x3_0 = 1
    
    x0 = torch.tensor([x1_0, x2_0, x3_0])

    #solve using rk4
    sol = solveRK4vec(p1, p2, p3, p4, p5, x0, deltaT, totalSteps)

    #select only x3 component as data points (only known output)
    #select all 40 data points to be used
    data = torch.zeros(num_data, 1)
    index=0
    
    for j in range(len(sol)):
        if j % 1 == 0:
            data[index] = sol[j][2]
            index+=1
            
    return data

In [None]:
#the planar flow class chains the planar flow layers using nn.Sequential

class PlanarFlow(nn.Module):
    def __init__(self, K: int = 6): #K is the number of planar flow layers chained together
        super().__init__()
        
        #chain the planar transforms together in a list
        self.layers = [PlanarTransform() for _ in range(K)]
        
        #create the model from the list
        self.model = nn.Sequential(*self.layers)
        
    def forward(self, z: Tensor) -> Tuple[Tensor, float]:
        
        #set the log of the Jacobian to zero
        log_det_J = 0
        
        for layer in self.layers:
            
            #sum the log Jacobian of each layer
            log_det_J += layer.log_det_J(z)
            
            #calculate the new points from the planar flow
            z = layer(z)
            
        return z, log_det_J

In [None]:
#the planar transform class contains the functions for the planar flow

class PlanarTransform(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        #randomly assign starting values to the planar flow parameters from a normal distribution
        self.u = nn.Parameter(torch.randn(1, 3).normal_(0, 0.1))
        self.w = nn.Parameter(torch.randn(1, 3).normal_(0, 0.1))
        self.b = nn.Parameter(torch.randn(1).normal_(0, 0.1))
        
    def forward(self, z: Tensor) -> Tensor:
        
        #check the invertibility condition
        if torch.mm(self.u, self.w.T) < -1:
            
            #update u to ensure invertibility
            self.get_u_hat()
            
        #return the planar flow layer function
        return z + self.u*nn.Tanh()(torch.mm(z, self.w.T) + self.b)
    
    def log_det_J(self, z: Tensor) -> Tensor:
        
        #check the invertibility condition
        if torch.mm(self.u, self.w.T) < -1:
            self.get_u_hat()
            
        #calculate the log of the Jacobian
        a = torch.mm(z, self.w.T) + self.b
        psi = (1 - nn.Tanh()(a)**2)*self.w
        abs_det = (1 + torch.mm(self.u, psi.T)).abs()
        log_det = torch.log(1e-10 + abs_det)
        
        return log_det
    
    def get_u_hat(self) -> None:
        
        #invertibility condition
        wtu = torch.mm(self.u, self.w.T)
        m_wtu = -1 + torch.log(1 + torch.exp(wtu))
        self.u.data = (self. u + (m_wtu - wtu)*self.w/torch.norm(self.w, p=2, dim=1)**2)

In [None]:
#the target distribution class holds the definitions for the 3D target distributions used in the HIV example

class TargetDistribution:
    
    def __init__(self, name: str, t: int, sigma_sqr: int):  #t is the annealing value, sigma_sqr is the error used
        
        #get the name of the target distirbution to be used
        self.func = self.get_target_distribution(name, t, sigma_sqr)
        
    def __call__(self, d:Tensor, z: Tensor) -> Tensor:
        
        #return the target function
        return self.func(d, z)
    
    @staticmethod
    def get_target_distribution(name: str, t: int, sigma_sqr: int) -> Callable[[Tensor], Tensor]:
        
        if name == "HIV":
            
            #target posterior for HIV equations using a gaussian likelihood and uniform prior
            def HIV(d, params):
                
                def evalRHS(p1, p2, p3, p4, p5, x1, x2, x3):
                    return p1 - p2*x1 - p3*x1*x3, p3*x1*x3 - p4*x2, p1*p4*x2 - p5*x3

                def solveRK4 (p1, p2, p3, p4, p5, x1_0, x2_0, x3_0, deltaT, totalSteps, true_data):

                    #set starting time to zero
                    time = torch.tensor(0.0, dtype=torch.double)

                    #set initial conditions
                    x1 = x1_0
                    x2 = x2_0
                    x3 = x3_0

                    #compute the norm at the initial condition
                    val = (true_data[0] - x3)**2

                    #solve at each time step using rk4
                    j = 1
                    for loopA in range(1, totalSteps):

                        time = time + deltaT

                        k1x1, k1x2, k1x3 = evalRHS(p1, p2, p3, p4, p5, x1, x2, x3)
                        k2x1, k2x2, k2x3 = evalRHS(p1, p2, p3, p4, p5, x1 + 0.5*deltaT*k1x1, x2 + 0.5*deltaT*k1x2, x3 + 0.5*deltaT*k1x3)
                        k3x1, k3x2, k3x3 = evalRHS(p1, p2, p3, p4, p5, x1 + 0.5*deltaT*k2x1, x2 + 0.5*deltaT*k2x2, x3 + 0.5*deltaT*k2x3)
                        k4x1, k4x2, k4x3 = evalRHS(p1, p2, p3, p4, p5, x1 + deltaT*k3x1, x2 + deltaT*k3x2, x3 + deltaT*k3x3)

                        x1 = x1 + (deltaT/6.0)*(k1x1 + 2*k2x1 + 2*k3x1 + k4x1)
                        x2 = x2 + (deltaT/6.0)*(k1x2 + 2*k2x2 + 2*k3x2 + k4x2)
                        x3 = x3 + (deltaT/6.0)*(k1x3 + 2*k2x3 + 2*k3x3 + k4x3)

                        #compute the summation in the exponential of the gaussian likelihood
                        #summing the norm between the true data and the current values at the 40 selected time steps
                        if loopA % 1 == 0:
                            val += (true_data[j] - x3)**2
                            j += 1

                    return val

                def solveHIV(p1, p2, x2_0, true_data):
                    
                    #set the final time and step size
                    totalTime = torch.tensor(2.06, dtype=torch.double)                   #final time will be 2
                    deltaT = torch.tensor(0.05, dtype = torch.double)
                    totalSteps = torch.tensor(int(totalTime/deltaT), dtype=torch.int)    #40 time steps
                    
                    #set known parameters
                    p3 = torch.tensor(4.1, dtype=torch.double)
                    p4 = torch.tensor(10.2, dtype=torch.double)
                    p5 = torch.tensor(2.6, dtype=torch.double)
                    
                    #set initial conditions
                    x1_0 = torch.tensor(0, dtype=torch.double)
                    x3_0 = torch.tensor(1, dtype=torch.double)

                    #compute the summation in the gaussian likelihood
                    val = solveRK4(p1, p2, p3, p4, p5, x1_0, x2_0, x3_0, deltaT, totalSteps, true_data)
                    
                    #replace infinite values and limit too large of values
                    if val > 10000 or torch.isnan(val):
                        val = torch.tensor([10000])
                    
                    return val
                
                
                #compute the posterior probabilities for each sampled point of parameters (p1,p2,x2_0)
                f = torch.zeros(len(params))
                for j in range(len(params)):
                    
                    #exponential summation
                    val = solveHIV(params[j][0], params[j][1], params[j][2], d)
                    
                    #constant value (doesn't affect the free energy bound)
                    c = 1.0/np.sqrt((2*np.pi*sigma_sqr)**40)                        #40 data points, 1 dimension
                    
                    #compute the log of the posterior
                    f[j] = torch.log(torch.tensor(c)) + -t*val/(2*sigma_sqr)
                
                return f
            
            return HIV

In [None]:
#free energy loss function used

class VariationalLoss(nn.Module):
    
    def __init__(self, distribution: TargetDistribution, mean: int, std: int):
        super().__init__()
        
        #the target distribution
        self.distr = distribution
        
        #the starting distirbution
        self.base_distr = MultivariateNormal(mean*torch.ones(3), (std**2)*torch.eye(3))
        
    def forward(self, d: Tensor, z0: Tensor, z: Tensor, sum_log_det_J: float) -> float:
        
        #calculate the log of the starting distribution at initial points z0
        base_log_prob = self.base_distr.log_prob(z0)
        
        #calculate the log of the target distribution at final points z
        #this function returns the log of the target density
        target_density_log_prob = self.distr(d, z)
        
        #calculate the free energy
        return (base_log_prob - target_density_log_prob - sum_log_det_J).mean()

In [None]:
#function definitions required for plotting purposes

#the starting function
def StartingFunction(x, mean, std):
    return scipy.stats.multivariate_normal.pdf(x, [mean,mean,mean], [[std**2,0,0], [0,std**2,0], [0,0,std**2]])

#the jacobian calculation using tanh
def Jacobian(x, P):
    
    num_points = np.shape(x)[0]
    J = np.zeros(num_points)
    
    u = np.array([P[0], P[1], P[2]])
    w = np.array([P[3], P[4], P[5]])
    b = P[6]
    
    for j in range(num_points):
        J[j] = np.abs(1 + u@w.T*(1 - np.tanh(w.T@x[j] + b)**2))
        
    return J 

#compute the inverse
def H(x, P):
    
    u = np.array([P[0], P[1], P[2]])
    w = np.array([P[3], P[4], P[5]])
    b = P[6]
    
    X = x + u*np.tanh(w.T@x + b)
    return X

def computeInverse(z, P):
    
    u = np.array([P[0], P[1], P[2]])
    w = np.array([P[3], P[4], P[5]])
    b = P[6]
    
    hInverse = (z - u*b)/(w + u*w)
    num_points = np.shape(hInverse)[0]
    
    for j in range(num_points):
        
        def optFun(x):
            return H(x, P) - z[j]
        
        hInverse[j] = scipy.optimize.fsolve(optFun, hInverse[j])
        
    return hInverse

In [None]:
#the definitions for plotting

plt.rc('font', family='Arial') 
plt.rc('xtick', labelsize='x-small') 
plt.rc('ytick', labelsize='x-small')

def plot_optimized(model, mean, std, points=500, cmap=cm.inferno):
    
    #compute the transformed sample points
    z0 = torch.zeros((points,3)).normal_(mean=mean, std=std)
    zk, log_jacobians = model(z0)
    zk = zk.detach().numpy()
    
    #compute the probability
    prob = np.exp(np.log(StartingFunction(z0)) - log_jacobians.detach().numpy())
    
    #create the figure and subplot
    fig = plt.figure(figsize=(10, 6), dpi=300)
    ax1 = fig.add_subplot(111, projection='3d')
    
    #3d plot
    CB = ax1.scatter(zk[:,0], zk[:,1], zk[:,2], c=prob, cmap=cm.inferno, alpha=0.3)
    plt.colorbar(CB)
    ax1.set_xlabel('p1')
    ax1.set_ylabel('p2')
    ax1.set_zlabel('x2_0')
    plt.title('Optimized PDF')
    #plt.savefig(f'Optimized PDF.png', bbox_inches='tight')
    plt.show()
    
def plot_tvals(tvals):
    
    #plot
    fig = plt.figure(figsize=(6, 4), dpi=300)
    plt.plot(np.linspace(1,len(tvals), len(tvals)), tvals, 'darkmagenta', lw=1)
    plt.title('Annealing Schedule')
    #plt.savefig(f'Annealing Schedule.png', bbox_inches='tight')
    plt.show()

In [None]:
#set true parameter values, number of data points, and error
p1_true = 1.2
p2_true = 0.8
x2_0_true = 1.5
num_data = 41
sigma_sqr = 0.0005

#add Gaussian noise to selected data points (not including initial values)
data_no_noise = solveHIVvec(p1_true, p2_true, x2_0_true, num_data)
error = torch.zeros(num_data,1)
error[1:num_data] = MultivariateNormal(torch.zeros(1), sigma_sqr*torch.eye(1)).sample((num_data - 1,))
data = data_no_noise + error

In [None]:
#plot trajectory and noisy data
times = np.linspace(0,2,41)

fig = plt.figure(figsize=(6, 4), dpi=300)
plt.plot(times, data_no_noise.detach(), color='darkmagenta')
plt.plot(times, data.detach(), '.', color='tomato')
plt.xlabel('Time')
plt.ylabel('x3')
plt.title('Trajectory')

In [None]:
#AdaAnn scheduler set-up
    
#list to collect t values
tvals = []

#choose the starting distribution parameters
mean_starting = 0
std_starting = 2
    
#choose the target distribution
target_distr = "HIV"

#set the parameters
flow_length = 250
lr = 0.0005
t0 = 0.00005
tol = 0.005
M = 100
dt = 0

#set the number of samples in each batch
N = 100                #at t0 and each annealing step
N_1 = 200              #at t = 1

#set the number of iterations
T_0 = 1000             #at t0
T = 5                  #at each annealing step
T_1 = 5000             #at t = 1
     
#create the model and optimizer using Adam   
model = PlanarFlow(K=flow_length)    
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
#start time
start = time.default_timer()

#optimization using AdaAnn
t = t0
while t < 1:
    
    #new t value
    t = min(1, t + dt)
    tvals = np.concatenate([tvals, np.array([t])])
    
    #number of iterations and batch size at each annealing step
    num_iter = T
    batch_size = N
    
    #update parameters at t0
    if t == t0:
        num_batches = T_0
         
    #update parameters at t = 1 and include a learning rate scheduler
    if t == 1:
        num_iter = T_1
        batch_size = N_1
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 1000, gamma=0.75)
          
    #update the target density and loss function with current t value
    density = TargetDistribution(target_distr, t=t, sigma_sqr=sigma_sqr)
    bound = VariationalLoss(density, mean=mean_starting, std=std_starting)
        
    #train the model
    for iter_num in range(1, num_iter + 1):

        #get the batches from starting distribution
        batch = torch.zeros((batch_size, 3)).normal_(mean=mean_starting, std=std_starting)
        
        #pass the batch through the planar flow model
        zk, log_jacobians = model(batch)

        #compute the loss
        loss = bound(data, batch, zk, log_jacobians)

        #train the model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        #apply a learning rate scheduler when t = 1
        if t == 1:
            scheduler.step()
            
    #compute the dt value using M points
    density_dt = TargetDistribution(target_distr, t=1, sigma_sqr=sigma_sqr)
    zk, log_jacobians = model(torch.zeros((M,3)).normal_(mean=mean_starting,std=std_starting))
    log_qk = density_dt(data, zk)
    dt = tol/torch.sqrt(log_qk.var())
    dt = dt.detach().numpy()
           
#compute time
end = time.default_timer()
opt_time = end - start

#plot approximation and annealing schedule
plot_optimized(model, mean=mean_starting, std=std_starting)
plot_tvals(tvals)

In [None]:
#no scheduler set-up
    
#choose the starting distribution parameters
mean_starting = 0
std_starting = 2
    
#choose the target distribution
target_distr = "HIV"

#set the parameters
flow_length = 250
lr = 0.0005
t0 = 0.00005
tol = 0.005
M = 100
dt = 0

#set the number of samples in each batch
batch_size = 100

#set the number of iterations
num_iter = 20000
     
#create the model, optimizer using Adam, target density, and loss function
model = PlanarFlow(K=flow_length)    
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
density = TargetDistribution(target_distr, t=t, sigma_sqr=sigma_sqr)
bound = VariationalLoss(density, mean=mean_starting, std=std_starting)

#learning rate scheduler
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 1000, gamma=0.75)
         
#start time
start = time.default_timer()

#optimization using no annealing scheduler
#train the model
for iter_num in range(1, num_iter + 1):

    #get the batches from starting distribution
    batch = torch.zeros((batch_size, 3)).normal_(mean=mean_starting, std=std_starting)
        
    #pass the batch through the planar flow model
    zk, log_jacobians = model(batch)

    #compute the loss
    loss = bound(data, batch, zk, log_jacobians)

    #train the model
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
        
    #apply a learning rate scheduler when t = 1
    #scheduler.step()
           
#compute time
end = time.default_timer()
opt_time = end - start

#plot approximation
plot_optimized(model, mean=mean_starting, std=std_starting)