In [1]:
import torch
import numpy as np

import pyro
import pyro.distributions as dist
import pyro.distributions.transforms as T
from torch.distributions import Transform, TransformedDistribution, MultivariateNormal, Normal
from torch.utils.data import DataLoader, TensorDataset, random_split

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.pyplot import figure
import scipy.io as sio
from __future__ import print_function
import sys
import re
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [2]:
def standarized_IQR(data):
    standardized_IQR_data = torch.zeros(data.shape)
    IQR = torch.zeros(data.shape[1])
    median = torch.zeros(data.shape[1])

    for i in range(data.shape[1]):
        column = data[:, i]

        median[i] = torch.median(column)

        q1 = torch.quantile(column, 0.25)
        q3 = torch.quantile(column, 0.75)

        IQR[i] = q3 - q1

        standardized_column = (column - median[i]) / IQR[i]

        standardized_IQR_data[:, i] = standardized_column

    return standardized_IQR_data.cuda(), median.cuda(), IQR.cuda()

def inv_standarized_IQR(data, medians, iqrs):
    original_data = torch.zeros(data.shape)

    for i in range(data.shape[1]):
        standardized_column = data[:, i]
        median = medians[i]
        iqr = iqrs[i]

        original_column = (standardized_column * iqr) + median
        original_data[:, i] = original_column

    return original_data.cuda()


def logit_transform(data, lower_bounds, upper_bounds):
    data, lower_bounds, upper_bounds = data.cpu(), lower_bounds.cpu(), upper_bounds.cpu()
    n, nvar = data.shape
    trans_data = torch.zeros([n, nvar])

    for i in range(n):
        num = data[i, :] - lower_bounds
        denom = upper_bounds - data[i, :]
        trans_data[i, :] = torch.log(num / denom)

    return trans_data.cuda()

def inverse_logit_transform(data, lower_bounds, upper_bounds):
    data, lower_bounds, upper_bounds = data.cpu(), lower_bounds.cpu(), upper_bounds.cpu()
    n, nvar = data.shape
    trans_inv_data = torch.zeros([n, nvar])

    for i in range(n):
        num = torch.exp(data[i, :]) + lower_bounds
        denom = 1 + torch.exp(data[i, :])
        trans_inv_data[i, :] = num / denom

    return trans_inv_data.cuda()

In [None]:
n = 6 # number of parameters

# get the lower bound and upper bound for logit transform
lower_bounds = torch.ones(n+1).cuda()
lower_bounds[0:n] = -1.0
lower_bounds[n] = 0.0
upper_bounds = torch.ones(n+1).cuda()

# get the data
full_dataset = sio.loadmat("/content/drive/MyDrive/Colab Notebooks/SVAR project/data_GVAR_dim6.mat") # the dataset contain pairs, obs and 'true' theta
#simulated_dataset = sio.loadmat("/content/drive/MyDrive/Colab Notebooks/SVAR project/SNPE_prior_dimi6.mat") # dataset contian sims and its corresponding theta

# get the prior from smc pilot run
theta_smc = sio.loadmat("/content/drive/MyDrive/Colab Notebooks/SVAR project/results_summ_pilot_dim6.mat")['part_vals_smc']
x_smc = sio.loadmat("/content/drive/MyDrive/Colab Notebooks/SVAR project/results_summ_pilot_dim6.mat")['part_sim_smc']

# SMC-NPE reuslt
smc_npe_result = sio.loadmat("/content/drive/MyDrive/Colab Notebooks/SVAR project/SNPE_smc_dim6.mat")['theta_smc_dim6_NPE']

theta_true = full_dataset['theta_true']
#theta = torch.from_numpy(simulated_dataset['part_vals']).to(torch.float32).cuda()
#x = torch.from_numpy(simulated_dataset['part_sim']).to(torch.float32).cuda()
theta = torch.from_numpy(theta_smc).to(torch.float32).cuda()
x = torch.from_numpy(x_smc).to(torch.float32).cuda()
x_0 = torch.from_numpy(full_dataset['sy']).to(torch.float32).cuda()

generate_model = model(n+1, full_dataset, 1000)

for i in range(theta.shape[0]):
    for j in range(theta.shape[1]):
        if x[i, j] > 1e30:
            theta[i, :] = dist.Uniform(lower_bounds, upper_bounds).sample().cuda()
            x[i, :] = generate_model.run_simulation(theta[i, :]).cuda()

    while torch.isinf(x[i,:]).any() == True:
        theta[i, :] = dist.Uniform(lower_bounds, upper_bounds).sample().cuda()
        x[i, :] = generate_model.run_simulation(theta[i, :]).cuda()

print(torch.isinf(x).any())
print(torch.isnan(x).any())

# transform the theta
theta_logit = logit_transform(theta, lower_bounds, upper_bounds)
theta_trans, theta_median, theta_iqrs = standarized_IQR(theta_logit)

# transform the simulated data
x_trans, x_median, x_iqrs = standarized_IQR(x)

# transform the observation data
y_trans = (x_0 - x_median)/x_iqrs

# check the cuda
print(theta_logit.shape, theta_trans.shape, x_trans.shape, y_trans)

In [None]:
class NPE:
    def __init__(self, x_trans, theta_trans, n, steps, lr, batch_size):
        # Dataset preparation
        self.full_dataset = TensorDataset(x_trans, theta_trans)
        n_total = len(self.full_dataset)
        n_train = int(n_total * 0.7)
        n_val = n_total - n_train
        self.train_dataset, self.val_dataset = random_split(self.full_dataset, [n_train, n_val])
        self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
        self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False)

        # Initialize distributions and modules
        self.dist_base = Normal(torch.zeros(n+1).cuda(), torch.ones(n+1).cuda())
        self.theta_transform = T.conditional_spline(n+1, context_dim=n+1)
        self.dist_theta_given_x = dist.ConditionalTransformedDistribution(self.dist_base, [self.theta_transform])
        self.modules = torch.nn.ModuleList([self.theta_transform]).cuda()
        self.optimizer = torch.optim.Adam(self.modules.parameters(), lr=lr)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1000, gamma=0.9)  # New

        self.steps = steps
        self.progress = ProgressBar(self.steps, fmt=ProgressBar.FULL)
        self.train_losses = []
        self.val_losses = []

    def train_one_step(self, x_batch, theta_batch):
        self.optimizer.zero_grad()
        ln_p_theta_given_x = self.dist_theta_given_x.condition(x_batch.detach()).log_prob(theta_batch.detach())
        train_loss = -ln_p_theta_given_x.sum()
        train_loss.backward()
        self.optimizer.step()
        self.dist_theta_given_x.clear_cache()
        return train_loss.item()

    def validate(self):
        val_loss = 0.0
        for batch_idx, (x_val_batch, theta_val_batch) in enumerate(self.val_loader):
            ln_p_theta_given_x_val = self.dist_theta_given_x.condition(x_val_batch.detach()).log_prob(theta_val_batch.detach())
            val_loss += -ln_p_theta_given_x_val.sum().item()
        val_loss /= len(self.val_loader)
        return val_loss

    def train(self):
        for step in range(self.steps):
            for batch_idx, (x_batch, theta_batch) in enumerate(self.train_loader):
                train_loss = self.train_one_step(x_batch, theta_batch)
            self.train_losses.append(train_loss)

            val_loss = self.validate()
            self.val_losses.append(val_loss)

            self.scheduler.step()  # New

            if step % 200 == 0:
                print(f'step: {step}, train_loss: {train_loss}, val_loss: {val_loss}')

            self.progress.current += 1
            self.progress()

        print('finish training')
        self.progress.done()

        # Return the final trained distribution
        return self.dist_theta_given_x

In [None]:
NPE_iter1 = NPE(x_trans, theta_trans, n, steps=801, lr=1e-4, batch_size=64)
NPE_model1 = NPE_iter1.train()

In [None]:
plt.plot(NPE_iter1.train_losses, label='Train Loss')
plt.plot(NPE_iter1.val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
theta_flow = NPE_model1.condition(y_trans).sample(torch.Size([1000,]))
posterior = inverse_logit_transform(inv_standarized_IQR(theta_flow, theta_median, theta_iqrs), lower_bounds, upper_bounds)

figure(figsize=(20, 10))
for i in range(1,8):
    plt.subplot(2,4,i)
    sns.kdeplot(data = posterior_prior[:, i-1].cpu())
    sns.kdeplot(data = posterior[:, i-1].cpu())
    #sns.kdeplot(data = bsl_pos[:, i-1])
    plt.axvline(theta_true[i-1])
    plt.legend(labels=['Uniform prior NPE','SMC-prior NPE','bsl', 'True'])
plt.show()