In [None]:
import torch as th 
from copy import deepcopy 
from sbi.inference import NLE 
from asbi.tasks import get_task

In [None]:
task = get_task("two_moons")
prior = task.get_prior_dist()
simulator = task.get_simulator()

n_sims = 1000
n_ensemble_members = 3

ensemble = [NLE(prior, density_estimator='maf') for _ in range(n_ensemble_members)]
theta = prior((n_sims,))
x = simulator(theta)

for inference in ensemble:
    _ = inference.append_simulations(theta, x).train()
    print(' training done')

flows = [deepcopy(inference._neural_net) for inference in ensemble]

In [None]:
from asbi.algorithms.EnsembleFlow import EnsembleFlow
ensemble_flow = EnsembleFlow(flows)

In [None]:
prior((1,))

In [None]:
t = prior((1,))[0]

samples = ensemble_flow.sample(33, t)

In [None]:
log_probs = ensemble_flow.log_prob(samples, t)

In [None]:
import numpy as np 

import matplotlib.pyplot as plt

# Remove outliers below -10
filtered_log_probs = np.exp(log_probs[log_probs > -10])

plt.figure(figsize=(10, 6), dpi=400)
plt.hist(filtered_log_probs.numpy(), bins=30)
plt.xlabel('Log Probability')
plt.ylabel('Frequency')
plt.title('Histogram of Log Probabilities')
plt.show()

In [None]:
import numpy as np 
import matplotlib.pyplot as plt

# Remove outliers below -10
filtered_log_probs = np.exp(log_probs[log_probs > -10])

plt.hist(filtered_log_probs.numpy(), bins=30)
plt.xlabel('Log Probability')
plt.ylabel('Frequency')
plt.title('Histogram of Log Probabilities')
plt.show()

In [None]:
ensemble_flow.compute_marginal_entropy(t, N=100000)

In [None]:
theta = prior((1,))


samples = []
for flow in flows:
    s = flow.sample((10,), theta)
    samples.append(s)

th.cat(samples, dim=0).shape

In [None]:
class EnsembleFlow:
    def __init__(self, flows) -> None:
        self.n_flows = len(flows)
        self.flows = flows

    def log_prob(self, x, condition):
        with th.no_grad():
            log_probs = [flow.log_prob(x, condition) for flow in self.flows]
            stacked = th.stack(log_probs, dim=0).mean(dim=0)
            return stacked

    def sample(self, n_samples, condition):
        # generate samples from mixture of flows
        n = n_samples // self.n_flows        
        samples = []
        for flow in self.flows:
            samples.append(flow.sample((n,), condition))

        if n_samples % self.n_flows != 0:
            samples.append(flow.sample((int(n_samples % self.n_flows),), condition))
        
        return th.cat(samples, dim=0)

In [None]:
ensemble_flow = EnsembleFlow(flows)

In [None]:
samples = ensemble_flow.sample(103, theta)

ensemble_flow.log_prob(samples, theta).mean()

In [None]:
type(inference)

In [None]:
from copy import deepcopy
flow = deepcopy(inference._neural_net)

flow

In [None]:
theta.unsqueeze(0)

In [None]:
samples = flow.sample((10,), theta.unsqueeze(0))
samples.shape

In [None]:
samples = flow.sample((2, 3), theta.unsqueeze(0))

samples.shape



In [None]:
sample

In [None]:
import torch as th 

theta = prior.sample()
with th.no_grad():  
    samples = flow.sample((1000,), theta.unsqueeze(0))
    log_prob = - th.mean(flow.log_prob(samples, theta.unsqueeze(0)))

samples.shape

In [None]:
flow.log_prob(samples, theta.unsqueeze(0))

In [None]:
inference._density_estimator

In [None]:
theta.shape