# (Amortized) distance learning directly using a feedforward network
- define prior $p(\theta)$ and sample
- produce simulations $x \sim p(x|\theta)$
- append observations $x_o$
- compute $d(\theta_i, x_j) := d(x(\theta_i),x_j), x_j \in \{X_{simulated}, X_{observed}\}$
- train network $NN_{\phi}(\theta_i, x_j) \rightarrow d(x(\theta_i),x_j)$
- at inference time, define generalized likelihood $e^{-\beta \times NN_{\phi}(\theta, x_o)}$, and sample GBI posterior

In [1]:
%load_ext autoreload
%autoreload 2

In [51]:
import matplotlib.pyplot as plt
import torch
from torch import zeros, ones, nn, Tensor

from sbi.utils import BoxUniform, likelihood_nn
from sbi.utils.gbi import build_generalized_log_likelihood, GBIPotential, mse_dist
# from sbi.inference import SNLE, RejectionPosterior, likelihood_estimator_based_potential

In [5]:
prior = BoxUniform(-ones(2), ones(2))
def simulator(theta):
    return theta ** 2 + torch.randn(theta.shape) * 0.1

In [96]:
_ = torch.manual_seed(0)

# simulate
theta = prior.sample((500,))
x = simulator(theta)

# make "observations"
theta_obs = prior.sample((20,))
x_obs = simulator(theta_obs)

# make some misspecified examples
x_obs[10:] -= torch.randn(x_obs[10:].shape)
xs = torch.concat((x, x_obs), 0)

# compute distances
dists = torch.vstack([mse_dist(x.unsqueeze(1), x_i) for x_i in xs]).T

In [98]:
xs.shape[0]

520

In [103]:
dists.shape, theta.shape, xs.shape

# dists are in chunks of 520 i.e., the second dim
dists.reshape((-1,)).shape, theta.repeat((xs.shape[0],1)).shape

theta.shape

theta.unsqueeze(2).repeat((1,1, xs.shape[0])).shape, xs.unsqueeze(2).repeat((1,1, theta.shape[0])).shape


# plt.imshow(torch.log(dists).numpy())

(torch.Size([500, 2, 520]), torch.Size([520, 2, 500]))

In [None]:
# construct dataset where X = [theta, xs], Y = dist


In [52]:
# train network
class DistanceRegressionEstimator(nn.Module):
    def __init__(self, theta_dim, x_dim, hidden_features, num_layers):
        super().__init__()
        # self.beta_sampling = 1

        input_dim = theta_dim + x_dim
        output_dim = 1

        layers = [nn.Linear(input_dim, hidden_features), nn.ReLU()]
        for _ in range(num_layers - 1):
            layers.append(nn.Linear(hidden_features, hidden_features))
            layers.append(nn.ReLU())
            
        layers.append(nn.Linear(hidden_features, output_dim))
        self.net = nn.Sequential(*layers)
        
    def forward(self, theta, x_o_batched):
        """
        Predicts distance between theta and x_o (repeated to be same batch size).
        """        
        return self.net(torch.concat((theta, x_o_batched), dim=-1))

In [39]:
x.unsqueeze(1).shape, xs.shape

(torch.Size([1000, 1, 2]), torch.Size([1020, 2]))

In [35]:
x.unsqueeze(0) - xs[0]

tensor([[[ 0.0000,  0.0000],
         [ 0.7577,  0.1028],
         [ 0.2724, -0.3911],
         ...,
         [ 0.2755, -0.0628],
         [-0.0267, -0.3459],
         [ 0.1376,  0.1303]]])