In [None]:
%matplotlib inline
import os
import sys
import starfile
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import mrcfile
from tqdm import tqdm
import configargparse
module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)
module_path = os.path.abspath(os.path.join('../src/experiment_scripts'))
if module_path not in sys.path:
    sys.path.append(module_path)
from aNiMAte.src.experiment_scripts.main import init_config
from aNiMAte.src.atomic_utils import AtomicModel
from aNiMAte.src.prody_utils import read_prody_model
from aNiMAte.src.dynamics_utils import DynamicsModelNMA
from prody import *
import pykeops

DATA_PATH = '/sdf/group/ml/CryoNet/experiments/ake/simulated/'
CONFIG_PATH =  '/sdf/group/ml/CryoNet/train_configs/ak-atomic-primal-sim.ini'
PDB_PATH = '/sdf/group/ml/CryoNet/experiments/ake/simulated/models/frames/'
START_PDB = os.path.join(PDB_PATH, 'new_frame00.pdb')
DATASET_NUM = 50
MAX_NUM_MODES = 128
pykeops.set_verbose(False)

def populate_starfile_list():
    parser = configargparse.ArgParser()
    parser.add_argument('-c', '--config', required=False, is_config_file=True,
                        help='Path to config file.', default=CONFIG_PATH)
    parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1")
    init_config(parser)  # use the default arguments from main.init_config to stay synchronized
    config = parser.parse_args()
    atomic_model = AtomicModel(START_PDB, config.atomic_clean_pdb, config.atomic_center,
                                   pdb_out=os.path.join('.', 'curated_gemmi.pdb'))
    prody_model = read_prody_model('curated_gemmi.pdb')
    atoms1 = prody_model.getCoords().reshape((-1, 1))
    num_atoms = prody_model.getCoords().shape[0]
    atomic_model = DynamicsModelNMA(atomic_model, atomic_clean_pdb=config.atomic_clean_pdb,
                                atomic_cg_selection=config.atomic_cg_selection,
                                atomic_nma_cutoff=config.atomic_nma_cutoff,
                                atomic_nma_gamma=config.atomic_nma_gamma,
                                atomic_nma_number_modes=MAX_NUM_MODES,
                                by_chain=config.atomic_nma_by_chain)
    eigvecs = atomic_model.eigvecs.detach().cpu().numpy()
    eigvals = atomic_model.eigvals.detach().cpu().numpy()
    Q = eigvecs.reshape((-1, MAX_NUM_MODES))
    S = np.matmul(Q , np.diag(1./eigvals))
    Q, _, _ = np.linalg.svd(S, full_matrices=False) 

    dataset_starfiles = []
    for i in range(DATASET_NUM):
        target_pdb = os.path.join(PDB_PATH, 'new_frame%02d.pdb'%(i+1))
        atomic_model2 = AtomicModel(target_pdb, config.atomic_clean_pdb, config.atomic_center,
                                       pdb_out=os.path.join('.', 'curated_gemmi.pdb'))
        prody_model2 = read_prody_model('curated_gemmi.pdb');
        atoms2 = prody_model2.getCoords().reshape((-1, 1))
        coord_diff = atoms2 - atoms1
        alphas = np.squeeze(np.matmul(Q.T, coord_diff))
        dataset_path = os.path.join(DATA_PATH, 'new_frame%02d'%(i+1))
        dataset_starfiles.append({'starfile': starfile.read(os.path.join(dataset_path, 'new_frame%02d.star'%(i+1))),
                                  'dataset_path': dataset_path, 'alphas': alphas})
    return dataset_starfiles

def get_filename(step, n_char=6):
    if step == 0:
        return '0' * n_char
    else:
        n_dec = int(np.log10(step))
        return '0' * (n_char - n_dec) + str(step)

def write_sampled_dataset(dataset_name, dataset_indeces, batch_size=250):
    data_out_path = os.path.join(DATA_PATH, dataset_name)
    out_particles_num = len(dataset_indeces)
    dataset_starfiles = populate_starfile_list()
    out_starfile = {'optics': dataset_starfiles[0]['starfile']['optics'], 
                    'particles': pd.DataFrame(columns = dataset_starfiles[0]['starfile']['particles'].columns)}
    S = dataset_starfiles[0]['starfile']['optics']['rlnImageSize'][0]
    mrcs_path = os.path.join(data_out_path, 'Particles/')
    if not os.path.exists(mrcs_path):
        os.makedirs(mrcs_path)
    mrc_index = 0
    for i, dataset_index in enumerate(tqdm(dataset_indeces)):
        if i % batch_size == 0:
            findex = get_filename((i//batch_size)+1, n_char=6)
            mrc_relative_path = f'Particles/particles_{findex}.mrcs'
            mrcs_file = os.path.join(data_out_path, mrc_relative_path)
            mrc = mrcfile.new_mmap(mrcs_file, 
                                   shape=(batch_size, S, S), 
                                   mrc_mode=2, overwrite=True)
            mrc_index = 0
        particle_row = dataset_starfiles[dataset_index]['starfile']['particles'].iloc[[0]]
        imgnamedf = particle_row['rlnImageName'].values[0].split('@')
        in_mrc_path = os.path.join(dataset_starfiles[dataset_index]['dataset_path'], imgnamedf[1])
        pidx = int(imgnamedf[0]) - 1
        with mrcfile.mmap(in_mrc_path, mode='r', permissive=True) as f:
            mrc.data[mrc_index] = f.data[pidx]
            mrc_index += 1
        particle_row['rlnImageName'] = get_filename(mrc_index, n_char=6) + '@' + mrc_relative_path
        particle_row['nmaAlphas'] = ','.join(['%.5f' % num for num in dataset_starfiles[dataset_index]['alphas']])
        particle_row['pdbIndex'] = dataset_index + 1
        dataset_starfiles[dataset_index]['starfile']['particles'] = dataset_starfiles[dataset_index]['starfile']['particles'].iloc[1: , :]
        out_starfile['particles'] = pd.concat([out_starfile['particles'], particle_row])

    out_starfile['particles'].reset_index(drop=True, inplace=True)
    starfile.write(out_starfile, os.path.join(data_out_path, f'{dataset_name}.star'), overwrite=True)

### Generate Uniform Dataset

In [None]:
OUT_PARTICLES_NUM = 50000
dataset_indeces = np.random.choice(np.arange(DATASET_NUM), OUT_PARTICLES_NUM)
plt.figure()
plt.hist(dataset_indeces, bins=DATASET_NUM);
plt.xlabel('Model')
plt.ylabel('Particles')

write_sampled_dataset('uniform', dataset_indeces)

### Discontinuous Dataset 

In [None]:
OUT_PARTICLES_NUM = 50000
EPS = 1e-6

means = np.random.choice([1, 2, 3], size = OUT_PARTICLES_NUM)
stdevs = np.random.choice([0.05], size = OUT_PARTICLES_NUM)
dataset_indeces = np.random.normal(loc=means, scale=stdevs)
dataset_indeces -= dataset_indeces.min()
dataset_indeces /= (dataset_indeces.max() + EPS)
dataset_indeces *= DATASET_NUM
dataset_indeces = np.floor(dataset_indeces).astype(int)
plt.figure()
plt.hist(dataset_indeces, bins=DATASET_NUM);
plt.xlabel('Model')
plt.ylabel('Particles')

write_sampled_dataset('discontinuous', dataset_indeces)

### Continuous Dataset

In [None]:
OUT_PARTICLES_NUM = 50000
EPS = 1e-6

means = np.random.choice([1, 2, 3], size = OUT_PARTICLES_NUM)
stdevs = np.random.choice([0.25], size = OUT_PARTICLES_NUM)
dataset_indeces = np.random.normal(loc=means, scale=stdevs)
dataset_indeces -= dataset_indeces.min()
dataset_indeces /= (dataset_indeces.max() + EPS)
dataset_indeces *= DATASET_NUM
dataset_indeces = np.floor(dataset_indeces).astype(int)
plt.figure()
plt.hist(dataset_indeces, bins=DATASET_NUM);
plt.xlabel('Model')
plt.ylabel('Particles')

write_sampled_dataset('continuous', dataset_indeces)