In [None]:
import CompDoobTransform as cdt
import time
import math
import torch
import matplotlib.pyplot as plt
from CompDoobTransform.utils import normal_logpdf
plt.style.use('ggplot')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Computing on ' + str(device))

In [None]:
# tuning parameters 
std_obs = 0.25
filename = 'cell_var_obs_small.pt'

In [None]:
# dict for objects relating to latent state process
state = {}

# dimension of state 
d = 2 
state['dim'] = d

# drift of diffusion
alpha = torch.tensor(1.0)
beta = torch.tensor(1.0)
kappa = torch.tensor(1.0)
P = torch.tensor(4.0)
xi = torch.tensor(0.5)

# drift
def drift(x):
    out = torch.zeros(x.shape)    
    out[:,0] = alpha * x[:,0]**P / (xi**P + x[:,0]**P) + beta * xi**P / (xi**P + x[:,1]**P) - kappa * x[:,0]
    out[:,1] = alpha * x[:,1]**P / (xi**P + x[:,1]**P) + beta * xi**P / (xi**P + x[:,0]**P) - kappa * x[:,1]
    return out
b = lambda x: drift(x)
state['drift'] = b

# diffusion coefficient of diffusion
sigma = torch.tensor(1.0, device = device) # diffusion coefficient
state['sigma'] = sigma

# time interval
T = torch.tensor(1.0, device = device) 
state['terminal_time'] = T

# time-discretization settings
M = 50 # number of time steps

In [None]:
# dict for objects relating to observations
obs = {}

# dimension of observation
p = 2
obs['dim'] = p

# observation parameters
var_obs = torch.tensor(std_obs**2, device = device) # variance of observation

# log-observation density
obs_log_density = lambda x, y: normal_logpdf(y, x, var_obs) # terminal condition, returns size (N)
obs['log_density'] = obs_log_density

In [None]:
# simulate states and observations from model
X0 = torch.ones(1,d)
X = X0.clone()
J = 2000
max_index = J*M+1
store_states = torch.zeros(J*M+1, d, device = device)
store_states[0,:] = X    
store_obs = torch.zeros(J*M, d, device = device)
stepsize = torch.tensor(T / M, device = device)
for j in range(J):
    for m in range(M):
        euler = X + stepsize * b(X)
        W = torch.sqrt(stepsize) * torch.randn(X.shape, device = device)
        X = euler + sigma * W
        Y = X + torch.sqrt(var_obs) * torch.randn(1, p, device = device)
        index = j*M + m + 1
        store_states[index,:] = X
        store_obs[index-1,:] = Y

# learning standardization means and standard deviations
standardization = {'x_mean': torch.mean(store_states, 0), 
                   'x_std': torch.std(store_states, 0), 
                   'y_mean': torch.mean(store_obs, 0), 
                   'y_std': torch.std(store_obs, 0)}
print(standardization)

# simulate initial states
initial = lambda N: store_states[torch.randint(0, max_index, size = (N,)), :] # function to subsample states
state['initial'] = initial

# simulate observations
observation = lambda N: initial(N) + torch.sqrt(var_obs) * torch.randn(N, p, device = device)
obs['observation'] = observation

In [None]:
# V0 and Z neural network configuration
V0_net_config = {'layers': [16], 'standardization': standardization}
Z_net_config = {'layers': [d+16], 'standardization': standardization}
net_config = {'V0': V0_net_config, 'Z': Z_net_config}

# optimization configuration (standard training)
I = 2000
optim_config = {'minibatch': 100, 
                'num_obs_per_batch': 10, 
                'num_iterations': I,
                'learning_rate' : 0.01, 
                'initial_required' : True}

In [None]:
# create model instance
model_static = cdt.core.model(state, obs, M, net_config, device = 'cpu')

# static training
time_start = time.time() 
model_static.train_standard(optim_config)
time_end = time.time()
time_elapsed = time_end - time_start
print("Training time (secs): " + str(time_elapsed))

In [None]:
# create model instance
model = cdt.core.model(state, obs, M, net_config, device = 'cpu')

# iterative training
time_start = time.time() 
model.train_iterative(optim_config)
time_end = time.time()
time_elapsed = time_end - time_start
print("Training time (secs): " + str(time_elapsed))

In [None]:
# plot loss over optimization iterations
plt.figure()
plt.plot(torch.arange(I), model_static.loss.to('cpu'), '-')
plt.plot(torch.arange(I), model.loss.to('cpu'), '-')
plt.xlabel('iteration', fontsize = 15)
plt.ylabel('loss', fontsize = 15)
plt.legend(['Static CDT', 'Iterative CDT'], fontsize = 15)
plt.show()

In [None]:
# guided intermediate resampling filter
inverse_temperature = torch.linspace(0.0, 1.0, M+1) # linear schedule

def guiding_initial(x, y, p):
    guiding = inverse_temperature[0]**p * obs_log_density(x,y)
    return guiding

def guiding_intermediate(m, x, x_next, y, p):
    log_potential = inverse_temperature[m-1]**p * obs_log_density(x,y)
    log_potential_next = inverse_temperature[m]**p * obs_log_density(x_next,y) 
    guiding = log_potential_next - log_potential
    return guiding

def guiding_obs_time(m, x, x_next, y, y_next, p):
    guiding = guiding_intermediate(m, x, x_next, y, p) + guiding_initial(x_next, y_next, p)
    return guiding

guiding_linear = {}
guiding_linear['initial'] = lambda x, y: guiding_initial(x, y, 1.0)
guiding_linear['intermediate'] = lambda m, x, x_next, y: guiding_intermediate(m, x, x_next, y, 1.0)
guiding_linear['obs_time'] = lambda m, x, x_next, y, y_next: guiding_obs_time(m, x, x_next, y, y_next, 1.0)

guiding_square = {}
guiding_square['initial'] = lambda x, y: guiding_initial(x, y, 2.0)
guiding_square['intermediate'] = lambda m, x, x_next, y: guiding_intermediate(m, x, x_next, y, 2.0)
guiding_square['obs_time'] = lambda m, x, x_next, y, y_next: guiding_obs_time(m, x, x_next, y, y_next, 2.0)

In [None]:
# repeat particle filters
multiplier = 1.0
num_obs = [100, 200, 400, 800, 1600]
len_num_obs = len(num_obs)
num_particles = [2**6, 2**7, 2**8, 2**9, 2**10]
R = 100 # number of repeats
BPF = {'ess' : torch.zeros(len_num_obs, R), 'log_estimate' : torch.zeros(len_num_obs, R)}
APF = {'ess' : torch.zeros(len_num_obs, R), 'log_estimate' : torch.zeros(len_num_obs, R)}
APFF = {'ess' : torch.zeros(len_num_obs, R), 'log_estimate' : torch.zeros(len_num_obs, R)}
GIRF1 = {'ess' : torch.zeros(len_num_obs, R), 'log_estimate' : torch.zeros(len_num_obs, R)}
GIRF2 = {'ess' : torch.zeros(len_num_obs, R), 'log_estimate' : torch.zeros(len_num_obs, R)}

for i in range(len_num_obs):
    # number of observations
    K = num_obs[i]

    # number of particles
    N = num_particles[i]

    # simulate latent process and observations
    X0 = torch.ones(1,d)
    X = torch.zeros(K+1, d)
    X[0,:] = X0.clone()
    Y = torch.zeros(K, p)
    for k in range(K):
        X[k+1,:] = model.simulate_diffusion(X[k,:].reshape((1,d)))
        Y[k,:] = X[k+1,:] + multiplier * torch.sqrt(var_obs) * torch.randn(1,p)

    for r in range(R):
        # run particle filters
        BPF_output = model.run_BPF(X0.repeat((N,1)), Y, N)
        APF_output = model.run_APF(X0.repeat((N,1)), Y, N)
        APFF_output = model_static.run_APF(X0.repeat((N,1)), Y, N)
        GIRF1_output = model.run_GIRF(X0.repeat((N,1)), Y, N, guiding_linear)
        GIRF2_output = model.run_GIRF(X0.repeat((N,1)), Y, N, guiding_square)

        # save average ESS%
        BPF_ESS = torch.mean(BPF_output['ess'] * 100 / N)
        APF_ESS = torch.mean(APF_output['ess'] * 100 / N)
        APFF_ESS = torch.mean(APFF_output['ess'] * 100 / N)
        GIRF1_ESS = torch.mean(GIRF1_output['ess'] * 100 / N)
        GIRF2_ESS = torch.mean(GIRF2_output['ess'] * 100 / N)
        BPF['ess'][i,r] = BPF_ESS
        APF['ess'][i,r] = APF_ESS
        APFF['ess'][i,r] = APFF_ESS
        GIRF1['ess'][i,r] = GIRF1_ESS
        GIRF2['ess'][i,r] = GIRF2_ESS

        # save log-likelihood estimates
        BPF_log_estimate = BPF_output['log_norm_const'][-1]
        APF_log_estimate = APF_output['log_norm_const'][-1]
        APFF_log_estimate = APFF_output['log_norm_const'][-1]
        GIRF1_log_estimate = GIRF1_output['log_norm_const'][-1]
        GIRF2_log_estimate = GIRF2_output['log_norm_const'][-1]
        BPF['log_estimate'][i,r] = BPF_log_estimate
        APF['log_estimate'][i,r] = APF_log_estimate
        APFF['log_estimate'][i,r] = APFF_log_estimate
        GIRF1['log_estimate'][i,r] = GIRF1_log_estimate
        GIRF2['log_estimate'][i,r] = GIRF2_log_estimate

        # print output
        print('No. of observations: ' + str(K) + ' Repeat: ' + str(r)) 
        print('BPF ESS%: ' + str(BPF_ESS))
        print('APF ESS%: ' + str(APF_ESS)) 
        print('APFF ESS%: ' + str(APFF_ESS)) 
        print('GIRF1 ESS%: ' + str(GIRF1_ESS)) 
        print('GIRF2 ESS%: ' + str(GIRF2_ESS)) 
        print('BPF log-estimate: ' + str(BPF_log_estimate))
        print('APF log-estimate: ' + str(APF_log_estimate))
        print('APFF log-estimate: ' + str(APFF_log_estimate))
        print('GIRF1 log-estimate: ' + str(GIRF1_log_estimate))
        print('GIRF2 log-estimate: ' + str(GIRF2_log_estimate))

# save results
results = {'BPF' : BPF, 'APF' : APF, 'APFF' : APFF, 'GIRF1' : GIRF1, 'GIRF2' : GIRF2}
torch.save(results, filename)