In [1]:
import mdtraj as md
import numpy as np
import math
import networkx as nx

## Goal: Automate the process of identifying which atoms are in each of the two tails of the CERs.  

Examination of each tail separately is necessary to identify if a CER lipid is in a hairpin or extended conformation, and also necessary for identify which leaflet a given tail is in for calculation of various properties on a per-leaflet basis.  While we could simply define a range of indices that correspond to each tail, failure to modify these indices when changing to a different CER type or chain length or in mixtures of CERs would likely lead to errors in the analysis.  
* Note, incorrectly identifying the chain indices might not actually cause the analysis code to crash, but simply produce incorrect output (i.e., we probably wouldn't select unallocated memory which would cause the code to seg fault, but instead just access an incorrect location in memory). 


To achieve this goal of automated detection, we will utilize `networkx` to create a graph of all connections between carbon atoms and between carbon and hydrogen.  This will essentially create two unconnect subgraphs, since the CERs have a nitrogen in the backbone structure of the headgroup and as such, only including carbon atoms will prevent the two tails from being seen as connected via the headgroup.  

To ensure we do not include atoms in the headgroup we can specify that we only consider carbon atoms with local structures ```[C;X4][C][C][H][H]``` or ```[C;X4][C][H][H][H]```(here represented using SMARTS, but it it will just be encoded as if statements below).   


### General algorithm in the code below
* Loop over the first frame in trajectory until we encouter a ucer2
    * we will break from the loop after we've operated on the first ucer2 because this is simple for demonstration/testing
* create a networkx graph 
* create a few temporary arrays to store position and element type, as these will be needed for identifying chain indices
* populate the graph with all C-C and C-H pairs within a specified distance (i.e., close enough they would be bonded; note bonding information is not included in the mdtraj traj).  
 * These pairs are added as edges to the graph
* Since we hav excluded Nitrogen atoms, we will end up with two non-connected subgraphs in the networkx graph
* create a list of indices that match the critera above for each chain
    * add these to a list and print to the screen, compare to what I did manually.


Load up a trajectory of a multilayer membrane containing CER lipids (here CER NS).

In [3]:
traj = md.load('../../4_prod_f305/trimmed_470.xtc', 
            top='../../4_prod_f305/confout.gro')

In [37]:
graphs = []

for residue in traj[0].topology.residues:
        if residue.name == 'ucer2':
            #create temporary arrays to store relevant info to make it easier to
            #access than trying to use mdtraj topology
            
            cer_graph = nx.Graph() 

            xyz_temp = [] #temporary array to store coordinates in residue 
            element_temp = [] #temporary array to store residue element 


            
            for atom in residue.atoms:
                xyz_temp.append(np.array(traj.xyz[0,atom.index,:])) #convert these to np.arrays 
                element_temp.append(atom.name[0])
            #loop over all C-C and C-H particle pairs in the residue
            for i in range(0, len(xyz_temp)):
                if element_temp[i] == 'C':
                    for j in range(i+1, len(xyz_temp)):
                        if element_temp[j] == 'C' or element_temp[j] == 'H':
                        
                            #note we are assuming we have an unwrapped trajectory
                            dist = np.linalg.norm(xyz_temp[j]-xyz_temp[i])
                            if dist < 0.18:
                                cer_graph.add_edge(i,j) #note we will use relative indices in the residue
            chain_ids = []
            
            for c in nx.connected_components(cer_graph):                
                chain_temp = []
                for node in cer_graph.subgraph(c).nodes:
                    # first let us consider the C-backbone atoms
                    if 'C' in element_temp[node]:
                        connections = cer_graph.edges(node)
                        #only consider those with 4 connections
                        if len(connections) == 4:
                            C_count = 0
                            H_count = 0
                            for connection in connections:
                                temp_id = connection[1] #who we are connected to is the second element
                                if 'H' in element_temp[temp_id]:
                                    H_count = H_count + 1
                                elif 'C' in element_temp[temp_id]:
                                    C_count = C_count + 1
                            #we specified we need 4 connections and that they need to be either [C][C][H][H] or [C][H][H][H]
                            if C_count >= 1 and H_count >= 2:
                                chain_temp.append(node) #add the base carbon to the list of atoms

                                for connection in connections:
                                    temp_id = connection[1]
                                    #note only add the hydrogen atoms since we looping over all Carbons already
                                    if 'H' in element_temp[temp_id]:
                                        chain_temp.append(connection[1])
                chain_ids.append(chain_temp)
            

            if len(chain_ids) != 2:
                print(f'Something went wrong, we should have 2 subgraphs, found {len(chain_ids)}')
            for i, chain in enumerate(chain_ids):
                print(f'chain {i+1} relative indices:')
                temp_str = ''
                for cid in chain:
                    temp_str = temp_str + f"{cid} "
                print(temp_str)
                print("------")


            break
           


chain 1 relative indices:
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 
------
chain 2 relative indices:
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 
------


## Checking the output 

When I defined the indices manually for CER NS C24, I defined the relative indices to range from:

```
start_index1 = 4
end_index1 = 69

start_index2 = 89
end_index2 = 124
```

These effectively match the output above, the only difference is the manually defined indices excluded the terminal methyl group in each chain (I had done this initially to calculate S_CH and the terminal methyl is typically excluded as the C-H orientation differs).  

It would be trivial to change the algorithm to simply exclude this terminal methyl group, as done below. We can see that the chain indices now match when excluding the terminal group.    

In [38]:
graphs = []

for residue in traj[0].topology.residues:
        if residue.name == 'ucer2':
            #create temporary arrays to store relevant info to make it easier to
            #access than trying to use mdtraj topology
            
            xyz_temp = [] #temporary array to store coordinates in residue 
            element_temp = [] #temporary array to store residue element 
            cer_graph = nx.Graph() 


            
            for atom in residue.atoms:
                xyz_temp.append(np.array(traj.xyz[0,atom.index,:])) #convert these to np.arrays 
                element_temp.append(atom.name[0])
            #loop over all C-C and C-H particle pairs in the residue
            for i in range(0, len(xyz_temp)):
                if element_temp[i] == 'C':
                    for j in range(i+1, len(xyz_temp)):
                        if element_temp[j] == 'C' or element_temp[j] == 'H':
                        
                            #note we are assuming we have an unwrapped trajectory
                            dist = np.linalg.norm(xyz_temp[j]-xyz_temp[i])
                            if dist < 0.18:
                                cer_graph.add_edge(i,j) #note we will use relative indices in the residue
            chain_ids = []
            
            for c in nx.connected_components(cer_graph):                
                chain_temp = []
                for node in cer_graph.subgraph(c).nodes:
                    # first let us consider the C-backbone atoms
                    if 'C' in element_temp[node]:
                        connections = cer_graph.edges(node)
                        #only consider those with 4 connections
                        if len(connections) == 4:
                            C_count = 0
                            H_count = 0
                            for connection in connections:
                                temp_id = connection[1] #who we are connected to is the second element
                                if 'H' in element_temp[temp_id]:
                                    H_count = H_count + 1
                                elif 'C' in element_temp[temp_id]:
                                    C_count = C_count + 1
                            #we specified we need 4 connections and that they need to be either [C][C][H][H] or [C][H][H][H]
                            if C_count >= 1 and H_count == 2:
                                chain_temp.append(node) #add the base carbon to the list of atoms

                                for connection in connections:
                                    temp_id = connection[1]
                                    #note only add the hydrogen atoms since we looping over all Carbons already
                                    if 'H' in element_temp[temp_id]:
                                        chain_temp.append(connection[1])
                chain_ids.append(chain_temp)
            

            if len(chain_ids) != 2:
                print(f'Something went wrong, we should have 2 subgraphs, found {len(chain_ids)}')
            for i, chain in enumerate(chain_ids):
                print(f'chain {i+1} relative indices:')
                temp_str = ''
                for cid in chain:
                    temp_str = temp_str + f"{cid} "
                print(temp_str)
                print("------")


            break
           



chain 1 relative indices:
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 
------
chain 2 relative indices:
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 
------
