In [1]:
# Load notebooks with required functions
import ipynb
from ipynb.fs.full.get_coordinates import *
from ipynb.fs.full.get_contacts import *

# Load required libraries
import numpy as np
import os
from tqdm import tqdm 
from joblib import Parallel, delayed 
from functools import partial
import mdtraj as md 
import itertools
import pandas as pd 
import warnings #Optional
warnings.filterwarnings("ignore") #Optional

In [2]:
def wcontact_matrix(thresholds, num_cores = 1, prot_name = None, save_to = None, pdb_folder = None, xtc_file = None, top_file = None, start = None, end = None, select_chain = None, name_variable = '__main__'):
    
    if save_to is None and prot_name is None:
        
        quit('Please set save_to = None or prot_name != None and save_to != None.')
    
    if xtc_file is None and top_file is None and pdb_folder is not None:
                
        traj_file = None
        conf_list = os.listdir(pdb_folder)
        N_conformations = len(conf_list) # Number of conformations
        
    elif xtc_file is not None and top_file is not None and pdb_folder is None:
        
        if top_file.endswith(".gro"):
            top_file = md.formats.GroTrajectoryFile(top_file).topology
          
        traj_file = md.load_xtc(xtc_file, top = top_file)
        N_conformations = len(traj_file)
        conf_list = np.arange(N_conformations)        
         
    else:
        quit('Please set pdb_folder != None and xtc_file = top_file = None, or pdb_folder = None and xtc_file != None, top_file != None.')
       
    
    def comp_function(conf_comp, thresholds_comp, pdb_data_comp, traj_data_comp, start_comp, end_comp, sel_chain):
        
        coordinates = get_coordinates(conf_name = conf_comp, pdb = pdb_data_comp, traj = traj_data_comp, res_start = start_comp, res_end = end_comp, which_chain = sel_chain)
        contacts = get_contacts(coordinates, thresholds_comp)
        return contacts
    
    it_function = partial(comp_function, thresholds_comp = thresholds, pdb_data_comp = pdb_folder, traj_data_comp = traj_file, start_comp = start, end_comp = end, sel_chain = select_chain) 
    N_pairs = len(it_function(conf_list[0]))  
    
    def it_function_error(conf):
        
        try:
            output = it_function(conf)
        except:
            output = np.repeat(np.nan, N_pairs)
        return output
   
    
    if __name__ == name_variable:
       
        os.environ['PYTHONWARNINGS'] = 'ignore'
        wcont_matrix = Parallel(n_jobs = num_cores, prefer = 'processes')(delayed(it_function_error)(i) for i in tqdm(conf_list))   
        wcont_data = pd.DataFrame(np.reshape(np.asarray(wcont_matrix), [len(conf_list), N_pairs]))
        
    if save_to is None:
        
        return wcont_data
   
    elif save_to is not None and prot_name is not None:
        
        wcont_data.to_csv('_'.join(['/'.join([save_to, prot_name]), 'wcontmatrix.txt']), header = None, index = None, sep = ' ')

        