In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as transforms
import json
import mrcfile
import umap

from cryo_sbi.inference.models import build_models
from cryo_sbi import CryoEmSimulator
from cryo_sbi.inference import priors
from cryo_sbi.utils.estimator_utils import sample_posterior, compute_latent_repr
from cryo_sbi.utils.image_utils import LowPassFilter, NormalizeIndividual, MRCtoTensor
from cryo_sbi.wpa_simulator.noise import circular_mask

In [None]:

fft_cut = 64
test = transforms.Compose([
    LowPassFilter(fft_cut, 128),
    NormalizeIndividual()
])

transform = transforms.Compose([
    MRCtoTensor(),
    transforms.Resize(size=(128, 128)),
    LowPassFilter(fft_cut, 128),
    NormalizeIndividual()
])

In [None]:
config = json.load(open('../experiments/6wxb/image_params_training.json'))

In [None]:
file_name = '6wxb_nma'    # File name 
data_dir = "../experiments/6wxb/"
num_samples_stats = 20000           # Number of simulations for computing posterior stats
num_samples_SBC = 10000             # Number of simulations for SBC
num_posterior_samples_SBC = 4096    # Number of posterior samples for each SBC simulation
num_samples_posterior = 50000       # Number of samples to draw from posterior
batch_size_sampling = 100           # Batch size for sampling posterior
num_workers = 24                    # Number of CPU cores
device = 'cuda'                     # Device for computations
save_figures = False

## Load cryo-em simulator and posterior with correct config

In [None]:
cryosbi = CryoEmSimulator(data_dir + "image_params_training.json")
cryosbi.config['SNR'] = 0.01

In [None]:
train_config = json.load(open(data_dir + "resnet18_encoder.json"))
estimator = build_models.build_npe_flow_model(train_config)
estimator.load_state_dict(torch.load(data_dir + "posterior_6wxb.estimator"))
estimator.cuda()
estimator.eval();

# Test simulator over whole prior range

In [None]:
indices = torch.tensor(np.arange(0, cryosbi.max_index + 1, 5), dtype=float).reshape(-1, 1)
images = torch.stack([cryosbi.simulator(index) for index in indices], dim=0)

In [None]:
images = test(images)

In [None]:
fig, axes = plt.subplots(4, 5, figsize=(10, 8))
for idx, ax in enumerate(axes.reshape(-1)):
    ax.imshow(images[idx], vmax=4, vmin=-4, cmap='binary')
    ax.set_yticks([])
    ax.set_xticks([])
    ax.text(10, 20, str(int(indices[idx].item())))
#plt.savefig('../experiments/6wxb/results/plots/6wxb_examples.pdf')

In [None]:
samples = sample_posterior(estimator, images, num_samples=10000, device='cuda')

In [None]:
fig, axes = plt.subplots(4, 5, figsize=(10, 8), sharex=True)
for idx, ax in enumerate(axes.reshape(1, -1)[0]):
    ax.hist(samples[:, idx].flatten().numpy(), bins=np.arange(0, 100, 0.7), histtype="step", color="blue", label="all")
    ax.set_yticks([])
    ax.set_yticks([])
    ax.set_xticks(range(0, 100, 20))
    ax.axvline(indices[idx], color='red')
#plt.savefig('../experiments/6wxb/results/plots/torsion_nma.pdf', dpi=300)

# Simualte setup by generating particles from the original cryo-em structure

In [None]:
indices = 50 * torch.ones(500, dtype=torch.float)
images = torch.stack([cryosbi.simulator(index) for index in indices], dim=0)

In [None]:
images = test(images)

In [None]:
samples_syn = sample_posterior(estimator, images, num_samples=10000, device='cuda')

In [None]:
fig, axes = plt.subplots(4, 5, figsize=(10, 8), sharex=True)
for idx, ax in enumerate(axes.reshape(1, -1)[0]):
    ax.hist(samples_syn[:, idx].flatten().numpy(), bins=np.arange(0, 100, 0.7), histtype="step", color="blue", label="all")
    ax.set_yticks([])
    ax.set_yticks([])
    ax.set_xticks(range(0, 100, 20))
    ax.axvline(indices[idx], color='red')
#plt.savefig('../experiments/6wxb/results/plots/torsion_nma.pdf', dpi=300)

In [None]:
plt.hist(samples_syn.mean(dim=0), bins=np.linspace(0, 100, 50))
plt.xlabel('Posterior means', fontsize=15)
plt.yticks([])

# Now with real data

In [None]:
img_file = '../../6wxb_test_particles/10532/data/03_Refined_Particles/P30_J363_particles/J342/localmotioncorrected/FoilHole_24136295_Data_24136362_24136364_20200224_020513_Fractions_particles_local_aligned.mrc'
particles_transfomed = transform(img_file)

In [None]:
samples_real = sample_posterior(estimator, particles_transfomed, num_samples=10000, device='cuda')

In [None]:
offset = 40
fig, axes = plt.subplots(4, 5, figsize=(10, 8))
for i, ax in enumerate(axes.reshape(-1)):
    ax.imshow(particles_transfomed[i+offset], vmax=4, vmin=-4, cmap='binary')
    ax.set_yticks([])
    ax.set_xticks([])

In [None]:
fig, axes = plt.subplots(4, 5, figsize=(10, 8), sharex=True)
for idx, ax in enumerate(axes.reshape(1, -1)[0]):
    ax.hist(samples_real[:, idx+offset].flatten().numpy(), bins=np.arange(0, 100, 0.7), histtype="step", color="blue", label="all")
    ax.set_yticks([])
    ax.set_yticks([])
    ax.set_xticks(range(0, 100, 20))

In [None]:
plt.hist(samples_real.mean(dim=0), bins=np.linspace(0, 100, 50))
plt.xlabel('Posterior means', fontsize=15)
plt.yticks([])

# Comparing experimental and synthetic images in the latent space of the posterior

In [None]:
num_syn_parts = 500
indices = 50 * torch.ones(num_syn_parts, dtype=torch.float)
#indices = priors.get_uniform_prior_1d(cryosbi.max_index).sample((num_syn_parts,))
images = torch.stack([cryosbi.simulator(index) for index in indices], dim=0)
images = test(images)

In [None]:
syntetic_particles_latent = compute_latent_repr(estimator, images, batch_size=100, device='cuda')
particles_latent = compute_latent_repr(estimator, particles_transfomed, batch_size=100, device='cuda')
cat_latent_samples = torch.cat((syntetic_particles_latent, particles_latent), dim=0)
labels_latent = torch.cat((torch.ones((len(indices),)), torch.zeros((161,))), dim=0)

In [None]:
reducer = umap.UMAP(metric='euclidean', n_components=2,  n_neighbors=50)
embedding = reducer.fit_transform(cat_latent_samples.numpy())

In [None]:
plt.scatter(embedding[:num_syn_parts, 0], embedding[:num_syn_parts, 1], s=3, c='blue', label='Synthetic images')
plt.scatter(embedding[num_syn_parts:, 0], embedding[num_syn_parts:, 1], s=3, c='red', label='Experimental images')
plt.xlabel('UMAP 1', fontsize=15)
plt.ylabel('UMAP 2', fontsize=15)
plt.legend(fontsize=15, markerscale=3, loc='lower left')

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(particles_transfomed[1], vmax=3, vmin=-3, cmap='binary')
axes[0].set_title('Resized real particle')
axes[1].imshow(images[42], vmax=3, vmin=-3, cmap='binary')
axes[1].set_title('Syntehtic particle')
#fig.savefig('../experiments/6wxb/results/plots/comparison_particles.pdf', dpi=400)