In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import json
from multiprocessing import Pool
from lampe.data import JointLoader
from itertools import islice
from tqdm import tqdm
from lampe.diagnostics import expected_coverage_mc
from lampe.plots import coverage_plot
import mrcfile

from cryo_sbi.inference.models import build_models
from cryo_sbi import CryoEmSimulator
from cryo_sbi.inference import priors

In [None]:
file_name = '6wxb_torsion_nma'    # File name 
data_dir = "../experiments/6wxb/"
num_samples_stats = 20000           # Number of simulations for computing posterior stats
num_samples_SBC = 10000             # Number of simulations for SBC
num_posterior_samples_SBC = 4096    # Number of posterior samples for each SBC simulation
num_samples_posterior = 50000       # Number of samples to draw from posterior
batch_size_sampling = 100           # Batch size for sampling posterior
num_workers = 24                    # Number of CPU cores
device = 'cuda'                     # Device for computations
save_figures = False

## Load cryo-em simulator and posterior with correct config

In [None]:
cryosbi = CryoEmSimulator(data_dir + "image_params_training.json")

In [None]:
train_config = json.load(open(data_dir + "resnet18_encoder.json"))
estimator = build_models.build_npe_flow_model(train_config)
estimator.load_state_dict(torch.load(data_dir + "6wxb_resnet18.estimator"))
estimator.cuda()
estimator.eval();

In [None]:
cryosbi.config['SNR'] = [0.1 , 0.01]

In [None]:
indices = torch.tensor(np.arange(0, cryosbi.max_index + 1, 5), dtype=float).reshape(-1, 1)
images = torch.stack([cryosbi.simulator(index) for index in indices], dim=0)

In [None]:
fig, axes = plt.subplots(4, 5, figsize=(10, 8))
for idx, ax in enumerate(axes.reshape(-1)):
    ax.imshow(images[idx], vmax=4, vmin=-4, cmap='binary')
    ax.set_yticks([])
    ax.set_xticks([])
    ax.text(10, 20, str(int(indices[idx].item())))
#plt.savefig('../experiments/6wxb/results/plots/6wxb_examples.pdf')

In [None]:
theta_samples = []
with torch.no_grad():
    for batched_images in torch.split(images, split_size_or_sections=1, dim=0):
        batched_images[0][circular_mask(128, 64)] = 0
        #plt.imshow(batched_images[0])
        #plt.show()
        samples = estimator.sample(
            batched_images.cuda(non_blocking=True),
            shape=(num_samples_posterior,)
        ).cpu()
        theta_samples.append(
            samples
        )
    samples = torch.cat(theta_samples, dim=1)

In [None]:
fig, axes = plt.subplots(4, 5, figsize=(10, 8), sharex=True)
for idx, ax in enumerate(axes.reshape(1, -1)[0]):
    ax.hist(samples[:, idx].flatten().numpy(), bins=np.arange(0, 100, 0.5), histtype="step", color="blue", label="all")
    ax.set_yticks([])
    ax.set_yticks([])
    ax.set_xticks(range(0, 100, 20))
    ax.axvline(indices[idx], color='red')
#plt.savefig('../experiments/6wxb/results/plots/torsion_nma.pdf', dpi=300)

In [None]:
labels = np.load('../../6wxb/6wxb_models/distance_to_reference.npy')
labels_str = []
for l in labels:
    labels_str.append(f'{l:.2f}')

In [None]:
with mrcfile.open('../../6wxb/6wxb_models/particles/FoilHole_24136295_Data_24136362_24136364_20200224_020513_Fractions_particles_local_aligned.mrc') as data:
    particles = data.data
particles = torch.from_numpy(particles)

In [None]:
def f(*args):
    print(args)

f(1)

In [None]:
import torchvision.transforms as transforms

class NormalizeIndividual():
    def __init__(self) -> None:
        pass

    def __call__(self, images):
        mean = images.mean(dim=[1, 2])
        std = images.std(dim=[1, 2])
        return transforms.functional.normalize(
            images,
            mean=mean, 
            std=std
        )

    
class MRCtoTensor():
    def __init__(self) -> None:
        pass

    def __call__(self, image_path):
        assert isinstance(image_path, str), 'image path needs to be a string'
        with mrcfile.open(image_path) as mrc:
            image = mrc.data
        return torch.from_numpy(image)
    


transform = transforms.Compose([
    MRCtoTensor(),
    transforms.Resize(size=(128, 128)),
    NormalizeIndividual()
])

In [None]:
img_file = '../../6wxb_test_particles/10532/data/03_Refined_Particles/P30_J363_particles/J342/localmotioncorrected/FoilHole_24136295_Data_24136362_24136364_20200224_020513_Fractions_particles_local_aligned.mrc'
euler = '../../6wxb_test_particles/10532/data/03_Refined_Particles/P30_J363_particles/J342/localmotioncorrected/FoilHole_24136295_Data_24136362_24136364_20200224_020513_Fractions_particles_traj.npy'

In [None]:
fig, axes = plt.subplots(4, 5, figsize=(10, 8))
for i, ax in enumerate(axes.reshape(-1)):
    ax.imshow(transform(img_file)[i], vmax=4, vmin=-4, cmap='binary')
    ax.set_yticks([])
    ax.set_xticks([])

In [None]:
particles_transfomed = transform(particles)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(particles_transfomed[1], vmax=3, vmin=-3)
axes[0].set_title('Resized real particle')
axes[1].imshow(images[2], vmax=3, vmin=-3)
axes[1].set_title('Syntehtic particle')
#fig.savefig('../experiments/6wxb/results/plots/comparison_particles.pdf', dpi=400)

In [None]:
with torch.no_grad(): 
    t = estimator.embedding.resnet.conv1(particles_transfomed[10].reshape(1, 1, 128, 128).cuda())
    t = estimator.embedding.resnet.relu(t)

In [None]:
particles_transfomed[10].reshape(1, 1, 128, 128)

In [None]:
samples = estimator.sample(particles_transfomed.cuda(), shape=(50000,)).cpu()

In [None]:
for i in range(5):
    _ = plt.hist(samples[:, i].flatten().numpy(), bins=np.linspace(0, 101, 101), histtype='step')
    #_ = plt.xticks(ticks=np.linspace(0, 101, 101)[::10], labels=labels_str[::10])
#plt.savefig('../experiments/6wxb/results/plots/posteriors_100_particles.pdf', dpi=400)

In [None]:
_ = plt.hist(samples.flatten().numpy(), bins=np.linspace(0, 101, 101))
#plt.savefig('../experiments/6wxb/results/plots/averaged_posterior_all_particles.pdf', dpi=400)