In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import json
from multiprocessing import Pool
from lampe.data import JointLoader
from lampe.diagnostics import expected_coverage_mc
from lampe.plots import coverage_plot
from tqdm import tqdm
from itertools import islice

from cryo_sbi.inference.models import build_models
from cryo_sbi import CryoEmSimulator
from cryo_sbi.inference import priors

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
torch.randn(2, 128, 128).ndim == 2

False

In [2]:
file_name = "23_03_17_missmatch"  # File name
data_dir = "../experiments/benchmark_hsp90/results/raw_data/"
config_dir = "../experiments/benchmark_hsp90/"
num_samples_stats = 20000  # Number of simulations for computing posterior stats
num_samples_SBC = 10000  # Number of simulations for SBC
num_posterior_samples_SBC = 4096  # Number of posterior samples for each SBC simulation
num_samples_posterior = 50000  # Number of samples to draw from posterior
num_samples_umap = 2000  # Number of simualtions for UMAP analysis
batch_size_sampling = 100  # Batch size for sampling posterior
batch_size_latent = 1000  # Batch size for calculating latent representation
num_workers = 24  # Number of CPU cores
device = "cuda"  # Device for computations
save_data = True

## Load cryo-em simulator and posterior with correct config

In [None]:
cryosbi = CryoEmSimulator(config_dir + "image_params_training.json")

In [None]:
train_config = json.load(open(config_dir + "resnet18_encoder.json"))
estimator = build_models.build_npe_flow_model(train_config)
estimator.load_state_dict(torch.load(config_dir + "posterior_hsp90.estimator"))
estimator.cuda()
estimator.eval();

## Compute posterior accuracy and precision for structre missmatch

In [None]:
models = np.load(
    json.load(open(config_dir + "image_params_training.json"))["MODEL_FILE"]
)

In [None]:
for row in range(5):
    cryosbi.models = models[:, row]

    indices = priors.get_uniform_prior_1d(cryosbi.max_index).sample(
        (num_samples_stats,)
    )
    images = torch.stack([cryosbi.simulator(index) for index in indices], dim=0)

    theta_samples = []
    with torch.no_grad():
        for batched_images in torch.split(
            images, split_size_or_sections=batch_size_sampling, dim=0
        ):
            samples = estimator.sample(
                batched_images.cuda(non_blocking=True), shape=(num_samples_posterior,)
            ).cpu()
            theta_samples.append(samples.reshape(-1, batch_size_sampling))
    samples = torch.cat(theta_samples, dim=1)

    mean_distance = (samples.mean(dim=0) - indices.reshape(-1)).numpy()
    posterior_quantiles = np.quantile(samples.numpy(), [0.025, 0.975], axis=0)
    confidence_widths = posterior_quantiles[1] - posterior_quantiles[0]

    if save_data:
        np.save(
            f"{data_dir}{file_name}_confidence_widths_r={row}.npy",
            np.array(confidence_widths),
        )
        np.save(f"{data_dir}{file_name}_mean_distance_r={row}.npy", mean_distance)

In [None]:
cryosbi.models = models[:, 4]
indices = priors.get_unirom_prior_1d(cryosbi.max_index).sample((num_samples_umap,))
images_wrong_row = torch.stack([cryosbi.simulator(index) for index in indices], dim=0)

In [None]:
theta_samples = []
with torch.no_grad():
    for batched_images in torch.split(
        images_wrong_row, split_size_or_sections=batch_size_sampling, dim=0
    ):
        samples = estimator.sample(
            batched_images.cuda(non_blocking=True), shape=(num_samples_posterior,)
        ).cpu()
        theta_samples.append(samples.reshape(-1, batch_size_sampling))
samples = torch.cat(theta_samples, dim=1)

In [None]:
if save_data:
    torch.save(
        {"indices": indices, "images": images_wrong_row, "posterior_samples": samples},
        f"{data_dir}{file_name}_row4.pt",
    )

## Compute posterior calibration under model missspecification

In [None]:
all_levels = []
all_coverages = []

for i in range(5):
    cryosbi.models = models[:, i]
    loader = JointLoader(
        priors.get_unirom_prior_1d(cryosbi.max_index),
        cryosbi.simulator,
        vectorized=False,
        batch_size=1,
        num_workers=num_workers,
        prefetch_factor=1,
    )

    estimator.cuda()
    estimator.eval()

    levels, coverages = expected_coverage_mc(
        estimator.flow,
        (
            (estimator.standardize(theta.cuda()), x.cuda())
            for theta, x in islice(loader, num_samples_SBC)
        ),
        n=num_posterior_samples_SBC,
    )

    all_levels.append(levels)
    all_coverages.append(coverages)

all_levels = torch.stack(all_levels, dim=1)
all_coverages = torch.stack(all_coverages, dim=1)

if save_data:
    torch.save([all_levels, all_coverages], f"{data_dir}{file_name}_sbc_rows")

## Compute posterior accuracy and precision for SNR

In [None]:
# Reset simulator
cryosbi = CryoEmSimulator(config_dir + "image_params_snr01_128.json")
snrs = np.logspace(-0.5, -1.5, 9)

In [None]:
for snr in snrs:
    cryosbi.config["SNR"] = snr

    indices = priors.get_unirom_prior_1d(cryosbi.max_index).sample((num_samples_stats,))
    images = torch.stack([cryosbi.simulator(index) for index in indices], dim=0)

    theta_samples = []
    with torch.no_grad():
        for batched_images in torch.split(
            images, split_size_or_sections=batch_size_sampling, dim=0
        ):
            samples = estimator.sample(
                batched_images.cuda(non_blocking=True), shape=(num_samples_posterior,)
            ).cpu()
            theta_samples.append(samples.reshape(-1, batch_size_sampling))
    samples = torch.cat(theta_samples, dim=1)

    mean_distance = (samples.mean(dim=0) - indices.reshape(-1)).numpy()
    posterior_quantiles = np.quantile(samples.numpy(), [0.025, 0.975], axis=0)
    confidence_widths = posterior_quantiles[1] - posterior_quantiles[0]

    if save_data:
        np.save(
            f"{data_dir}{file_name}_confidence_widths_snr={snr}.npy",
            np.array(confidence_widths),
        )
        np.save(f"{data_dir}{file_name}_mean_distance_snr={snr}.npy", mean_distance)

In [None]:
cryosbi.config["SNR"] = 0.03
indices = priors.get_unirom_prior_1d(cryosbi.max_index).sample((num_samples_umap,))
images_wrong_snr = torch.stack([cryosbi.simulator(index) for index in indices], dim=0)

In [None]:
theta_samples = []
with torch.no_grad():
    for batched_images in torch.split(
        images_wrong_snr, split_size_or_sections=batch_size_sampling, dim=0
    ):
        samples = estimator.sample(
            batched_images.cuda(non_blocking=True), shape=(num_samples_posterior,)
        ).cpu()
        theta_samples.append(samples.reshape(-1, batch_size_sampling))
samples = torch.cat(theta_samples, dim=1)

In [None]:
if save_data:
    torch.save(
        {"indices": indices, "images": images_wrong_row, "posterior_samples": samples},
        f"{data_dir}{file_name}_row4.pt",
    )

In [None]:
all_levels = []
all_coverages = []

for snr in snrs:
    cryosbi.config["SNR"] = snr
    loader = JointLoader(
        priors.get_unirom_prior_1d(cryosbi.max_index),
        cryosbi.simulator,
        vectorized=False,
        batch_size=1,
        num_workers=24,
        prefetch_factor=1,
    )

    levels, coverages = expected_coverage_mc(
        estimator.flow,
        (
            (estimator.standardize(theta.cuda()), x.cuda())
            for theta, x in islice(loader, num_samples_SBC)
        ),
        n=num_posterior_samples_SBC,
    )

    all_levels.append(levels)
    all_coverages.append(coverages)

all_levels = torch.stack(all_levels, dim=1)
all_coverages = torch.stack(all_coverages, dim=1)

if save_data:
    torch.save([all_levels, all_coverages], f"{data_dir}{file_name}_sbc_snr")

## Model missspecification non-Gaussian noise

In [None]:
import simulator_colored_noise

In [None]:
cryosbi_colored_noise = simulator_colored_noise.CryoEmSimulatorColoredNoise(
    config_dir + "image_params_snr01_128.json"
)

In [None]:
indices = priors.get_unirom_prior_1d(cryosbi_colored_noise.max_index).sample(
    (num_samples_stats,)
)
images = torch.stack(
    [cryosbi_colored_noise.simulator(index) for index in indices], dim=0
)

theta_samples = []
with torch.no_grad():
    for batched_images in torch.split(
        images, split_size_or_sections=batch_size_sampling, dim=0
    ):
        samples = estimator.sample(
            batched_images.cuda(non_blocking=True), shape=(num_samples_posterior,)
        ).cpu()
        theta_samples.append(samples.reshape(-1, batch_size_sampling))
samples = torch.cat(theta_samples, dim=1)

mean_distance = (samples.mean(dim=0) - indices.reshape(-1)).numpy()
posterior_quantiles = np.quantile(samples.numpy(), [0.025, 0.975], axis=0)
confidence_widths = posterior_quantiles[1] - posterior_quantiles[0]

if save_data:
    np.save(
        f"{data_dir}{file_name}_confidence_widths_gradient_snr.npy",
        np.array(confidence_widths),
    )
    np.save(f"{data_dir}{file_name}_mean_distance_gradient_snr.npy", mean_distance)

In [None]:
loader = JointLoader(
    priors.get_unirom_prior_1d(cryosbi_colored_noise.max_index),
    cryosbi_colored_noise.simulator,
    vectorized=False,
    batch_size=1,
    num_workers=num_workers,
    prefetch_factor=1,
)

levels, coverages = expected_coverage_mc(
    estimator.flow,
    (
        (estimator.standardize(theta.cuda()), x.cuda())
        for theta, x in islice(loader, num_samples_SBC)
    ),
    n=num_posterior_samples_SBC,
)

if save_data:
    torch.save([levels, coverages], f"{data_dir}{file_name}_sbc_gradient_snr.pt")

## Model missspecification: No particle

In [None]:
images_no_particle = torch.randn((num_samples_umap, 1, 128, 128))

In [None]:
plt.imshow(images_no_particle[0].reshape(128, 128))

In [None]:
theta_samples = []
with torch.no_grad():
    for batched_images in torch.split(
        images_no_particle, split_size_or_sections=batch_size_sampling, dim=0
    ):
        samples = estimator.sample(
            batched_images.cuda(non_blocking=True), shape=(num_samples_posterior,)
        ).cpu()
        theta_samples.append(samples.reshape(-1, batch_size_sampling))
samples = torch.cat(theta_samples, dim=1)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[1].hist(
    samples[:, :10].numpy(),
    bins=np.arange(0, 10, 0.1),
    histtype="step",
    density=True,
    linewidth=2,
)
axes[1].set_xlabel("Index")
axes[0].imshow(images_no_particle[0].reshape(128, 128))
plt.savefig(f"Posterior_no_particles_{file_name}.pdf", dpi=500)

## Model missspecification: Wrong particles (One arm of Hsp90)

In [None]:
cryosbi = CryoEmSimulator(data_dir + "image_params_snr01_128.json")
cryosbi.models = models[:, 0, :, :603]

In [None]:
test_image = cryosbi.simulator(torch.tensor([10.0]))
samples = estimator.sample(test_image.cuda(), shape=(100000,)).cpu()

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(test_image.reshape(128, 128))
_ = axes[1].hist(
    samples.flatten().numpy(),
    bins=np.arange(0, 20, 0.1),
    histtype="step",
    density=True,
    linewidth=2,
)
fig.savefig("Example_wrong_particle.pdf", dpi=500)

In [None]:
indices = priors.get_unirom_prior_1d(cryosbi.get_max_index()).sample(
    (num_samples_umap,)
)
images_wrong_particle = torch.stack(
    [cryosbi.simulator(index) for index in indices], dim=0
)

In [None]:
theta_samples = []
with torch.no_grad():
    for batched_images in torch.split(
        images_wrong_particle, split_size_or_sections=batch_size_sampling, dim=0
    ):
        samples = estimator.sample(
            batched_images.cuda(non_blocking=True), shape=(num_samples_posterior,)
        ).cpu()
        theta_samples.append(samples.reshape(-1, batch_size_sampling))
samples = torch.cat(theta_samples, dim=1)

In [None]:
_ = plt.hist(
    samples[:, :20].numpy(),
    bins=np.arange(0, 20, 0.1),
    histtype="step",
    density=True,
    label=name,
    linewidth=2,
)
plt.xlabel("Index")

## Detecting model missspecification

In [None]:
images_mstar = [
    images_no_particle,
    images_wrong_particle,
    images_wrong_row,
    images_wrong_snr,
]

In [None]:
latent_repr_mstar = []
with torch.no_grad():
    for images in images_mstar:
        latent_space_samples = []
        for batched_images in torch.split(
            images, split_size_or_sections=batch_size_latent, dim=0
        ):
            samples = estimator.embedding(batched_images.cuda(non_blocking=True)).cpu()
            latent_space_samples.append(samples.reshape(batch_size_latent, -1))
        latent_repr_mstar.append(torch.cat(latent_space_samples, dim=0))

In [None]:
cryosbi = CryoEmSimulator(data_dir + "image_params_snr01_128.json")
indices = priors.get_unirom_prior_1d(cryosbi.get_max_index()).sample(
    (num_samples_umap,)
)
images = torch.stack([cryosbi.simulator(index) for index in indices], dim=0)

In [None]:
latent_space_samples = []
batch_size = 1000
with torch.no_grad():
    for batched_images in torch.split(
        images, split_size_or_sections=batch_size_latent, dim=0
    ):
        samples = estimator.embedding(batched_images.cuda(non_blocking=True)).cpu()
        latent_space_samples.append(samples.reshape(batch_size_latent, -1))
latent_repr_m = torch.cat(latent_space_samples, dim=0)

In [None]:
theta_samples = []
with torch.no_grad():
    for batched_images in torch.split(
        images, split_size_or_sections=batch_size_sampling, dim=0
    ):
        samples = estimator.sample(
            batched_images.cuda(non_blocking=True), shape=(num_samples_posterior,)
        ).cpu()
        theta_samples.append(samples.reshape(-1, batch_size_sampling))
samples = torch.cat(theta_samples, dim=1)

In [None]:
posterior_quantiles = np.quantile(samples.numpy(), [0.025, 0.975], axis=0)
confidence_widths = posterior_quantiles[1] - posterior_quantiles[0]

In [None]:
cat_latent_samples = torch.cat((latent_repr_m, *latent_repr_mstar))

labels = torch.cat(
    (
        torch.zeros(10000),
        1 * torch.ones(10000),
        2 * torch.ones(10000),
        3 * torch.ones(10000),
        4 * torch.ones(10000),
    )
)

In [None]:
import umap

reducer = umap.UMAP(metric="euclidean", n_components=2, n_neighbors=50)
embedding = reducer.fit_transform(cat_latent_samples.numpy())

In [None]:
name = [
    "Ground truth",
    "No particle",
    "Wrong particle",
    "Wrong row (4)",
    "Wrong SNR (0.03)",
]
colors = ["red", "blue", "green", "yellow", "black"]
fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharey=True)

for idx, i in enumerate(
    range(0, (len(images_mstar) + 1) * num_samples_umap, num_samples_umap)
):
    axes[0].scatter(
        embedding[i : i + num_samples_umap, 0],
        embedding[i : i + num_samples_umap, 1],
        label=name[idx],
        c=colors[idx],
        s=0.1,
    )
axes[0].legend(fontsize=10, markerscale=10)
axes[0].set_ylabel("UMAP dimsenion 1")

im1 = axes[1].scatter(
    embedding[0:num_samples_umap, 0],
    embedding[0:num_samples_umap, 1],
    c=indices.numpy(),
    s=0.5,
    cmap="viridis",
)
fig.colorbar(im1, ax=axes[1], label="Index")
axes[1].set_xlabel("UMAP dimsenion 2")

im2 = axes[2].scatter(
    embedding[0:num_samples_umap, 0],
    embedding[0:num_samples_umap, 1],
    c=confidence_widths,
    s=0.5,
    cmap="viridis",
)
fig.colorbar(im2, ax=axes[2], label="Width of 95% confidence intervall")
plt.savefig(f"UMAP_analysis_{file_name}.png", dpi=500)

In [None]:
def sample_posterior(estimator, images, num_samples, batch_size=100, device="cpu"):
    theta_samples = []

    if images.shape[0] > batch_size:
        images = torch.split(images, split_size_or_sections=batch_size, dim=0)

    with torch.no_grad():
        for image_batch in images:
            samples = estimator.sample(
                image_batch.to(device, non_blocking=True), shape=(num_samples,)
            ).cpu()
            theta_samples.append(samples.reshape(-1, batch_size))

    return torch.cat(theta_samples, dim=1)


def compute_latent_repr(estimator, images, batch_size=1000, device="cpu"):
    latent_space_samples = []

    if images.shape[0] > batch_size:
        images = torch.split(images, split_size_or_sections=batch_size, dim=0)
    else:
        batch_size = 1

    with torch.no_grad():
        for image_batch in images:
            samples = estimator.embedding(
                image_batch.to(device, non_blocking=True)
            ).cpu()
            latent_space_samples.append(samples.reshape(batch_size, -1))

    return torch.cat(latent_space_samples, dim=0)