In [1]:
import MDAnalysis as mda
from MDAnalysis.analysis.base import AnalysisBase
import numpy as np
from multiprocessing import Pool
import itertools
import sys
import os
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
# build custom data structure
class LipidContacts(AnalysisBase):
    def __init__(self, protein, lipids, cutoff=4.5, smoothing_cutoff=3, min_bind=5, **kwargs):
        super().__init__(protein.universe.trajectory, **kwargs)
        self.protein = protein
        self.u = self.protein.universe
        self.lipids = lipids
        self.cutoff = cutoff
        self.smoothing_cutoff = smoothing_cutoff
        self.min_bind = min_bind
        
        if np.unique(self.protein.segids).shape[0] > 1:
            self.complex = True
        else:
            self.complex = False
        
    
    ###################
    # PRIVATE METHODS #
    ###################
    
    def _prepare(self):
        '''
        Preprocessing: Set up custom data structure based on full
        TM region residues
        '''
        
        tm = self.identify_tm()
        self.interactions = self.construct_interaction_array(tm)
        self.mapping = self.map_lipids()  
    
    
    def _single_frame(self):
        '''
        What to do at each frame
        '''
        
        frame = self._ts.frame
        
        # iteration through the tm residues, find unique lipid contacts
        for key in self.interactions.keys():
            res, seg = key.split('-')
            lips = self.u.select_atoms(f'segid MEMB and around {self.cutoff} (protein and resid {res} and segid {seg})')
            lip_resi = lips.residues.ix # this is the list of unique lipid resIDs for contacts
            lip_resn = [self.mapping[resID] for resID in lip_resi]
            
            for (lipID, lipRN) in zip(lip_resi, lip_resn):
                if lipID not in self.interactions[key][lipRN].keys():
                    self.interactions[key][lipRN].update({lipID:[frame]})
                else:
                    self.interactions[key][lipRN][lipID] += [frame]
    
    
    def _conclude(self):
        '''
        Postprocessing: Normalization of lipid binding events into streamlined output
        '''
        
        self.results = {}
        for pres in self.interactions.keys():
            for lip in self.interactions[pres].keys():
                # check for empty
                if self.interactions[pres][lip]:
                    coeffs = self.get_coeff(self.interactions[pres][lip])
                    self.results.update({f'{pres}-{lip}': coeffs})
    
    
    ##################
    # PUBLIC METHODS #
    ##################
    
    def construct_interaction_array(self, tm_residues):
        '''
        Generate a nested dict structure to track lipid contacts on a per
        reside basis
        '''
        
        lipids_of_interest = ['PC','PE','PG','PI','PS','PA','CL','SM','CHOL']
        inter = {key:{lip:{} for lip in lipids_of_interest} for key in tm_residues}
        
        return inter
    
    
    def identify_tm(self):
        # find phosphate plane for membrane boundary
        memb_zcog = self.lipids.center_of_geometry()[2] # already defined
        z_top = self.u.select_atoms(f'name P and prop z > {memb_zcog}').center_of_geometry()[2]
        z_bot = self.u.select_atoms(f'name P and prop z < {memb_zcog}').center_of_geometry()[2]
        
        # obtain list of resids pertaining to residues within this boundary
        protein_residues = self.u.select_atoms(f'protein and prop z > {z_bot} and prop z < {z_top}').residues
        
        return [f'{resID}-{segID}' for (resID,segID) in zip(protein_residues.ix,protein_residues.segids)]
    
    
    def map_lipids(self):
        lipids = u.select_atoms('segid MEMB and name P').residues
        lip_mapping = {name:name[-2:] if name != 'CHOL' else name for name in lipids.resnames}
        mapping = {resid:lip_mapping[resn] for resid, resn in zip(lipids.ix, lipids.resnames)}
        
        return mapping
    
    
    def get_coeff(self, simdata):
        '''
        Obtain the distribution of binding events in order to fit an exponential.
        Returns the coefficients of said exponential to be used as training/test data.
        '''
        
        events = []
        for key_ in simdata.keys():
            events += [event for event in self.get_binding_profile(simdata[key])]
            
        # throw out minimal binding
        culled = [event for event in events if event > self.min_bind]
        
        # fit exponential to `culled` distribution
        hist = np.histogram(culled, bins=50, density=True)
        X, Y = ((hist[1][:-1] + hist[1][1:]) / 2), hist[0]
        A, B = np.polyfit(X, Y, 2, w=np.sqrt(Y))
        
        return [A, B]
    
    
    def get_binding_profile(self, pairdata):
        # history is used to track the local binding history to handle edge cases
        history = [0]*20
        events = []
        lastframe = pairdata[-1]
        
        i = 0
        while i < lastframe:
            # check if bound in this frame
            bound = 1 if pairdata[0] == i else 0
            
            if bound:
                # if you have been bound within the hyst cutoff
                # you are considered `resident`
                resident = 1 if sum(history[:self.smoothing_cutoff]) > 0 else 0
                history.insert(0, 1)
                pairdata.pop(0)
                
            else:
                resident = 0
                history.insert(0, 0)
                
                try:
                    events.append(current)
                except Exception as e:
                    pass
                
            history.pop()
             
            
            if bound and resident:
                current += 1
            elif bound and not resident:
                current = 1
                
            i += 1
            
        # need to check last frame for binding since this would not be appended otherwise
        if sum(history[:self.smoothing_cutoff]) > 0:
            events.append(current)
            
        return events

In [3]:
u = mda.Universe('testing/T6R2.psf',
                 'testing/T6R2.dcd')

protein = u.select_atoms('protein')
lipids = u.select_atoms('segid MEMB')

lipid_analysis = LipidContacts(protein, lipids)

In [4]:
#lipid_analysis.run(verbose=True)

In [5]:
def parallelize_run(analysis, n_workers, worker_id):
    analysis.run(start=worker_id, step=n_workers, verbose=not worker_id)
    return analysis

#def display_hack():
#    sys.stdout.write(' ')
#    sys.stdout.flush()
    
n_workers = os.cpu_count()

params = zip(itertools.repeat(lipid_analysis),
             itertools.repeat(n_workers),
             range(n_workers))


In [6]:
pool = Pool(processes=n_workers)#, initializer=display_hack)

analyses = pool.starmap(parallelize_run, params)



Process SpawnPoolWorker-1:
Traceback (most recent call last):
  File "/Users/matt/opt/anaconda3/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/Users/matt/opt/anaconda3/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/matt/opt/anaconda3/lib/python3.9/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/Users/matt/opt/anaconda3/lib/python3.9/multiprocessing/queues.py", line 368, in get
    return _ForkingPickler.loads(res)
AttributeError: Can't get attribute 'parallelize_run' on <module '__main__' (built-in)>
Process SpawnPoolWorker-2:
Traceback (most recent call last):
  File "/Users/matt/opt/anaconda3/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/Users/matt/opt/anaconda3/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/matt/opt/anaconda3/

KeyboardInterrupt: 

In [None]:
pool.close()

In [None]:
lipid_analysis.interactions

In [1]:
import MDAnalysis as mda

In [2]:
u = mda.Universe('testing/T6R2.psf','testing/trimmed.dcd')
protein = u.select_atoms('protein')
lipids = u.select_atoms('segid MEMB')

In [17]:
protein.residues.ix

array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
       104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
       117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129,
       130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142,
       143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
       156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
       169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 18

In [25]:
for ts in u.trajectory[::10]:
    lips = lipids.select_atoms('segid MEMB and around 3.5 global (resid 367 and segid PROA and group protein)',protein=protein)
    lip_resi = lips.residues.ix
    print(lip_resi)

[801 805 822]
[498 815]
[815 817]
[802 817]
[805 817 829]
[ 829 1039]
[829]
[ 569  815 1039]
[802 815]
[802 815]
