# MPModel: Infering the detergent corona around a membrane protein

Throughout this tutorial, we will buid a coarse-grained model of a micelle around the  membrane protein Aquaporin-0 (PDB: 2B6P [1]) and use the `simSPI` package to simulate a cryo-EM experiment of the structure. For that, it is necessary to have installed the TEM-simulator [2] and the python libraries: Matplotlib Numpy, Gemmi, Scikit-learn and Scipy.


<p align="center">
    <img src="figures/2b6p_map.png" alt="Drawing" style="width: 700px;"/>
    <p style="text-align: center">
        Figure 1: _Bos taurus_ Aquaporin-0 complex
    </p>
</p>

In [1]:
import os
import sys
import warnings

sys.path.append(os.path.dirname(os.getcwd()))
warnings.filterwarnings("ignore")

In [2]:
import numpy as np
import math
import gemmi
import torch
import mrcfile
from matplotlib import pyplot as plt, axes as ax
from scipy.spatial.transform import Rotation as rot

from simSPI.linear_simulator.linear_simulator import LinearSimulator
from test_linear_simulator import init_data

import detergent_belt as mpm
import atomic_models

from torchmetrics.image import UniversalImageQualityIndex   

## Creating the model
### Extracting the protein atomic coordinates from a PDB file

<p align="center">
    <img src="figures/aquaporin.png" alt="Drawing" style="width: 400px;"/>
    <p style="text-align: center">
        Figure 2: Representation of the atomic coordinates of aquaporin
    </p>
</p>

### Centering and rotating the structure
<p style='text-align: justify;margin-right:10%;'>
Using Principal Component Analysis (PCA), a method for data decomposition, we can extract the main axis where our data is distributed. We can use a matrix transformation as a function to represent a modification on our dataset. In this case, the transformation is aligning the principal component of the coordinate points to the Z-axis. Therefore, we can distribute the atoms on another axis and, by doing so, rotate the protein so that its transmembrane portion is positioned properly in space.
</p>

<p align="center">
    <img src="figures/aquaporin_rotated.png" alt="Drawing" style="width: 400px;"/>
    <p style="text-align: center">
        Figure 3: Aquaporin structure rotated and centered at the origin
    </p>
</p>

### Creating a corona around the protein

To create the Ellipsoid we will compute the points that satisfy the condition: 
$\frac{x^{2}}{a^{2}} + \frac{y^{2}}{b^{2}} + \frac{z^{2}}{c^{2}} >= 1$

Consider $O = (0,0,0)$, as the protein was centered after the coordinates were imported.

<p align="center">
    <img src="figures/ellipsoid.png" alt="Drawing" style="width: 200px;"/>
    <p style="text-align: center">
        Figure 4: The geometric parameters of an ellipsoid
    </p>
</p>

The method used in this program creates an ellipsoid by generating points iteratively in regular intervals dependent on the ray of the pseudoatom. After that, a function removes the coordinates in the most inner part of the ellipsoid that would conflict with the protein's atoms. This process' speed depends on the number of pseudo atoms, that itself depends on the parameters of the corona such as the height, determined by size of the transmembrane portion of the protein, and on the ray of the pseudo atoms filling the ellipsoid.

### Excluding the overlaping pseudoatoms
Now that we have our two ellipsoids we need to delete from our corona the pseudo atoms that are physically overlapping with other atoms

### Generating the final PDB file

<p align="center">
    <img src="figures/aqp_model_abc.png" alt="Drawing" style="width: 600px;"/>
    <p style="text-align: center">
        Figure 5: Model of detergent micelle and membrane protein Aquaporin (AQP0) seen as A- full model, B- micelle
hydrophobic core and C- frontal cut. The micelle hydrophobic core is represented in yellow,
the hydrophilic shell in blue and the protein in pink.
    </p>
</p>

XXXXXXXXXXXXX CHIMERAX XXXXXXXXXXXXX molmap #[structure] [res]

ADD STEPS
ADD IMAGE
MAYBE AUTOMIZE

In [3]:
def set_protein(path, ptn_axis):
    ptn = mpm.MembraneProtein(path)
    ptn.rotate_protein(ptn_axis)
    return ptn
        
def build_model(sample,ptn):
    micelle = mpm.DetergentBelt()
    micelle.set_belt_parameters(sample,50,25,7,[0,0,0]) #a1,a2,h,t,cset the parameters for the core of your micelle
    micelle.set_atomic_parameters(1.5, "CA", "N")  #set the parameters for hydrophobic and hydrophilic pseudoatoms
    micelle.generate_core()
    micelle.generate_shell()

    micelle.in_hull(micelle.core_coordinates_set,ptn.final_coordinates)
    micelle.remove(micelle.core_coordinates_set)

    micelle.in_hull(micelle.shell_coordinates_set,micelle.core_coordinates_set)
    micelle.remove(micelle.shell_coordinates_set)
    micelle.in_hull(micelle.shell_coordinates_set,ptn.final_coordinates)      
    micelle.remove(micelle.shell_coordinates_set)

    file_name = f"outputs/pdb/7wsn/experiment_2/7wsn_abc_{sample}x50x25.pdb"   #name your pdb file
    M = mpm.Model()
    M.clean_gemmi_structure()
    M.write_atomic_model(file_name, model=gemmi.Model("model"))
    M.create_model(file_name,micelle,ptn)

## Linear Simulation

In [4]:
def rotation(N):
    A = np.random.uniform(0,2*np.pi,N)
    V = np.random.uniform(0,1,3)
    R = []
    for a in A:
        r = rot.from_rotvec(a*V)
        R.append(r.as_matrix())
    return R #sampled rotations

def LinSim(vol_path, data_path, rotations):
    '''
    rotations -> list of rotation matrices
    '''
    with mrcfile.open(vol_path) as file:
        vol = file.data
    saved_data, config = init_data(data_path)
    
    #nx = vol.shape[0]
    nx = 200
    config["input_volume_path"] = vol_path
    config["side_len"] = nx
    config["ctf_size"] = nx
    config["pixel_size"] = 1.2
    
    config["noise"] = True
    config["noise_sigma"] = 1

    lin_sim = LinearSimulator(config)
    
    N = len(rotations)
    
    ctf_params = saved_data["ctf_params"]
    shift_params = saved_data["shift_params"]
    rot_params = {'rotmat': torch.Tensor(rotations)}
    
    ctf_params["defocus_u"] = torch.tensor([1.1000,2.1000,1.5000,1.7000,1.1000,2.1000,1.5000,1.7000]).reshape(N,1,1,1)
    ctf_params["defocus_v"] = torch.tensor([1.0000,2.2000,1.6000,1.8000,1.10000,2.2000,1.6000,1.8000]).reshape(N,1,1,1)
    ctf_params["defocus_angle"] = torch.tensor([0.3142,1.5708,2.5133,3.1416,0.3142,1.5708,2.5133,3.1416]).reshape(N,1,1,1)
    shift_params["shift_x"] = torch.tensor([ 4.0000,  5.5000, -3.2000, -6.0000, 4.0000,  5.5000, -3.2000, -6.0000])
    shift_params["shift_y"] = torch.tensor([ 6.0000, -4.5000, -4.0000,  0.0000,  6.0000, -4.5000, -4.0000,  0.0000])

    particles = lin_sim(rot_params, ctf_params, shift_params)
    
    return particles #torch tensor for particle projection

## Parameter Inference

<p align="center">
    <img src="figures/parameters_new.png" alt="Drawing" style="width: 450px;"/>
    <p style="text-align: center">
        Figure 6: XXXXXXXXXXXXXXXXXXXXXXX ADD DESCRIPTION AND REFERENCE XXXXXXXXXXXXXXXXXXXX
    </p>
</p>

We will infer the parameter "a" by Rejection Approximation Bayesian Computation.

In [83]:
def run_sampling(runs,ptn,span):
    for i in range(runs):
        a = np.random.uniform(span[0],span[1])   #sample parameter a
        model = build_model(a,ptn)  #create model using sampled a
        print(a)

def pdb_to_mrc():
    pass
    
def run_projection(): 
    pass
    proj = LinSim(vol_path, data_path, R)
    torch.save(proj, proj_path)

def psnr(proj1, proj2):
    p1 = torch.max(proj1)
    p2 = torch.max(proj2)
    if p1 >= p2:
        p = p1
    else:
        p = p2
    se = []
    for i in range(len(proj1)):
        se.append(torch.sum((proj1[i]-proj2[i])**2))
    mse = np.mean(se)
    psnr = (10*(np.log10(p)**2))/mse
    return psnr
        
def run_inference(folder,rot,gt):
    dataset = dict()
    for filename in os.scandir(folder):
        sim_proj = torch.load(filename.path)
        distance = psnr(gt_proj, sim_proj)
        dataset[str(filename).strip('<DirEntry npy>')] = distance
        print(distance)
    rank = sorted(dataset.items(), key=lambda x:x[1])
    return rank

In [39]:
ptn = set_protein("7wsn.pdb", "Z")
runs = 20
span = [35,90]

print(run_sampling(runs,ptn,span))

42.75206636840166
50.538531163199565
41.02816945797648
57.069605230522306
70.72695364411129
81.61809839497835
36.193432428890794
65.53347265841435
58.03130371262615
73.91089912402805
80.49550188370858
87.66308602190934
47.323975989565454
66.51786560092798
50.601283141712884
56.14315867978939
61.096895253589544
54.27858268336376
83.84427545949086
36.04199257086475
None


In [7]:
root_dir = "/home/halv/"  # change this to your simSPI directory.
gt_file = "7wsn_50x40x25"
gt_model = f"{root_dir}compSPI_fork/simSPI/notebooks/outputs/mrc/7wsn/experiment_2/{gt_file}.mrc"
data_path = f"{root_dir}compSPI_fork/simSPI/tests/data/linear_simulator_data.npy"
proj_folder = f"{root_dir}compSPI_fork/simSPI/notebooks/outputs/projections/7wsn/experiment_2"

In [8]:
rot = [[[ 0.71776143,  0.52648782, -0.45566337],
       [ 0.02935378,  0.63095527,  0.7752637 ],
       [ 0.6956701 , -0.56982982,  0.43742095]], [[ 0.09148036,  0.81331589,  0.57458559],
       [ 0.97592573, -0.1879467 ,  0.11065714],
       [ 0.19799068,  0.55062991, -0.81092934]], [[ 0.5965571 ,  0.66521582, -0.44900729],
       [ 0.12932599,  0.47247297,  0.87180507],
       [ 0.79208233, -0.57814982,  0.19582741]], [[ 0.42557599,  0.82837522, -0.36425207],
       [ 0.30289736,  0.24890439,  0.91994554],
       [ 0.85272403, -0.50183772, -0.14498494]], [[ 0.43528531,  0.82011672, -0.37139636],
       [ 0.29203428,  0.26159994,  0.91993557],
       [ 0.85161181, -0.50889491, -0.12563159]], [[ 0.17689998,  0.98298448, -0.04947645],
       [ 0.63803148, -0.07625517,  0.76622515],
       [ 0.74941459, -0.16711275, -0.64066458]], [[ 0.71942651,  0.02813017,  0.6939987 ],
       [ 0.52443221,  0.63313247, -0.56931023],
       [-0.4554079 ,  0.77353214,  0.44073991]], [[ 0.45444147,  0.80345366, -0.38463641],
       [ 0.27097115,  0.28664782,  0.91891657],
       [ 0.84856207, -0.52181916, -0.08744811]]]

In [9]:
gt_proj = LinSim(gt_model, data_path, rot) #simulate image with ground truth model
torch.save(gt_proj, proj_folder+"gt.npy")

In [31]:
sample = "7wsn_abc_36.04199257086475"
root_dir = "/home/halv/"  # change this to your simSPI directory.
vol_path = f"{root_dir}compSPI_fork/simSPI/notebooks/outputs/mrc/7wsn/experiment_2/{sample}.mrc"
data_path = f"{root_dir}compSPI_fork/simSPI/tests/data/linear_simulator_data.npy"
proj_path = f"{root_dir}compSPI_fork/simSPI/notebooks/outputs/projections/7wsn/experiment_2/{sample}.npy"
proj = LinSim(vol_path, data_path, rot)
torch.save(proj, proj_path)

In [84]:
a = run_inference(proj_folder,rot,gt_proj)
print(a)

tensor(6.5326e-05)
tensor(5.1446e-05)
tensor(6.4613e-05)
tensor(5.5290e-05)
tensor(6.1508e-05)
tensor(6.3590e-05)
tensor(5.2700e-05)
tensor(5.5455e-05)
tensor(4.7596e-05)
tensor(4.8372e-05)
tensor(6.0093e-05)
tensor(6.8271e-05)
tensor(6.4026e-05)
tensor(6.5140e-05)
tensor(5.7961e-05)
tensor(6.3196e-05)
tensor(5.6596e-05)
tensor(5.5711e-05)
tensor(6.5295e-05)
tensor(6.2206e-05)
[("'7wsn_abc_36.193432428890794.npy'", tensor(4.7596e-05)), ("'7wsn_abc_36.04199257086475.npy'", tensor(4.8372e-05)), ("'7wsn_abc_87.66308602190934.npy'", tensor(5.1446e-05)), ("'7wsn_abc_83.84427545949086.npy'", tensor(5.2700e-05)), ("'7wsn_abc_42.75206636840166.npy'", tensor(5.5290e-05)), ("'7wsn_abc_81.61809839497835.npy'", tensor(5.5455e-05)), ("'7wsn_abc_80.49550188370858.npy'", tensor(5.5711e-05)), ("'7wsn_abc_41.02816945797648.npy'", tensor(5.6596e-05)), ("'7wsn_abc_73.91089912402805.npy'", tensor(5.7961e-05)), ("'7wsn_abc_70.72695364411129.npy'", tensor(6.0093e-05)), ("'7wsn_abc_66.51786560092798.npy'", t