In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import json
from multiprocessing import Pool
from lampe.inference import NRE, NRELoss, MetropolisHastings
from tqdm import tqdm

In [None]:
import sys

sys.path.insert(0, "../scr/inference/")
sys.path.insert(0, "../scr/inference/models")
sys.path.insert(0, "../scr/wpa_simulator/")

In [None]:
from models import build_model
from cryo_em_simulator import CryoEmSimulator
import priors

## Load cryo-em simulator and posterior with correct config

In [None]:
cryosbi = CryoEmSimulator(
    "../data/trained_posteriors/benchmark_hsp90/image_params_snr10.json"
)

In [None]:
train_config = json.load(open("../data/trained_posteriors/benchmark_hsp90/.json"))
estimator = build_model.build_nre_classifier_model(train_config)
estimator.load_state_dict(
    torch.load("../data/trained_posteriors/benchmark_hsp90/NRE_snr01.estimator")
)

## Plot the training loss

In [None]:
loss1 = torch.load("../data/trained_posteriors/benchmark_hsp90/NRE_snr01.loss")
plt.plot(loss1.numpy(), label="Train")
# plt.plot(loss1.numpy()[1], label='Validation')
# plt.plot(loss1.numpy()[0], label='Train')
# plt.ylim((1.8, 5))
plt.ylabel("Loss", fontsize=12)
plt.xlabel("Epoch", fontsize=12)
plt.legend()
# plt.savefig('loss_snr10.pdf', dpi=600)

## Testing posterior on array of images

In [None]:
indices = torch.tensor(np.arange(0, 20, 1), dtype=float).reshape(-1, 1)
images = torch.stack([cryosbi.simulator(index) for index in indices], dim=0)

In [None]:
plt.imshow(images[9].reshape(64, 64))

In [None]:
estimator.eval()
estimator.cuda()
prior = priors.get_unirom_prior_1d()

In [None]:
with torch.no_grad():
    theta_0 = prior.sample(
        (1096, images.shape[0])
    ).cuda()  # 1024 concurrent Markov chains

    features = estimator.embedding(images.cuda())
    log_p = (
        lambda theta: estimator.nre(theta, features)
        + prior.log_prob(theta.cpu()).cuda()
    )  # p(theta | x) = r(theta, x) p(theta)

    sampler = MetropolisHastings(theta_0, log_f=log_p, sigma=0.6)
    samples = torch.cat([theta for theta in sampler(512 + 2024, burn=2024, step=8)])

In [None]:
fig, axes = plt.subplots(4, 5, figsize=(10, 10), sharex=True)
for idx, ax in enumerate(axes.reshape(1, -1)[0]):
    ax.hist(
        samples[:, idx].cpu().flatten().numpy(),
        bins=np.arange(0, 20, 0.5),
        histtype="step",
        color="blue",
        label="all",
    )
    ax.set_yticks([])
    ax.set_xticks(range(0, 20, 4))
    ax.axvline(indices[idx], color="red")
# plt.savefig('Example_NSF.pdf', dpi=400)

## Compute posterior convidence distribution

In [None]:
def batched_simulator(batche_of_indices):
    return torch.stack([cryosbi.simulator(index) for index in batche_of_indices], dim=0)

In [None]:
N_samples = 20000
indices = priors.get_unirom_prior_1d().sample((N_samples,))

In [None]:
batched_indices = torch.split(indices, split_size_or_sections=1000, dim=0)
with Pool(24) as p:
    images = p.map(batched_simulator, batched_indices)

In [None]:
images = torch.cat(images, dim=0)

In [None]:
theta_samples = []

In [None]:
estimator.cuda()
estimator.eval()
batch_size = 1000
batched_images = torch.split(images, split_size_or_sections=batch_size, dim=0)

with torch.no_grad():
    for batch in tqdm(batched_images, unit="batch"):
        theta_0 = prior.sample(
            (512, batch.shape[0])
        ).cuda()  # 1024 concurrent Markov chains

        batch = estimator.embedding(batch.cuda())
        log_p = (
            lambda theta: estimator.nre(theta, batch.cuda())
            + prior.log_prob(theta.cpu()).cuda()
        )  # p(theta | x) = r(theta, x) p(theta)

        sampler = MetropolisHastings(theta_0, log_f=log_p, sigma=0.6)
        samples = torch.cat([theta for theta in sampler(512 + 1024, burn=1024, step=8)])
        theta_samples.append(samples)

In [None]:
samples = torch.cat(theta_samples, dim=1)

In [None]:
mean_distance = (samples.mean(dim=0) - indices.reshape(-1)).numpy()

In [None]:
posterior_quantiles = np.quantile(samples.numpy(), [0.025, 0.975], axis=0)
confidence_widths = posterior_quantiles[1] - posterior_quantiles[0]

In [None]:
_ = plt.hist(
    mean_distance,
    bins=np.arange(-20, 20, 0.2),
    histtype="step",
    density=True,
    linewidth=2,
)
plt.xlabel("Posterior mean - true index")
plt.savefig()

In [None]:
_ = plt.hist(
    confidence_widths,
    bins=np.arange(0, 20, 0.4),
    histtype="step",
    density=True,
    linewidth=2,
)
plt.xlabel(r"Width of $95\%-$confidence intervall")

In [None]:
plt.scatter(confidence_widths, np.abs(mean_distance), s=0.05)
plt.plot([0, 10], [0, 10], linewidth=0.7, color="red")
plt.ylabel("|Posterior mean - true index|")
plt.xlabel("Width of $95\%-$confidence intervall")

In [None]:
np.save("confidence_widths_snr01_wideres50_128.npy", np.array(confidence_widths))
np.save("mean_distance_snr01_wideres50_128.npy", mean_distance)

In [None]:
for file, name in zip(
    [
        "confidence_widths_snr01_resnet.npy",
        "confidence_widths_snr01_wideres50.npy",
        "confidence_widths_snr10_deep_ce.npy",
        "confidence_widths_snr01_wideres50_128.npy",
    ],
    ["SNR=0.1 Resnet", "SNR=1, Resnet", "SNR=10", "SNR=0.1 128x128"],
):
    confidence_widths = np.load(file)
    _ = plt.hist(
        confidence_widths,
        bins=np.arange(0, 20, 0.3),
        histtype="step",
        density=True,
        label=name,
        linewidth=2,
    )
    plt.xlabel(r"with of $95\%-$confidence intervall")
plt.legend()
# plt.savefig('Posterior_widths_SNR.pdf', dpi=300)

# Compare with bioEM calculations

In [None]:
image = np.genfromtxt(
    f"../../BioEM_production/hsp90_images/particle_from_16_snr1", skip_header=1
)

In [None]:
image = torch.tensor(image[:, 2])

In [None]:
plt.imshow(image.reshape(64, 64))

In [None]:
assert np.isclose(image.std(), 1)

In [None]:
estimator.cuda()
estimator.eval()
thetas = torch.tensor(np.arange(0, 20, 1), dtype=torch.float).reshape(-1, 1)
with torch.no_grad():
    log_prob = estimator(thetas.cuda(), image.float().cuda()).cpu()

In [None]:
plt.plot(thetas.cpu(), log_prob)

In [None]:
np.savez_compressed(
    "posterior_lampe_large_deep_ce_snr01_16",
    theta=thetas.cpu().numpy(),
    log_prob=log_prob.numpy(),
)

## Compute posterior calibration

In [None]:
from lampe.data import JointLoader
from priors import get_unirom_prior_1d
from lampe.diagnostics import expected_coverage_ni
from lampe.plots import coverage_plot
from itertools import islice

In [None]:
loader = JointLoader(
    get_unirom_prior_1d(),
    cryosbi.simulator,
    vectorized=True,
    batch_size=1,
    num_workers=24,
    prefetch_factor=1,
)

In [None]:
estimator.cuda()
estimator.eval()

log_p = (
    lambda theta, x: estimator(theta.cuda(), x.cuda())
    + prior.log_prob(theta.cpu()).cuda()
)
nre_levels, nre_coverages = expected_coverage_ni(
    log_p, loader, (torch.tensor([0.0]), torch.tensor([19.0]))
)

fig = coverage_plot(levels, coverages, legend="NRE")
# fig.savefig('sbc_posterior.pdf', dpi=600)