In [1]:
import time
import MDAnalysis as mda
import numpy as np
import os
from itertools import product
from tqdm import tqdm

# List of base directories and their names
directories_info = [
    {'path': '/pathtofolder/', 'name': 'RGD_280K'},
    {'path': '/pathtofolder/', 'name': 'RGD_370K'}
]

def get_polar_beads(universe, n_chains=360, atoms_per_chain=16):  #adjust n_chains=num of repeats, atoms_per_chain
    indices = [0, 3, 4, 6, 8, 13, 15]
    polar_beads = []
    for n in range(n_chains):
        selection_str = 'index ' + ' '.join(map(str, [i + atoms_per_chain * n for i in indices]))
        polar_beads.extend(universe.select_atoms(selection_str))
    return polar_beads

def get_atom_groups(universe, n_chains=360, atoms_per_chain=16):
    indices = {
        'pi': [10, 11],
        'cation': [2],
        'negative': [5]
    }
    atom_groups = {key: [] for key in indices.keys()}
    for n in range(n_chains):
        for key, offset in indices.items():
            selection_str = 'index ' + ' '.join(map(str, [i + atoms_per_chain * n for i in offset]))
            atom_group = universe.select_atoms(selection_str)
            atom_groups[key].append(atom_group)
    return atom_groups

def get_apolar_beads(universe, n_chains=360, atoms_per_chain=16):
    indices = [1, 14]
    apolar_beads = []
    for n in range(n_chains):
        selection_str = 'index ' + ' '.join(map(str, [i + atoms_per_chain * n for i in indices]))
        apolar_beads.extend(universe.select_atoms(selection_str))
    return apolar_beads

def calculate_hbonds(polar_beads, threshold=5.0):
    inter_hbond_count = 0
    intra_hbond_count = 0
    n_beads = len(polar_beads)
    #groups = [list(range(i, i + 80)) for i in range(0, n_beads, 80)]
    centers = [bead.position for bead in polar_beads]
    for (i, com1), (j, com2) in product(enumerate(centers), repeat=2):
        if i < j:
            distance = np.linalg.norm(com1 - com2)
            if distance < threshold:
                group_i = i // 84
                group_j = j // 84
                if group_i != group_j:
                    inter_hbond_count += 1
                elif abs(i - j) > 3:
                    intra_hbond_count += 1
    return inter_hbond_count, intra_hbond_count

def calculate_ccs(apolar_beads, threshold=5.0):
    inter_cc_count = 0
    intra_cc_count = 0
    n_beads = len(apolar_beads)
    centers = [bead.position for bead in apolar_beads]
    for (i, com1), (j, com2) in product(enumerate(centers), repeat=2):
        if i < j:
            distance = np.linalg.norm(com1 - com2)
            if distance < threshold:
                group_i = i // 24  #num of apolar beads in one repeat*number of repeat in one chain
                group_j = j // 24
                if group_i != group_j:
                    inter_cc_count += 1
                elif abs(i - j) > 0:
                    intra_cc_count += 1
    return inter_cc_count, intra_cc_count

def calculate_distances(centers1, centers2, threshold):
    pairs = []
    for (i, com1), (j, com2) in product(enumerate(centers1), enumerate(centers2)):
        if i != j:
            distance = np.linalg.norm(com1 - com2)
            if distance < threshold:
                pairs.append(((i, j), distance))
    return pairs

def process_frame(universe, polar_beads, atom_groups, chain_groups, apolar_beads):
    def in_same_chain_group(i, j):
        return any(i in group and j in group for group in chain_groups)

    inter_hbond_count, intra_hbond_count = calculate_hbonds(polar_beads)
    inter_cc_count, intra_cc_count = calculate_ccs(apolar_beads)
    
    centers = {key: np.array([group.center_of_geometry() for group in groups]) for key, groups in atom_groups.items()}
    interactions = {
        'pi_cation': calculate_distances(centers['pi'], centers['cation'], 5),
        'q_q': calculate_distances(centers['negative'], centers['cation'], 5),
        'pi_pi': calculate_distances(centers['pi'], centers['pi'], 5)
    }
    
    interaction_results = {}
    for key, pairs in interactions.items():
        same_chain_pairs = []
        diff_chain_pairs = []
        for (i, j), distance in pairs:
            if in_same_chain_group(i, j):
                same_chain_pairs.append(distance)
            else:
                diff_chain_pairs.append(distance)
        interaction_results[key] = {
            'same_chain_pairs': same_chain_pairs,
            'diff_chain_pairs': diff_chain_pairs
        }
    
    return inter_hbond_count, intra_hbond_count, inter_cc_count, intra_cc_count, interaction_results

def analyze_directory(directory_info, stride):
    xtc_file = os.path.join(directory_info['path'], 'solute.xtc')
    tpr_file = os.path.join(directory_info['path'], 'solute.tpr')
    universe = mda.Universe(tpr_file, xtc_file)
    polar_beads = get_polar_beads(universe)
    atom_groups = get_atom_groups(universe)
    apolar_beads = get_apolar_beads(universe)
    chain_groups = [list(range(192 * i, 192 * (i + 1))) for i in range(30)]
    
    inter_frame_counts = []
    intra_frame_counts = []
    inter_cc_frame_counts = []
    intra_cc_frame_counts = []
    frame_data = {key: [] for key in ['pi_cation', 'q_q', 'pi_pi']}
    
    for ts in tqdm(universe.trajectory[2000:7001:stride], desc=f"Processing {directory_info['name']}"):
        inter_hbond_count, intra_hbond_count, inter_cc_count, intra_cc_count, frame_results = process_frame(
            universe, polar_beads, atom_groups, chain_groups, apolar_beads)
        inter_frame_counts.append(inter_hbond_count)
        intra_frame_counts.append(intra_hbond_count)
        inter_cc_frame_counts.append(inter_cc_count)
        intra_cc_frame_counts.append(intra_cc_count)
        for key in frame_data.keys():
            frame_data[key].append(frame_results[key])
    
    summary = {
        'inter_hbonds': {
            'total': sum(inter_frame_counts),
            'average': np.mean(inter_frame_counts) if inter_frame_counts else 0,
            'std': np.std(inter_frame_counts) if inter_frame_counts else 0
        },
        'intra_hbonds': {
            'total': sum(intra_frame_counts),
            'average': np.mean(intra_frame_counts) if intra_frame_counts else 0,
            'std': np.std(intra_frame_counts) if intra_frame_counts else 0
        },
        'inter_ccs': {
            'total': sum(inter_cc_frame_counts),
            'average': np.mean(inter_cc_frame_counts) if inter_cc_frame_counts else 0,
            'std': np.std(inter_cc_frame_counts) if inter_cc_frame_counts else 0
        },
        'intra_ccs': {
            'total': sum(intra_cc_frame_counts),
            'average': np.mean(intra_cc_frame_counts) if intra_cc_frame_counts else 0,
            'std': np.std(intra_cc_frame_counts) if intra_cc_frame_counts else 0
        }
    }
    
    for key, data in frame_data.items():
        total_same_chain_pairs = sum(len(frame['same_chain_pairs']) for frame in data)
        total_diff_chain_pairs = sum(len(frame['diff_chain_pairs']) for frame in data)
        summary[key] = {
            'total_same_chain_pairs': total_same_chain_pairs,
            'total_diff_chain_pairs': total_diff_chain_pairs,
            'average_same_chain_pairs': np.mean([len(frame['same_chain_pairs']) for frame in data]),
            'std_same_chain_pairs': np.std([len(frame['same_chain_pairs']) for frame in data]),
            'average_diff_chain_pairs': np.mean([len(frame['diff_chain_pairs']) for frame in data]),
            'std_diff_chain_pairs': np.std([len(frame['diff_chain_pairs']) for frame in data])
        }
    
    return directory_info['name'], summary

# Main loop over directories
time_between_frames = 1
desired_interval = 2
stride = int(desired_interval / time_between_frames)

for directory_info in directories_info:
    name, summary = analyze_directory(directory_info, stride)
    
    with open(f"{name}_0.5mM_interaction_analysis.txt", 'w') as f:
        f.write("type\tmean_inter\tstd_inter\tmean_intra\tstd_intra\n")
        f.write(f"p-p\t{summary['inter_hbonds']['average']:.2f}\t{summary['inter_hbonds']['std']:.2f}\t"
                f"{summary['intra_hbonds']['average']:.2f}\t{summary['intra_hbonds']['std']:.2f}\n")
        f.write(f"c-c\t{summary['inter_ccs']['average']:.2f}\t{summary['inter_ccs']['std']:.2f}\t"
                f"{summary['intra_ccs']['average']:.2f}\t{summary['intra_ccs']['std']:.2f}\n")
        for interaction_type, data in summary.items():
            if interaction_type not in ['inter_hbonds', 'intra_hbonds', 'inter_ccs', 'intra_ccs']:
                f.write(f"{interaction_type}\t{data['average_diff_chain_pairs']:.2f}\t{data['std_diff_chain_pairs']:.2f}\t"
                        f"{data['average_same_chain_pairs']:.2f}\t{data['std_same_chain_pairs']:.2f}\n")


Processing RGD_280K: 100%|████████████████| 2501/2501 [7:10:09<00:00, 10.32s/it]
Processing RGD_370K: 100%|████████████████| 2501/2501 [7:09:29<00:00, 10.30s/it]
