# MPModel: Infering the detergent corona around a membrane protein

Throughout this tutorial, we will buid a coarse-grained model of a micelle around the  membrane protein Aquaporin-0 (PDB: 2B6P [1]) and use the `simSPI` package to simulate a cryo-EM experiment of the structure. For that, it is necessary to have installed the TEM-simulator [2] and the python libraries: Matplotlib Numpy, Gemmi, Scikit-learn and Scipy.


<p align="center">
    <img src="figures/2b6p_map.png" alt="Drawing" style="width: 700px;"/>
    <p style="text-align: center">
        Figure 1: _Bos taurus_ Aquaporin-0 complex
    </p>
</p>


[1] https://www.rcsb.org/structure/2B6P

[2] http://tem-simulator.sourceforge.net/

In [1]:
import os
import sys
import warnings

sys.path.append(os.path.dirname(os.getcwd()))
warnings.filterwarnings("ignore")

In [2]:
import numpy as np
import math
import gemmi
import torch
import mrcfile
from matplotlib import pyplot as plt, axes as ax
from scipy.spatial.transform import Rotation as rot

from simSPI.linear_simulator.linear_simulator import LinearSimulator
from test_linear_simulator import init_data

import detergent_belt as mpm
import atomic_models

## Creating the model
### Extracting the protein atomic coordinates from a PDB file

In [3]:
ptn = mpm.MembraneProtein("6dmy.pdb") #path to pdb file

<p align="center">
    <img src="figures/aquaporin.png" alt="Drawing" style="width: 400px;"/>
    <p style="text-align: center">
        Figure 2: Representation of the atomic coordinates of aquaporin
    </p>
</p>

### Centering and rotating the structure
<p style='text-align: justify;margin-right:10%;'>
Using Principal Component Analysis (PCA), a method for data decomposition, we can extract the main axis where our data is distributed. We can use a matrix transformation as a function to represent a modification on our dataset. In this case, the transformation is aligning the principal component of the coordinate points to the Z-axis. Therefore, we can distribute the atoms on another axis and, by doing so, rotate the protein so that its transmembrane portion is positioned properly in space.
</p>

In [8]:
ptn.rotate_protein(axis='Z')

<p align="center">
    <img src="figures/aquaporin_rotated.png" alt="Drawing" style="width: 400px;"/>
    <p style="text-align: center">
        Figure 3: Aquaporin structure rotated and centered at the origin
    </p>
</p>

***
### Creating a corona around the protein

To create the Ellipsoid we will compute the points that satisfy the condition: 
$\frac{x^{2}}{a^{2}} + \frac{y^{2}}{b^{2}} + \frac{z^{2}}{c^{2}} >= 1$

Consider $O = (0,0,0)$, as the protein was centered after the coordinates were imported.

<p align="center">
    <img src="figures/ellipsoid.png" alt="Drawing" style="width: 200px;"/>
    <p style="text-align: center">
        Figure 4: The geometric parameters of an ellipsoid
    </p>
</p>

The method used in this program creates an ellipsoid by generating points iteratively in regular intervals dependent on the ray of the pseudoatom. After that, a function removes the coordinates in the most inner part of the ellipsoid that would conflict with the protein's atoms. This process' speed depends on the number of pseudo atoms, that itself depends on the parameters of the corona such as the height, determined by size of the transmembrane portion of the protein, and on the ray of the pseudo atoms filling the ellipsoid.

Let us create the hydrophobic interior ellipsoid of the corona with parameters chosen empirically for a better visualization of the exemple on Pymol.</br>

major axis &rarr; 35Å</br>
minor axis &rarr; 25Å</br>
ellipticity &rarr; 2</br>
pseudo atom type &rarr; CA</br>
pseudo atom ray &rarr; 1.5</br>

major axis &rarr; 40Å</br>
minor axis &rarr; 30Å</br>
ellipticity &rarr; 2</br>
pseudo atom type &rarr; N</br>
pseudo atom ray &rarr; 1.5

In [51]:
axis = 30
C = mpm.BeltCore()
C.set_belt_parameters(axis-10, 50, 16,[-2,-2,40])    #set the parameters for the core of your micelle
C.set_atomic_parameters(1.5, "CA")  #set the parameters for you hydrophilic pseudoatom
C.generate_ellipsoid()

S = mpm.BeltShell()
S.set_belt_parameters(axis, 60, 26,[-2,-2,40])    #set the parameters for the shell of your micelle
S.set_atomic_parameters(1.5, "N")   #set the parameters for you hydrophobic pseudoatom
S.generate_ellipsoid()

### Excluding the overlaping pseudoatoms
Now that we have our two ellipsoids we need to delete from our corona the pseudo atoms that are physically overlapping with other atoms
Observation: the next steps take a couple of minutes to run

In [52]:
C.in_hull(ptn.final_coordinates)
C.remove()
 
S.in_hull(C.coordinates_set)
S.remove()
S.in_hull(ptn.final_coordinates)      
S.remove()

### Generating the final PDB file

In [53]:
file_name = "outputs/pdb/6dmy/6dmy_30x60x26.pdb"   #name your pdb file
M = mpm.Model()
M.clean_gemmi_structure()
M.write_atomic_model(file_name, model=gemmi.Model("model"))
M.create_model(file_name,C,S,ptn)

<p align="center">
    <img src="figures/aqp_model_abc.png" alt="Drawing" style="width: 600px;"/>
    <p style="text-align: center">
        Figure 5: Model of detergent micelle and membrane protein Aquaporin (AQP0) seen as A- full model, B- micelle
hydrophobic core and C- frontal cut. The micelle hydrophobic core is represented in yellow,
the hydrophilic shell in blue and the protein in pink.
    </p>
</p>

XXXXXXXXXXXXX CHIMERAX XXXXXXXXXXXXX molmap #[structure] [res]

ADD STEPS
ADD IMAGE
MAYBE AUTOMIZE

## Linear Simulation

In [4]:
def rotation(N):
    A = np.random.uniform(0,2*np.pi,N)
    V = np.random.uniform(0,1,3)
    R = []
    for a in A:
        r = rot.from_rotvec(a*V)
        R.append(r.as_matrix())
    return R #sampled rotations

In [54]:
def LinSim(vol_path, data_path, rotations):
    '''
    rotations -> list of rotation matrices
    '''
    with mrcfile.open(vol_path) as file:
        vol = file.data
    saved_data, config = init_data(data_path)
    
    #nx = vol.shape[0]
    nx = 200
    config["input_volume_path"] = vol_path
    config["side_len"] = nx
    config["ctf_size"] = nx
    config["pixel_size"] = 1.2
    
    config["noise"] = True
    config["noise_sigma"] = 1

    lin_sim = LinearSimulator(config)
    
    N = len(rotations)
    
    ctf_params = saved_data["ctf_params"]
    shift_params = saved_data["shift_params"]
    rot_params = {'rotmat': torch.Tensor(rotations)}
    
    ctf_params["defocus_u"] = torch.tensor([1.1000,2.1000,1.5000,1.7000,1.1000,2.1000,1.5000,1.7000]).reshape(N,1,1,1)
    ctf_params["defocus_v"] = torch.tensor([1.0000,2.2000,1.6000,1.8000,1.10000,2.2000,1.6000,1.8000]).reshape(N,1,1,1)
    ctf_params["defocus_angle"] = torch.tensor([0.3142,1.5708,2.5133,3.1416,0.3142,1.5708,2.5133,3.1416]).reshape(N,1,1,1)
    shift_params["shift_x"] = torch.tensor([ 4.0000,  5.5000, -3.2000, -6.0000, 4.0000,  5.5000, -3.2000, -6.0000 ])
    shift_params["shift_y"] = torch.tensor([ 6.0000, -4.5000, -4.0000,  0.0000,  6.0000, -4.5000, -4.0000,  0.0000])

    particles = lin_sim(rot_params, ctf_params, shift_params)
    
    return particles #torch tensor for particle projection

In [24]:
#R = rotation(8)

f, ax = plt.subplots(2, 4, figsize=(12, 6))

#ax[0].imshow(proj[0, 0])
#ax[1].imshow(proj[1, 0])
#ax[2].imshow(proj[2, 0])
#ax[3].imshow(proj[3, 0])

ax[0,0].imshow(proj[0, 0])
ax[0,1].imshow(proj[1, 0])
ax[0,2].imshow(proj[2, 0])
ax[0,3].imshow(proj[3, 0])
ax[1,0].imshow(proj[4, 0])
ax[1,1].imshow(proj[5, 0])
ax[1,2].imshow(proj[6, 0])
ax[1,3].imshow(proj[7, 0])

### Simulate images

In [55]:
R = [[[ 0.80635711, -0.01370839,  0.59127007],
       [ 0.39410307,  0.75788644, -0.51989509],
       [-0.44098864,  0.65224245,  0.61652964]],
     [[ 0.06689542,  0.85905001,  0.50750181],
       [ 0.97395306, -0.16666956,  0.15374228],
       [ 0.21665741,  0.48399829, -0.84782382]],
     [[ 0.1111205 ,  0.99294078,  0.04148306],
       [ 0.75318596, -0.11137452,  0.64831059],
       [ 0.64835418, -0.04079614, -0.76024505]],
     [[ 0.92351078, -0.06264793,  0.37842182],
       [ 0.21290437,  0.9043648 , -0.36985948],
       [-0.31906044,  0.42213688,  0.84852866]],
    [[ 0.92061443, -0.12424634,  0.37017823],
       [ 0.23403454,  0.93444747, -0.26839478],
       [-0.31256504,  0.3337226 ,  0.88934376]],
    [[-0.23368781,  0.79206824,  0.56393077],
       [ 0.91409036, -0.01871608,  0.40507842],
       [ 0.33140433,  0.61014557, -0.71964822]],
    [[-0.1085439 ,  0.99362457, -0.03047033],
       [ 0.5394632 ,  0.08462134,  0.83774619],
       [ 0.83498364,  0.07449461, -0.54520902]], 
    [[ 0.39405398,  0.05288187,  0.91756469],
       [ 0.78512593,  0.49964088, -0.36597304],
       [-0.47780616,  0.86461697,  0.15536655]]]

In [60]:
sample = "6dmy_60x60x26"
root_dir = "/home/halv/"  # change this to your simSPI directory.
vol_path = f"{root_dir}compSPI_fork/simSPI/notebooks/outputs/mrc/6dmy/{sample}.mrc"
data_path = f"{root_dir}compSPI_fork/simSPI/tests/data/linear_simulator_data.npy"
proj_path = f"{root_dir}compSPI_fork/simSPI/notebooks/outputs/projections/6dmy/gt.npy"

In [61]:
proj = LinSim(vol_path, data_path, R)
torch.save(proj, proj_path)

In [None]:
proj1 = torch.load(proj_path)
print(type(proj1))

f, ax = plt.subplots(2, 4, figsize=(12, 6))

ax[0,0].imshow(proj1[0, 0])
ax[0,1].imshow(proj1[1, 0])
ax[0,2].imshow(proj1[2, 0])
ax[0,3].imshow(proj1[3, 0])
ax[1,0].imshow(proj1[4, 0])
ax[1,1].imshow(proj1[5, 0])
ax[1,2].imshow(proj1[6, 0])
ax[1,3].imshow(proj1[7, 0])

<class 'torch.Tensor'>


## Parameter Inference

<p align="center">
    <img src="figures/parameters_new.png" alt="Drawing" style="width: 450px;"/>
    <p style="text-align: center">
        Figure 6: XXXXXXXXXXXXXXXXXXXXXXX ADD DESCRIPTION AND REFERENCE XXXXXXXXXXXXXXXXXXXX
    </p>
</p>

We will infer the parameter "a" by Rejection Approximation Bayesian Computation.

In [22]:
class ABC():
    '''
    e epsilon rejection criterion
    obs image vector of ground truth
    sim image vector of simulated model
    '''
    def l2(self, proj1, proj2):
        norm = []
        for i in range(len(proj1)):
            norm.append(torch.sum((proj1[i]-proj2[i])**2))
            rank = sorted(norm)
        return rank #np.mean(norm)

    def criterion(self, proj1, proj2):
        e = ((self.criterion*10000) + 40000)   #normalized criterion
        if l2(proj1,proj2) <= e:
            return True
            
    def generate_model():
        pass

    def runABC(self):
        samples = self.samples
        a = np.random.uniform()   #sample parameter a
        model = generate_model()  #create model using sampled a
        self.accepted = []
        for i in range(samples):   
            sim_proj = LinSim(vol_path, data_path, rot) #simulate image with generated model x100
            #superimpose images
            if criterion(obs_proj, sim_proj):
                self.accepted.append(a)   #save parameter

        self.result = np.mean(self.accepted)
        self.rank = sorted(self.accepted)

### Metrics

In [4]:
def mse(proj1, proj2):
    se = []
    for i in range(len(proj1)):
        se.append(torch.sum((proj1[i]-proj2[i])**2))
    return np.mean(se)
    
def rmsd(proj1, proj2):
    sd = []
    for i in range(len(proj1)):
        sd.append(torch.sum((proj1[i]-proj2[i])**2))
    return np.mean(np.sqrt(sd))

def psnr(proj1, proj2):
    p1 = torch.max(proj1)
    p2 = torch.max(proj2)
    if p1 >= p2:
        p = p1
    else:
        p = p2
    se = []
    for i in range(len(proj1)):
        se.append(torch.sum((proj1[i]-proj2[i])**2))
    mse = np.mean(se)
    psnr = (10*(np.log10(p)**2))/mse
    return psnr

In [5]:
from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure as ms_ssim
from torchmetrics.image import UniversalImageQualityIndex

In [98]:
root_dir = "/home/halv/"  # change this to your simSPI directory.
gt = "6nsj_50x50x25"
gt_proj = torch.load(f"{root_dir}compSPI_fork/simSPI/notebooks/outputs/projections/6nsj/{gt}.npy")

dataset = dict()

In [114]:
sim = "6nsj_80x80x25"
sim_proj = torch.load(f"{root_dir}compSPI_fork/simSPI/notebooks/outputs/projections/6nsj/{sim}.npy")

In [115]:
uqi = UniversalImageQualityIndex()

distance = psnr(gt_proj, sim_proj)
dataset[sim] = distance
print(distance, dataset)

tensor(4.2166e-05) {'6nsj_45x45x25': tensor(7.1990e-05), '6nsj_50x50x25': tensor(0.0004), '6nsj_55x55x25': tensor(0.0001), '6nsj_60x60x25': tensor(5.9934e-05), '6nsj_65x65x25': tensor(4.3408e-05), '6nsj_70x70x25': tensor(3.9824e-05), '6nsj_75x75x25': tensor(3.9885e-05), '6nsj_80x80x25': tensor(4.2166e-05)}


In [None]:
6nsj_ell_mse = {'6nsj_45x50x25': 165984.66, '6nsj_50x50x25': 39938.473, '6nsj_55x50x25': 79683.164, 
           '6nsj_60x50x25': 206437.4, '6nsj_65x50x25': 307591.88, '6nsj_70x50x25': 349634.6, 
           '6nsj_75x50x25': 379704.56, '6nsj_80x50x25': 384298.12}
6nsj_ell_rmsd = {'6nsj_45x50x25': 400.16388, '6nsj_50x50x25': 199.8453, '6nsj_55x50x25': 280.30017, 
                 '6nsj_60x50x25': 444.63068, '6nsj_65x50x25': 541.0524, '6nsj_70x50x25': 576.74585, 
                 '6nsj_75x50x25': 601.90125, '6nsj_80x50x25': 606.81415}
6nsj_ell_psnr = {'6nsj_45x50x25': tensor(0.0001), '6nsj_50x50x25': tensor(0.0004), 
                 '6nsj_55x50x25': tensor(0.0002), '6nsj_60x50x25': tensor(8.7270e-05), 
                 '6nsj_65x50x25': tensor(5.9182e-05), '6nsj_70x50x25': tensor(5.0708e-05), 
                 '6nsj_75x50x25': tensor(4.5270e-05), '6nsj_80x50x25': tensor(4.4729e-05)}
6nsj_ell_uqi = {'6nsj_45x50x25': tensor(0.2010), '6nsj_50x50x25': tensor(0.4541), 
                '6nsj_55x50x25': tensor(0.3133), '6nsj_60x50x25': tensor(0.1501), 
                '6nsj_65x50x25': tensor(0.0975), '6nsj_70x50x25': tensor(0.0773), 
                '6nsj_75x50x25': tensor(0.0657), '6nsj_80x50x25': tensor(0.0580)}


6nsj_cir_mse = {'6nsj_45x45x25': 258523.9, '6nsj_50x50x25': 39938.473, '6nsj_55x55x25': 116498.914, 
           '6nsj_60x60x25': 297419.2, '6nsj_65x65x25': 395996.56, '6nsj_70x70x25': 431631.25, 
           '6nsj_75x75x25': 430970.25, '6nsj_80x80x25': 407661.16}
6nsj_cir_rmsd = {'6nsj_45x45x25': 503.02505, '6nsj_50x50x25': 199.8453, '6nsj_55x55x25': 339.31647, 
                 '6nsj_60x60x25': 538.5089, '6nsj_65x65x25': 621.0999, '6nsj_70x70x25': 649.4824, 
                 '6nsj_75x75x25': 650.2472, '6nsj_80x80x25': 633.68494}
6nsj_cir_psnr ={'6nsj_45x45x25': tensor(7.1990e-05), '6nsj_50x50x25': tensor(0.0004), 
                '6nsj_55x55x25': tensor(0.0001), '6nsj_60x60x25': tensor(5.9934e-05), 
                '6nsj_65x65x25': tensor(4.3408e-05), '6nsj_70x70x25': tensor(3.9824e-05), 
                '6nsj_75x75x25': tensor(3.9885e-05), '6nsj_80x80x25': tensor(4.2166e-05)}
6nsj_cir_uqi = {'6nsj_45x45x25': tensor(0.0994), '6nsj_50x50x25': tensor(0.4541), 
                '6nsj_55x55x25': tensor(0.2257), '6nsj_60x60x25': tensor(0.0464), 
                '6nsj_65x65x25': tensor(0.0232), '6nsj_70x70x25': tensor(0.0062), 
                '6nsj_75x75x25': tensor(0.0002), '6nsj_80x80x25': tensor(0.0026)}


7wsn_mse = {'7wsn_50x40x25': 40149.305, '7wsn_20x40x25': 279642.56, '7wsn_25x40x25': 265809.12, 
        '7wsn_30x40x25': 223505.69, '7wsn_35x40x25': 153820.53, '7wsn_40x40x25': 109778.05, 
        '7wsn_45x40x25': 59534.04, '7wsn_55x40x25': 49390.47, '7wsn_60x40x25': 70199.59, 
        '7wsn_65x40x25': 95364.17, '7wsn_70x40x25': 112922.984, '7wsn_75x40x25': 128386.125, 
        '7wsn_80x40x25': 139803.42}
7wsn_rmsd = {'7wsn_20x40x25': 526.97833, '7wsn_25x40x25': 513.81995, '7wsn_30x40x25': 471.21075, 
             '7wsn_35x40x25': 390.30798, '7wsn_40x40x25': 329.91663, '7wsn_45x40x25': 243.64532, 
             '7wsn_50x40x25': 200.37125, '7wsn_55x40x25': 222.11444, '7wsn_60x40x25': 264.14392, 
             '7wsn_65x40x25': 307.09174, '7wsn_70x40x25': 333.93845, '7wsn_75x40x25': 355.78076, 
             '7wsn_80x40x25': 371.1292}
7wsn_psnr = {'7wsn_20x40x25': tensor(6.8210e-05), '7wsn_25x40x25': tensor(7.1748e-05), 
             '7wsn_30x40x25': tensor(8.2719e-05), '7wsn_35x40x25': tensor(0.0001), 
             '7wsn_40x40x25': tensor(0.0002), '7wsn_45x40x25': tensor(0.0003), 
             '7wsn_50x40x25': tensor(0.0004), '7wsn_55x40x25': tensor(0.0003), 
             '7wsn_60x40x25': tensor(0.0002), '7wsn_65x40x25': tensor(0.0002), 
             '7wsn_70x40x25': tensor(0.0001), '7wsn_75x40x25': tensor(0.0001), 
             '7wsn_80x40x25': tensor(0.0001)}
7wsn_uqi = {'7wsn_20x40x25': tensor(0.0293), '7wsn_25x40x25': tensor(0.0293), 
            '7wsn_30x40x25': tensor(0.0481), '7wsn_35x40x25': tensor(0.0729), 
            '7wsn_40x40x25': tensor(0.1134), '7wsn_45x40x25': tensor(0.1893), 
            '7wsn_50x40x25': tensor(0.2734), '7wsn_55x40x25': tensor(0.2154), 
            '7wsn_60x40x25': tensor(0.1513), '7wsn_65x40x25': tensor(0.1021), 
            '7wsn_70x40x25': tensor(0.0803), '7wsn_75x40x25': tensor(0.0684), 
            '7wsn_80x40x25': tensor(0.0611)}