In [1]:
# Here is a short tutorial in case you want to use the code.

# If you want to extend and use recovar, you should import this first
import recovar.config

from recovar import dataset
from recovar.fourier_transform_utils import fourier_transform_utils
import jax
# Fourier transform utils version that runs on GPU by default (if you have one). If you pass numpy instead, it will run on cpu.
ftu = fourier_transform_utils(jax.numpy) 

dataset_dict = dataset.get_default_dataset_option()

# Fill these options with the path to preprocessed files
experiment_directory = '/home/mg6942/mytigress/uniform/'
dataset_dict['ctf_file'] = experiment_directory + 'ctf.pkl'
dataset_dict['poses_file'] = experiment_directory + 'poses.pkl'
dataset_dict['particles_file'] = experiment_directory + 'particles.128.mrcs'

# Returns an object that knows everything about the dataset.
cryo_dataset = dataset.load_dataset_from_dict(dataset_dict)

(INFO) (ctf.py) (16-Oct-23 16:36:51) Image size (pix)  : 128
(INFO) (ctf.py) (16-Oct-23 16:36:51) A/pix             : 6.0
(INFO) (ctf.py) (16-Oct-23 16:36:51) DefocusU (A)      : 26795.69921875
(INFO) (ctf.py) (16-Oct-23 16:36:51) DefocusV (A)      : 26795.69921875
(INFO) (ctf.py) (16-Oct-23 16:36:51) Dfang (deg)       : 0.0
(INFO) (ctf.py) (16-Oct-23 16:36:51) voltage (kV)      : 300.0
(INFO) (ctf.py) (16-Oct-23 16:36:51) cs (mm)           : 2.0
(INFO) (ctf.py) (16-Oct-23 16:36:51) w                 : 0.10000000149011612
(INFO) (ctf.py) (16-Oct-23 16:36:51) Phase shift (deg) : 0.0


In [2]:
import numpy as np
import jax.numpy as jnp
# Run code on GPU (if you have one), in batch of 1000
batch_size = 1000
# Iterator used to send data in batch to GPU.
dataset_iterator = cryo_dataset.get_dataset_generator(batch_size=batch_size) 

new_stack = np.empty([cryo_dataset.n_images, *cryo_dataset.image_shape])
weiner_param = 1

# Running code on GPU with JAX is very easy. By default, any jax.numpy.array will be allocated on GPU,
# and can be used like a normal numpy.array to do operations on GPU. Once you want to send back to CPU,
# either do jax.device_put(array, device = jax.devices("cpu")[0]) or pass it to a numpy array.

# Here is a simple example: CTF correcting an image stack 
for images, batch_image_ind in dataset_iterator:
    # images are still on CPU at this point. you can do jnp.array(images) to send them to GPU explicitly, or they will be sent to GPU by any of the functions below.

    # Do some computation on GPU    
    CTFs = cryo_dataset.CTF_fun(cryo_dataset.CTF_params[batch_image_ind], cryo_dataset.image_shape, cryo_dataset.voxel_size) # Compute CTF
    images = cryo_dataset.image_stack.process_images(images) # Compute DFT, masking
    CTF_corrected_images = (CTFs / (CTFs**2 + weiner_param)) * images  # CTF correction
    
    CTF_corrected_images_real = ftu.get_idft2(CTF_corrected_images.reshape(CTF_corrected_images.shape[0], *cryo_dataset.image_shape )) # Back to real domain

    # Send back to CPU
    new_stack[batch_image_ind] = np.array(CTF_corrected_images_real.real)



In [None]:
# In recovar.core, you will find functions to do all the basic cryo-EM operations, in batches for GPU: 
import recovar.core as core
core.translate_images() 
core.get_slices() # Slice a volume 
core.sum_batch_P_adjoint_mat_vec() # Summed adjoint of slicing: v = \sum_i P_i^* im_i for some im_i
core.batch_compute_ctf_crydgrn() # CTF function
