In [1]:
from typing import List

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Linear, Sequential, Module


class FeedForwardNetwork(Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        architecture: List[int] = [],
        activation_function: str = "ReLU",
        device: str = "cpu",
        *args,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)
        self._input_dim = input_dim
        self._output_dim = output_dim

        self._activation_function_type = getattr(nn, activation_function)
        self._linear = self._create_linear_unit(architecture).to(device)

    def _create_linear_unit(self, architecture: List[int]) -> Sequential:
        """creates a linear unit specified with architecture and self._activation_function_type

        Args:
            architecture (List[int]): dimension of linear layers

        Returns:
            Sequential: sequential linear unit
        """
        # input layer
        if len(architecture) == 0:
            return Linear(self._input_dim, self._output_dim)

        layers = [
            Linear(self._input_dim, int(architecture[0])),
            self._activation_function_type(),
        ]
        # add hidden layers
        for idx in range(len(architecture) - 1):
            layers.extend(
                [
                    Linear(int(architecture[idx]), int(architecture[idx + 1])),
                    self._activation_function_type(),
                ]
            )
        # output layer
        layers.append(Linear(architecture[-1], self._output_dim))
        sequence = Sequential(*layers)
        return sequence

    def forward(self, x: Tensor):
        return self._linear(x)
    


class SBINetwork(Module):
    def __init__(self, theta_dim: int, simulator_out_dim: int, latent_dim: int = 256, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self._theta_encoder = FeedForwardNetwork(theta_dim, latent_dim, [256])
        self._simulator_out_encoder = FeedForwardNetwork(simulator_out_dim, latent_dim, [256])
        self._collector = FeedForwardNetwork(2 * latent_dim, 1, [256, 256, 128])

    def forward(self, theta: Tensor, x_target: Tensor) -> Tensor:
        """_summary_

        Args:
            theta (Tensor): (batch_size, theta_dim)
            x_target (Tensor): (batch_size, n_target, simulator_dim)

        Returns:
            Tensor: (batch_size, n_target, 1)
        """
        theta_enc = self._theta_encoder.forward(theta)
        simulator_out_enc = self._simulator_out_encoder.forward(x_target)
        # repeat the theta  encoding along the n_target dimension
        theta_repeat_dim = (1, simulator_out_enc.shape[1], 1)
        theta_enc = theta_enc[:, None].repeat(theta_repeat_dim)
        
        res = self._collector(torch.cat([theta_enc, simulator_out_enc], dim=-1))
        return res
    

batch_size = 32
theta_dim = 5
sim_out_dim = 2
n_target = 7

theta = torch.rand(batch_size, theta_dim)
x = torch.rand(batch_size, sim_out_dim)
x_target = torch.rand(batch_size, n_target, sim_out_dim)


net = SBINetwork(theta_dim, sim_out_dim)
net.forward(theta, x_target).shape

torch.Size([32, 7, 1])

In [2]:
from typing import Tuple
from lightning import LightningModule
import torch


class SBICriterion(nn.Module):
    def __init__(
        self,
        distance_order: int = 2.0,
    ):
        self._distance_order = distance_order

    def forward(self, pred: Tensor, x: Tensor, x_target: Tensor) -> Tensor:
        """_summary_

        Args:
            pred (Tensor): (batch_size, 1)
            x (Tensor): (batch_size, n_sim_features)
            x_target (Tensor): (batch_size, n_target, n_sim_features)

        Returns:
            Tensor: loss
        """
        # distance matrix
        d = self.sample_distance(x, x_target)
        squared_distance = torch.float_power(pred[..., None] - d, 2)
        squared_distance = torch.sum(squared_distance, dim=-1)
        return torch.mean(squared_distance)

    def sample_distance(self, x: Tensor, x_target: Tensor) -> Tensor:
        """compute L2 distance

        Args:
            x (Tensor): (batch_size, n_sim_features)
            x_target (Tensor): (batch_size, n_target, n_sim_features)

        Returns:
            Tensor: (batch_size, n_target)
        """
        d = x[:, None] - x_target
        distance = torch.linalg.norm(d, ord=self._distance_order, dim=-1)
        return distance


class SBI(LightningModule):
    def __init__(self, prior_dim: int, simulator_out_dim: int, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.net = SBINetwork(
            theta_dim=prior_dim, simulator_out_dim=simulator_out_dim, latent_dim=256
        )
        self.criterion = SBICriterion(distance_order=2)


    def training_step(self, batch: Tuple[Tensor, Tensor, Tensor], batch_idx: int):
        prior, simulator_out, x_target = batch
        network_res = self.forward(prior, x_target)
        loss = self.criterion.forward(network_res, simulator_out, x_target)
        return loss

    def forward(self, prior: Tensor, x_target: Tensor) -> Tensor:
        """_summary_

        Args:
            prior (Tensor): (batch_size, n_prior_features)
            x_target (Tensor): (batch_size, n_target, n_sim_features)

        Returns:
            Tensor: (batch_size, n_target)
        """
        return self.net.forward(prior, x_target)
    
pdist = nn.PairwiseDistance(p=2)
input1 = torch.randn(batch_size, 1, 128)
input2 = torch.randn(batch_size, n_target, 128)
output = pdist(input1, input2)
print(output.shape)

output.shape, net.forward(theta, x_target).shape


criterion = SBICriterion()
criterion.forward(net.forward(theta, x_target), x,x_target)

torch.Size([32, 7])


tensor(2.2185, dtype=torch.float64, grad_fn=<MeanBackward0>)

In [2]:
from gbi_diff.dataset import SBIDataset

dataset = SBIDataset.from_file("data/moon_100000.pt")



{'_theta': tensor([[ 0.8570, -0.7816],
        [-0.9455,  0.6795],
        [-0.5851,  0.7095],
        ...,
        [ 0.5830, -0.5815],
        [-0.1822, -0.6158],
        [ 0.4536,  0.3972]]), '_x': tensor([[ 0.3070, -1.1536],
        [ 0.1359,  1.2175],
        [ 0.2363,  0.8836],
        ...,
        [ 0.2641, -0.9388],
        [-0.2093, -0.2905],
        [-0.3317,  0.0571]]), '_target': tensor([[ 0.2965, -1.1465],
        [ 0.1270,  1.2170],
        [ 0.2245,  0.8811],
        ...,
        [ 0.2732, -0.9494],
        [-0.1932, -0.2865],
        [-0.3353,  0.0664]]), '_measured': None, '_target_noise_std': 0.01}


<gbi_diff.dataset.dataset.SBIDataset at 0x7e7497e2f290>

In [5]:
pdist = nn.PairwiseDistance(p=2)
input1 = torch.randn(batch_size, 1, 128)
input2 = torch.randn(batch_size, n_target, 128)
output = pdist(input1, input2)
print(output.shape)

output.shape, net.forward(theta, x_target).shape


criterion = SBICriterion()
criterion.forward(net.forward(theta, x_target), x,x_target)

torch.Size([32, 7])


tensor(2.0171, dtype=torch.float64, grad_fn=<MeanBackward0>)