In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import json
import zuko
from multiprocessing import Pool
from lampe.inference import NRE, NRELoss, MetropolisHastings
from tqdm import tqdm

### Define Lampe estimator and CNN (for loading the trained models)

In [None]:
import sys
sys.path.insert(0, "../scr/inference/")
sys.path.insert(0, "../scr/inference/models")
sys.path.insert(0, "../scr/wpa_simulator/")

In [None]:
from models import build_model
from cryo_em_simulator import CryoEmSimulator
import priors

### Load cryo-em simulator and posterior with correct config

In [None]:
cryosbi = CryoEmSimulator("../data/trained_posteriors/benchmark_hsp90/image_params_snr01.json")

In [None]:
train_config = json.load(open("../data/trained_posteriors/benchmark_hsp90/testNRE.json"))
estimator = build_model.build_nre_classifier_model(train_config)
estimator.load_state_dict(torch.load("../data/trained_posteriors/benchmark_hsp90/NRE_snr01.estimator"))

### Generate particles from double-well potential

In [None]:
import torch.distributions as D

mix = D.Categorical(torch.ones(2,))
comp = D.Normal(torch.tensor([5, 14.]), torch.tensor([2, 2]))
double_well_distribution = D.MixtureSameFamily(mix, comp)

true_indices = double_well_distribution.sample((1000,))

# Removing invalid indices
num_out_samples = torch.sum(true_indices > 19) + torch.sum(true_indices < 0) 
while num_out_samples > 0:
    true_indices[true_indices > 19] = double_well_distribution.sample((torch.sum(true_indices > 19),))
    true_indices[true_indices < 0] = double_well_distribution.sample((torch.sum(true_indices < 0),))
    num_out_samples = torch.sum(true_indices > 19) + torch.sum(true_indices < 0) 

In [None]:
counts, bins, _ = plt.hist(true_indices, bins=100, density=True)

In [None]:
plt.plot(bins[1:], -np.log(counts))

In [None]:
def batched_simulator(batche_of_indices):
    return torch.stack([cryosbi.simulator(index) for index in batche_of_indices], dim=0)

In [None]:
batched_indices = torch.split(true_indices, split_size_or_sections=100, dim=0)
with Pool(24) as p:
    images = p.map(batched_simulator, batched_indices)
images = torch.cat(images, dim=0)

In [None]:
images = torch.stack([cryosbi.simulator(index) for index in true_indices], dim=0)

In [None]:
images.shape

In [None]:
plt.imshow(images[0].reshape(64, 64))

In [None]:
image_features = estimator.embedding(images)

In [None]:
log_ratios = torch.stack([estimator.nre(theta.reshape(1), image_features) for theta in torch.linspace(0, 19,20, dtype=int)], dim=0).detach()

In [None]:
log_ratios.device

### Cryo-BIFE to extract free energ

In [None]:
G_prior = zuko.distributions.BoxUniform(0 * torch.ones(20), 10 * torch.ones(20))

In [None]:
import scipy.interpolate as ip

In [None]:
'''def log_theta_prior(G, theta):
    cs_scipy = ip.CubicSpline(torch.linspace(0, 19, G.size(0)).numpy(), G.numpy())
    y_scp = cs_scipy(theta.numpy())
    return -y_scp'''

def log_theta_prior(theta, G):
    return -G[theta]

In [None]:
log_theta_prior(torch.tensor(4 * [3]), G_prior.sample())

In [None]:
idx = 3
(log_ratios[idx] + log_theta_prior(torch.tensor(1000 * [idx]),  G_prior.sample())).shape

In [None]:
def log_images(G):
    print(G)
    log_prob_image = torch.stack([
        log_ratios[idx] + log_theta_prior(torch.tensor(1000 * [idx]), G) for idx in torch.linspace(0, 19,20, dtype=int)
    ], dim=0)
    return torch.sum(log_prob_image, dim=1) + G_prior.log_prob(G)

In [None]:
G_0 = G_prior.sample()  # 1024 concurrent Markov chains

sampler = MetropolisHastings(G_0, log_f=log_images, sigma=0.6)
samples = torch.cat([
    G for G in sampler(512 + 200, burn=1024, step=8)
])

In [None]:
plt.plot(samples.mean(dim=0))

In [None]:
estimator.eval()
estimator.cuda()

In [None]:
with torch.no_grad():
    theta_0 = prior.sample((1096, images.shape[0])).cuda()  # 1024 concurrent Markov chains
    
    features = estimator.embedding(images.cuda())
    log_p = lambda theta: estimator.nre(theta, features) + prior.log_prob(theta.cpu()).cuda()  # p(theta | x) = r(theta, x) p(theta)

    sampler = MetropolisHastings(theta_0, log_f=log_p, sigma=0.6)
    samples = torch.cat([
        theta
        for theta in sampler(512 + 2024, burn=2024, step=8)
    ])

In [None]:
fig, axes = plt.subplots(4, 5, figsize=(10, 10), sharex=True)
for idx, ax in enumerate(axes.reshape(1, -1)[0]):
    ax.hist(samples[:, idx].cpu().flatten().numpy(), bins=np.arange(0, 20, 0.5), histtype="step", color="blue", label="all")
    ax.set_yticks([])
    ax.set_xticks(range(0, 20, 4))
    ax.axvline(indices[idx], color='red')
#plt.savefig('Example_NSF.pdf', dpi=400)

## Compute posterior convidence distribution

In [None]:
def batched_simulator(batche_of_indices):
    return torch.stack([cryosbi.simulator(index) for index in batche_of_indices], dim=0)

In [None]:
N_samples = 20000
indices = priors.get_unirom_prior_1d().sample((N_samples,))

In [None]:
batched_indices = torch.split(indices, split_size_or_sections=1000, dim=0)
with Pool(24) as p:
    images = p.map(batched_simulator, batched_indices)
images = torch.cat(images, dim=0)

In [None]:
images = torch.cat(images, dim=0)

In [None]:
theta_samples = []

In [None]:
estimator.cuda()
estimator.eval()
batch_size = 1000
batched_images = torch.split(images, split_size_or_sections=batch_size, dim=0)

with torch.no_grad():
    for batch in tqdm(batched_images, unit='batch'):
        theta_0 = prior.sample((512, batch.shape[0])).cuda()  # 1024 concurrent Markov chains
        
        batch = estimator.embedding(batch.cuda())
        log_p = lambda theta: estimator.nre(theta, batch.cuda()) + prior.log_prob(theta.cpu()).cuda()  # p(theta | x) = r(theta, x) p(theta)

        sampler = MetropolisHastings(theta_0, log_f=log_p, sigma=0.6)
        samples = torch.cat([
            theta for theta in sampler(512 + 1024, burn=1024, step=8)
        ])
        theta_samples.append(samples)

In [None]:
samples = torch.cat(theta_samples, dim=1)

In [None]:
confidence_widths = []
for i in range(N_samples):
    posterior_samples = samples[:, i]
    lower_q, upper_q = np.quantile(samples[:, i].cpu().flatten().numpy(), [0.025, 0.975])
    confidence_width = upper_q - lower_q
    confidence_widths.append(confidence_width)

In [None]:
_ = plt.hist(confidence_widths, bins=np.arange(0, 20, 0.2), histtype='step', density=True, linewidth=2)

In [None]:
np.save('nre_SNR=01.npy', np.array(confidence_widths))

In [None]:
for file, name in zip(['confidence_widths_snr10_deep_ce.npy','confidence_widths_snr1_deep_ce.npy','confidence_widths_snr01_large_deep_ce.npy', 'nre_SNR=01.npy'], ['SNR=10','SNR=1','SNR=0.1','NRE']):
    confidence_widths = np.load(file)
    _ = plt.hist(confidence_widths, bins=np.arange(0, 20, 0.3), histtype='step', density=True, label=name, linewidth=2)
    plt.xlabel(r'with of $95\%-$confidence intervall')
plt.legend()
#plt.savefig('NRE_comp.pdf', dpi=300)

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(20, 10), sharex=False)
for idx, ax in enumerate(axes[0]):
    ax.imshow(images[idx].reshape(64, 64))
    ax.set_yticks([])
    ax.set_xticks([])
for idx, ax in enumerate(axes[1]):
    ax.hist(samples[:, idx].flatten().numpy(), bins=np.arange(0, 20, 0.2), histtype="step", color="blue", label="all")
    ax.set_yticks([])
    ax.set_xticks(range(0, 20, 4))
    ax.axvline(indices[idx], color='red')
plt.savefig('SNR10_examples.pdf', dpi=300)

# Compare with bioEM calculations

In [None]:
image = np.genfromtxt(f'../../BioEM_production/hsp90_images/particle_from_16_snr1', skip_header=1)

In [None]:
image = torch.tensor(image[:, 2])

In [None]:
plt.imshow(image.reshape(64, 64))

In [None]:
assert np.isclose(image.std(), 1)

In [None]:
estimator.cuda()
estimator.eval()
thetas = torch.tensor(np.arange(0, 20, 1), dtype=torch.float).reshape(-1, 1)
with torch.no_grad():
     log_prob = estimator(thetas.cuda(), image.float().cuda()).cpu()

In [None]:
plt.plot(thetas.cpu(), log_prob)

In [None]:
np.savez_compressed('posterior_lampe_large_deep_ce_snr01_16', theta=thetas.cpu().numpy(), log_prob=log_prob.numpy())

## Compute posterior calibration

In [None]:
from lampe.data import JointLoader
from priors import get_unirom_prior_1d
from lampe.diagnostics import expected_coverage_ni
from lampe.plots import coverage_plot
from itertools import islice

In [None]:
loader = JointLoader(get_unirom_prior_1d(), cryosbi.simulator, vectorized=True, batch_size=1, num_workers=24, prefetch_factor=1)

In [None]:
estimator.cuda()
estimator.eval()

log_p = lambda theta, x: estimator(theta.cuda(), x.cuda()) + prior.log_prob(theta.cpu()).cuda()
nre_levels, nre_coverages = expected_coverage_ni(log_p, loader, (torch.tensor([0.]), torch.tensor([19.])))

fig = coverage_plot(levels, coverages, legend='NRE')
#fig.savefig('sbc_posterior.pdf', dpi=600)

In [None]:
for theta, x in loader:
    print(estimator(theta.cuda(), x.cuda()))

In [None]:
''