In [2]:
# Numerical libraries
import numpy as np
import pandas as pd
from scipy.spatial.transform import Rotation

# Plotting
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sns

# MD Stuff
import MDAnalysis as mda

# Utils
from tqdm import tqdm
import pickle

# SBI
import torch
import sbi
from sbi import utils as utils
from sbi import analysis as analysis
from sbi.inference.base import infer
from sbi.inference import SNPE, prepare_for_sbi, simulate_for_sbi
from sbi.utils.get_nn_models import posterior_nn

# Functions and utilities

In [4]:
def gen_quat(size):
    #Sonya's code
    
    # np.random.seed(0)
    quaternions = np.zeros((size, 4))
    count = 0

    while count < size:

        quat = np.random.uniform(-1,1,4) #note this is a half-open interval, so 1 is not included but -1 is
        norm = np.sqrt(np.sum(quat**2))

        if ( 0.2 <= norm <= 1.0 ):
            quaternions[count] = quat/norm
            count += 1

    return quaternions

def gen_img(coord, n_pixels, pixel_size, sigma):
    
    n_atoms = coord.shape[1]
    norm =  (2 * np.pi * sigma**2 * n_atoms)

    grid_min = -pixel_size * (n_pixels - 1)*0.5
    grid_max = pixel_size * (n_pixels - 1)*0.5 + pixel_size

    grid = np.arange(grid_min, grid_max, pixel_size)

    gauss = np.exp( -0.5 * ( ((grid[:,None] - coord[0,:])/sigma)**2) )[:,None] * \
            np.exp( -0.5 * ( ((grid[:,None] - coord[1,:])/sigma)**2) )

    Icalc = gauss.sum(axis=2)

    return Icalc/norm

def load_model(fname, filter = "name CA"):

    mda_model = mda.Universe(fname)

    # Center model
    mda_model.atoms.translate(-mda_model.select_atoms('all').center_of_mass())

    # Extract coordinates
    coordinates = mda_model.select_atoms(filter).positions.T

    return coordinates

def add_noise(img, n_pixels, pixel_size, snr):

    img_noise = np.asarray(img).reshape(n_pixels, n_pixels)
    
    rad_sq = (pixel_size * (n_pixels + 1)*0.5)**2

    grid_min = -pixel_size * (n_pixels - 1)*0.5
    grid_max = pixel_size * (n_pixels - 1)*0.5 + pixel_size

    grid = np.arange(grid_min, grid_max, pixel_size)

    mask = grid[None,:]**2 + grid[:,None]**2 < rad_sq

    var = np.std(img[mask])
    noise = np.random.normal(loc=0.0, scale = var, size=img.shape)

    img_noise = img + noise*snr

    var = np.std(img_noise)

    return img_noise/var

# Generating data

In [5]:
prior_indices = utils.BoxUniform(low=1*torch.ones(1), high=20*torch.ones(1))

In [18]:
def simulator_plain_1d(index):

    index1 = int(np.round(index))

    coord = load_model(f"models/state_1_{index1}.pdb")

    """ 
    # Uncomment if you want rotations
    quat = gen_quat(1)[0]
    rot_mat = Rotation.from_quat(quat).as_matrix()
    coord = np.matmul(rot_mat, coord)
    """

    N_PIXELS = 32
    PIXEL_SIZE = 4
    SIGMA = 1
    IMG_PARAMS = (N_PIXELS, PIXEL_SIZE, SIGMA)

    image = gen_img(coord, *IMG_PARAMS)
    
    # Uncomment if you want noise
    # SNR = 10
    image = add_noise(image, N_PIXELS, PIXEL_SIZE, SNR)

    return image

sim_plain_1d, prior_1d = prepare_for_sbi(simulator_plain_1d, prior_indices_1d)

100%|██████████| 20/20 [01:25<00:00,  4.29s/it]


# Post-processing of the generated images 

Because we are generating images with no rotations we need to add noise so pytorch doesn't die. This could also be the place to normalize images, or add the CTF. I prefer not to do add these things when generating the images because having a set of raw images allows me to do many test with just one dataset.

# Training SBI's neural network

In [20]:
density_estimator_build_fun = posterior_nn(model='maf', hidden_features=10, num_transforms=4)
inference = SNPE(prior=prior_indices, density_estimator=density_estimator_build_fun)

In [21]:
# Calculate and save posterior
inference = inference.append_simulations(indices, images)
density_estimator = inference.train()
posterior_0 = inference.build_posterior(density_estimator)

with open("posteriors/posterior_no_rot_norm.pkl", "wb") as handle:
    pickle.dump(posterior_0, handle)

Error: Canceled future for execute_request message before replies were done

In [None]:
# Or load it
with open("posteriors/posterior_no_rot.pkl", "rb") as handle:
    posterior_0 = pickle.load(handle)

# Checking the posterior trained well with one example

In [14]:
true_index = prior_indices.sample((1,))
true_index, true_images = simulate_for_sbi_cpp(indices, args_dict, n_ranks=4, n_threads=1, imgs_per_index=2500)

NameError: name 'prior_indices' is not defined