In [1]:
import sys
sys.path.append('../')
from sbi.inference.snle.snle_a import SNLE_A
from sbi.inference.base import *
import sbi.utils as utils
from sbi.utils.get_nn_models import *
from sbi.utils.sbiutils import *
from sbi.utils.torchutils import *
from sbi.inference.base import infer
import torch
import numpy as np
import tqdm
import matplotlib.pyplot as plt
from torch import nn
import seaborn as sns
import random
import pickle
from pyknos.nflows import flows
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D

from simulators.ricker import ricker
from utils.metrics import RMSE
from utils.plot_config import update_plot_style
import io
import time
from torch.utils.data import DataLoader

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [2]:
def temporalMomentsGeneral(Y, K=3, B=4e9):
    N, Ns = Y.shape
    tau = np.linspace(0, 100, Ns)
    out = np.zeros((N, K))
    Y = Y.detach().numpy()
    for k in range(K):
        for i in range(N):
            out[i, k] = np.trapz(tau**(k) * Y[i], tau) + 1e-4
    return np.log(out)

In [3]:
num_simulations = 1000
N = 100

In [4]:
obs_cont = torch.tensor(np.load(f"../data/ricker_obs_1.npy")).to(device)

theta = torch.tensor(np.load("../data/ricker_theta_1000.npy")).to(device)
x = torch.tensor(np.load("../data/ricker_x_1000.npy")).reshape(num_simulations, N, 100).to(device)

dataloader = DataLoader(x, batch_size=200, shuffle=True)

In [6]:
prior = [Uniform(2 * torch.ones(1).to(device), 8 * torch.ones(1).to(device)),
         Uniform(torch.zeros(1).to(device), 20 * torch.ones(1).to(device))]
prior_new = [Uniform(2 * torch.ones(1).to(device), 8 * torch.ones(1).to(device)),
         Uniform(torch.zeros(1).to(device), 80 * torch.ones(1).to(device))]

simulator, prior = prepare_for_sbi(ricker(N=N), prior)
simulator, prior_new = prepare_for_sbi(ricker(N=N), prior_new)

In [7]:
class RickerSummary(nn.Module):
    def __init__(self, input_size, hidden_dim):
        super(RickerSummary, self).__init__()

        self.hidden_dim = hidden_dim
        self.input_size = input_size
        
        self.encoder = nn.Sequential(nn.Conv1d(self.input_size, 4, 3, 4),
                                     nn.Conv1d(4, 4, 3, 4),
                                     nn.Conv1d(4, 4, 3, 4),
                                     )
        
        self.decoder = nn.Sequential(nn.ConvTranspose1d(4, 4, 3, 4),
                                     nn.ConvTranspose1d(4, 4, 3, 4),
                                     nn.ConvTranspose1d(4, self.input_size, 3, 4),
                                     nn.Upsample(100)
                                     )

    def forward(self, Y):
        embeddings = self.encoder(Y.reshape(-1, 1, 100))
        output = self.decoder(embeddings.reshape(-1, 4, 1)).reshape(-1, 100, 100)
        return output
    
    def forward_encoder(self, Y):
        embeddings = self.encoder(Y.reshape(-1, 1, 100)).reshape(-1, 100, 4)
        return embeddings

In [8]:
def solve_normal():
    summary_net_normal = RickerSummary(1, 4).to(device)

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(summary_net_normal.parameters(), lr=0.01)

    # Train the model for some number of epochs
    num_epochs = 10
    time_list = []
    for epoch in range(num_epochs):
        running_loss = 0.0
        
        for data in dataloader:
            start_time = time.time()
            inputs = data
            optimizer.zero_grad()

            outputs = summary_net_normal(inputs)

            loss = criterion(outputs, inputs) / 10000

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {running_loss/len(dataloader)}")
    return summary_net_normal

In [9]:
def solve_robust(beta, obs_cont):
    summary_net_robust = RickerSummary(1, 4).to(device)

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(summary_net_robust.parameters(), lr=0.01)

    index_list = [int(i) for i in range(len(x))]

    num_epochs = 10
    for epoch in range(num_epochs):
        running_loss = 0.0
        
        for data in dataloader:
            inputs = data
            optimizer.zero_grad()

            outputs = summary_net_robust(inputs)

            random.shuffle(index_list)
            context_embeddings = torch.mean(summary_net_robust.forward_encoder(x[index_list[:200]]), dim=1)
            obs_embeddings = torch.mean(summary_net_robust.forward_encoder(obs_cont), dim=1)

            ae_loss = criterion(outputs, inputs) / 10000
            summary_loss = metrics.MMD_unweighted(context_embeddings, obs_embeddings, lengthscale=metrics.median_heuristic(context_embeddings))

            loss = ae_loss + beta*summary_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {running_loss/len(dataloader)}")
    return summary_net_robust

In [12]:
summary_net_normal = solve_normal()

x_summary = torch.sum(summary_net_normal.forward_encoder(x), dim=1).cpu().detach().numpy()
obs_summary = torch.sum(summary_net_normal.forward_encoder(obs_cont), dim=1)

theta = theta.to(device)

inference_normal = SNLE_A(prior=prior, device='cpu')
density_estimator_normal = inference_normal.append_simulations(theta=theta, x=torch.tensor(x_summary).to(device))
density_estimator_normal.train()

posterior_normal = inference_normal.build_posterior(prior=prior_new)

Epoch 1, Loss: 18.409171104431152
Epoch 2, Loss: 18.078046607971192
Epoch 3, Loss: 16.805509185791017
Epoch 4, Loss: 13.82416648864746
Epoch 5, Loss: 13.2974702835083
Epoch 6, Loss: 12.692077159881592
Epoch 7, Loss: 12.869472885131836
Epoch 8, Loss: 12.642597198486328
Epoch 9, Loss: 12.64726676940918
Epoch 10, Loss: 12.586797142028809


In [17]:
summary_net_ours = solve_robust(3, obs_cont)

x_summary_robust = torch.sum(summary_net_ours.forward_encoder(x), dim=1).cpu().detach().numpy()
obs_summary_robust = torch.sum(summary_net_ours.forward_encoder(obs_cont), dim=1)

theta = theta.to(device)
inference_robust = SNLE_A(prior=prior, device='cpu')
density_estimator_robust = inference_robust.append_simulations(theta=theta, x=torch.tensor(x_summary_robust).to(device))
density_estimator_robust.train()

posterior_robust = inference_robust.build_posterior()

Epoch 1, Loss: 18.804403686523436
Epoch 2, Loss: 17.69242706298828
Epoch 3, Loss: 14.976074600219727
Epoch 4, Loss: 13.736443901062012
Epoch 5, Loss: 13.227659797668457
Epoch 6, Loss: 13.415997886657715
Epoch 7, Loss: 13.16824893951416
Epoch 8, Loss: 13.171481895446778
Epoch 9, Loss: 13.08675537109375
Epoch 10, Loss: 13.099022483825683
 Neural network successfully converged after 413 epochs.

In [35]:
robust_samples = posterior_robust.sample([100], obs_summary_robust)
normal_samples = posterior_normal.sample([100], obs_summary)

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/1100 [00:00<?, ?it/s]

In [36]:
theta_gt = torch.tensor([4, 10])
print("RMSE normal", float(RMSE(theta_gt, normal_samples, p=2)))
print("RMSE robust", float(RMSE(theta_gt, robust_samples, p=2)))

RMSE normal 9.191390037536621
RMSE robust 3.748779535293579


In [37]:
ricker_simulator = ricker(N=1)
obs_stat = torch.tensor(temporalMomentsGeneral(obs_cont.reshape(100, 100)))

predictive_data_normal = torch.zeros(100, 100)
predictive_data_robust = torch.zeros(100, 100)
for j in range(100):
    predictive_data_normal[j] = ricker_simulator(normal_samples[j])[0]
    predictive_data_robust[j] = ricker_simulator(robust_samples[j])[0]

pred_stat_normal = torch.tensor(temporalMomentsGeneral(predictive_data_normal))
pred_stat_robust = torch.tensor(temporalMomentsGeneral(predictive_data_robust))

mmd_normal = float(metrics.MMD_unweighted(pred_stat_normal, obs_stat, lengthscale=1))
mmd_robust = float(metrics.MMD_unweighted(pred_stat_robust, obs_stat, lengthscale=1))

print("mmd normal", mmd_normal)
print("mmd robust", mmd_robust)

mmd normal 0.8686911479283871
mmd robust 0.1915174432457829


In [None]:
graph = sns.jointplot(x=normal_samples[:, 0], y=normal_samples[:, 1],
                 cmap="Oranges", kind="kde", height=4, marginal_kws={"color":"C1", "alpha":.5, "shade":True}, shade=True, thresh=0.05, alpha=.5,
                 label='NLE')


graph.x = robust_samples[:, 0]
graph.y = robust_samples[:, 1]
graph.plot_joint(sns.kdeplot, cmap="Blues", shade=True, alpha=.5, label='Ours')
graph.ax_joint.axvline(x=theta_gt[0], lw=1, ls="-",c="black", label="True $\\theta$")
graph.ax_joint.axhline(y=theta_gt[1], lw=1, ls="-",c="black")

graph.ax_joint.axvline(x=2, ls="--", lw=1, c="gray", alpha=0.3)
graph.ax_joint.axvline(x=8, ls="--", lw=1, c="gray", alpha=0.3)

graph.ax_joint.axhline(y=0, ls="--", lw=1,c="gray", alpha=0.3)
graph.ax_joint.axhline(y=20, ls="--", lw=1, c="gray", alpha=0.3)



legend_elements = [Line2D([0], [0], color='k', lw=1, label='True $\\theta$'),
                   Patch(facecolor='C0', edgecolor='C0',
                         label='Ours'),
                   Patch(facecolor='C1', edgecolor='C1',
                         label='NLE')]


graph.ax_joint.legend(handles=legend_elements, loc='upper right', fontsize=10) 


graph.ax_joint.set_xlabel('$\\theta_1$')
graph.ax_joint.set_ylabel('$\\theta_2$')

graph.plot_marginals(sns.kdeplot, color='C0', shade=True, alpha=.5, legend=False)
plt.tight_layout()
plt.savefig("NLE_posterior.pdf", dpi = 300)
