In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import json
from multiprocessing import Pool
from lampe.data import JointLoader
from itertools import islice
from tqdm import tqdm
from lampe.diagnostics import expected_coverage_mc
from lampe.plots import coverage_plot

import sys
sys.path.insert(0, "../scr/cryo_sbi/inference/")
sys.path.insert(0, "../scr/cryo_sbi/inference/models")
sys.path.insert(0, "../scr/cryo_sbi/wpa_simulator/")

from models import build_models
from cryo_em_simulator import CryoEmSimulator
import priors

In [None]:
file_name = '23_03_09_results'      # File name 
data_dir = "../experiments/benchmark_hsp90/results/raw_data/"
plot_dir = "../experiments/benchmark_hsp90/results/plots/"
config_dir = "../experiments/benchmark_hsp90/"
num_samples_stats = 20000           # Number of simulations for computing posterior stats
num_samples_SBC = 10000             # Number of simulations for SBC
num_posterior_samples_SBC = 4096    # Number of posterior samples for each SBC simulation
num_samples_posterior = 50000       # Number of samples to draw from posterior
batch_size_sampling = 100           # Batch size for sampling posterior
num_workers = 24                    # Number of CPU cores
device = 'cuda'                     # Device for computations
save_data = True
save_figures = True

## Load cryo-em simulator and posterior with correct config

In [None]:
cryosbi = CryoEmSimulator(config_dir + "image_params_snr01_128.json")

In [None]:
train_config = json.load(open(config_dir + "resnet18_encoder.json"))
estimator = build_models.build_npe_flow_model(train_config)
estimator.load_state_dict(torch.load(config_dir + "resnet18_encoder_snr01.estimator"))
estimator.cuda()
estimator.eval();

## Testing posterior on single images

In [None]:
indices = torch.tensor(np.arange(0, cryosbi.max_index + 1, 1), dtype=float)
images = torch.stack([cryosbi.simulator(index) for index in indices], dim=0)

In [None]:
theta_samples = []
with torch.no_grad():
    for batched_images in torch.split(images, split_size_or_sections=batch_size_sampling, dim=0):
        samples = estimator.sample(
            batched_images.cuda(non_blocking=True),
            shape=(num_samples_posterior,)
        ).cpu()
        theta_samples.append(
            samples
        )
    samples = torch.cat(theta_samples, dim=1)

In [None]:
if save_data:
    torch.save(
        {
        "indices": indices,
        "images": images,
        "posterior_samples": samples
        },
        f'{data_dir}{file_name}_snr01_examples.pt'
    )

In [None]:
fig, axes = plt.subplots(4, 5, figsize=(10, 8))
for idx, ax in enumerate(axes.reshape(-1)):
    ax.imshow(images[idx], vmax=5, vmin=-5)
    ax.set_yticks([])
    ax.set_xticks([])
    ax.text(10, 20, str(int(indices[idx].item())))

In [None]:
fig, axes = plt.subplots(4, 5, figsize=(10, 8), sharex=True)
for idx, ax in enumerate(axes.reshape(1, -1)[0]):
    ax.hist(samples[:, idx].numpy(), bins=np.arange(0, 20, 0.5), histtype="step", color="blue", label="all")
    ax.set_yticks([])
    ax.set_yticks([])
    ax.set_xticks(range(0, 20, 4))
    ax.axvline(indices[idx], color='red')

## Generate images from prior and evaluate posterior for statistical analysis

In [None]:
indices = priors.get_unirom_prior_1d(cryosbi.max_index).sample((num_samples_stats,))
images = torch.stack([cryosbi.simulator(index) for index in indices], dim=0)

In [None]:
theta_samples = []
with torch.no_grad():
    for batched_images in torch.split(images, split_size_or_sections=batch_size_sampling, dim=0):
        samples = estimator.sample(
            batched_images.cuda(non_blocking=True),
            shape=(num_samples_posterior,)
        ).cpu()
        theta_samples.append(
            samples.reshape(-1, batch_size_sampling)
        )
    samples = torch.cat(theta_samples, dim=1)

In [None]:
if save_data:
    torch.save(
        {
        'indices': indices,
        'images': images,
        'posterior_samples': samples
        },
        f"{data_dir}{file_name}_stats.pt"
    )

## Compute posterior calibration

In [None]:

loader = JointLoader(
    priors.get_unirom_prior_1d(cryosbi.max_index),
    cryosbi.simulator,
    vectorized=False,
    batch_size=1, 
    num_workers=num_workers,
    prefetch_factor=1
)

levels, coverages = expected_coverage_mc(
    estimator.flow,
    ((estimator.standardize(theta.cuda()), x.cuda()) for theta, x in islice(loader, num_samples_SBC)),
    n = num_posterior_samples_SBC
)

fig = coverage_plot(levels, coverages, legend='NPE')

if save_data:
    torch.save(
        {
        'levels': levels,
        'coverages': coverages
        },
        f'{data_dir}{file_name}_SBC.pt'
    )

if save_figures:
    fig.savefig(
        f'{plot_dir}{file_name}_SBC.pdf',
        dpi=400
    )

## Generate images from index 10 with known quaternion

In [None]:
quats = torch.from_numpy(np.load('quaternion_list.npy'))
num_simulations = len(quats)
indices = 10 * torch.ones(num_simulations).reshape(-1, 1)
parameters = torch.cat((indices, quats), dim=1)
images = torch.stack([cryosbi._simulator_with_quat(param) for param in parameters], dim=0)

In [None]:
theta_samples = []
with torch.no_grad():
    for batched_images in torch.split(images, split_size_or_sections=batch_size_sampling, dim=0):
        samples = estimator.sample(
            batched_images.cuda(non_blocking=True),
            shape=(num_samples_posterior,)
        ).cpu()
        theta_samples.append(
            samples.reshape(-1, batch_size_sampling)
        )
samples = torch.cat(theta_samples, dim=1)

In [None]:
if save_data:
    torch.save(
        {
        'indices': indices,
        'images': images,
        'posterior_samples': samples
        },
        f"{data_dir}{file_name}_quats.pt"
    )