In [12]:
import os
from src.analysis.utils import get_file_by_extension
from src.plumed.cv import get_contact_map
from typing import Literal, Optional

def get_contact_maps(
    directory: str, 
    cutoff: float = 0.8, 
    mode: Literal['single-chain', 'two-chain'] = 'two-chain',
    spot1_residues: Optional[Segment] = None,
    spot2_residues: Optional[Segment] = None,
    ):
    """
    Get contact maps for both solvated and equilibrated structures of a given system.
    
    Args:
        system_name (str): Name of the system (e.g., 'ASYN-A')
        
    Returns:
        tuple: (solvated_contact_map, equilibrated_contact_map)
    """
    
    solvated_pdb = get_file_by_extension(directory, '_solvated.pdb')
    equilibrated_pdb = get_file_by_extension(directory, '_equilibrated.pdb')
    
    solvated_contact_map = get_contact_map(
        filename = os.path.basename(solvated_pdb).replace('_solvated', '').replace('.pdb', ''), 
        cutoff = cutoff,
        output_dir = directory,
        mode = mode,
        specify_pdb = 'solvated',
        spot1_residues = spot1_residues,
        spot2_residues = spot2_residues,
    )
    
    equilibrated_contact_map = get_contact_map(
        filename = os.path.basename(equilibrated_pdb).replace('_equilibrated', '').replace('.pdb', ''), 
        cutoff = cutoff,
        output_dir = directory,
        mode = mode,
        specify_pdb = 'equilibrated',
        spot1_residues = spot1_residues,
        spot2_residues = spot2_residues,
    )
    
    return solvated_contact_map, equilibrated_contact_map

from typing import Optional
from src.models import Segment, Residue
import re
def extract_from_system(system):
    # need to use regex to extract for instance
    # Q7-B30L4 -> BINDER_LENGTH = 30, LINKER1_LENGTH = 4, LINKER2_LENGTH = 4
    # Z1-B40L10W -> BINDER_LENGTH = 40, LINKER1_LENGTH = 10, LINKER2_LENGTH = 10
    # PQ19-B50L10W -> BINDER_LENGTH = 50, LINKER1_LENGTH = 10, LINKER2_LENGTH = 10 
    pattern = r'([A-Z0-9]+)-B(\d+)L(\d+)W?'  # Made W optional with ? and added digits to first group
    match = re.match(pattern, system)
    if match:
        BINDER_LENGTH = int(match.group(2))
        LINKER1_LENGTH = int(match.group(3))
        LINKER2_LENGTH = LINKER1_LENGTH
    else:
        raise ValueError(f"System {system} does not match pattern {pattern}")
    return BINDER_LENGTH, LINKER1_LENGTH, LINKER2_LENGTH


def get_binding_spots(system: str) -> tuple[Segment, Segment]:
    """
    Get the binding spot segments for a given system.
    
    Args:
        system (str): Name of the system (e.g., 'Q7-B30L4')
        
    Returns:
        tuple[Segment, Segment]: Two segments representing the binding spots
    """
    # Determine protease length based on system
    PROTEASE_LENGTH = 9 if 'Q7' in system else 10
    
    # Extract parameters from system name
    BINDER_LENGTH, LINKER1_LENGTH, LINKER2_LENGTH = extract_from_system(system)
    INF_LENGTH = 161
    NON_INF_LENGTH = BINDER_LENGTH + LINKER1_LENGTH + PROTEASE_LENGTH + LINKER2_LENGTH
    
    # Determine binding spot range based on system
    if 'W' in system:
        # Wider binding spot
        spot2_start = 55
        spot2_end = 135
    else:
        # Default binding spot
        spot2_start = 80
        spot2_end = 100
    
    # Create segments for both binding spots
    spot1_residues = Segment(residues=[
        Residue(
            index=i, 
            global_index=i,
            chain_id='A', 
            indexing=1
        ) for i in range(1, BINDER_LENGTH + 1)
    ])
    
    spot2_residues = Segment(residues=[
        Residue(
            index=i, 
            global_index=i,
            chain_id='A',
            indexing=1
        ) for i in range(
            NON_INF_LENGTH + spot2_start + 1,
            NON_INF_LENGTH + spot2_end + 1
        )
    ])
    
    return spot1_residues, spot2_residues

In [14]:
# print('number of contacts in solvated -> in equilibrated')
# for system_name in [
#     'ASYN-A', 'ASYN-G', 
#     'CD28-A', 'CD28-B', 'CD28-G', 'CD28-P',
#     'P53-1', 'P53-2', 'P53-E',
#     'SUMO-1A', 'SUMO-1C',
# ]:
#     directory = f'/home/jakub/phd/openmm-md/data/241010_FoldingUponBinding/output/{system_name}/241122-Explore'
#     solvated_cmap, equilibrated_cmap = get_contact_maps(directory, cutoff=0.8)
#     solvated_num_contacts = len(solvated_cmap.contacts)
#     equilibrated_num_contacts = len(equilibrated_cmap.contacts)
#     print(f"{system_name}: {solvated_num_contacts} -> {equilibrated_num_contacts}")

for system_name in [
    'PQ19-B30L4', 'PQ19-B30L7', 'PQ19-B30L10', 'PQ19-B40L10', 'PQ19-B40L10W', 'PQ19-B50L10W',
    'Q7-B30L4', 'Q7-B30L7', 'Q7-B30L10', 'Q7-B40L10', 'Q7-B40L10W', 'Q7-B50L10W',
    'Z1-B30L4', 'Z1-B30L7', 'Z1-B30L10', 'Z1-B40L10', 'Z1-B40L10W', 'Z1-B50L10W',
]:
    directory = f'/home/jakub/phd/openmm-md/data/241109_INFconstruct/output/{system_name}/241122-Explore'
    spot1_residues, spot2_residues = get_binding_spots(system_name)
    solvated_cmap, equilibrated_cmap = get_contact_maps(
        directory,
        mode = 'single-chain',
        spot1_residues = spot1_residues,
        spot2_residues = spot2_residues,
    )
    solvated_num_contacts = len(solvated_cmap.contacts)
    equilibrated_num_contacts = len(equilibrated_cmap.contacts)
    print(f"{system_name}: {solvated_num_contacts} -> {equilibrated_num_contacts}")

PQ19-B30L4: 8 -> 17
PQ19-B30L7: 6 -> 22
PQ19-B30L10: 4 -> 11
PQ19-B40L10: 6 -> 21
PQ19-B40L10W: 25 -> 31
PQ19-B50L10W: 25 -> 43
Q7-B30L4: 17 -> 25
Q7-B30L7: 19 -> 18
Q7-B30L10: 13 -> 17
Q7-B40L10: 17 -> 15
Q7-B40L10W: 25 -> 41
Q7-B50L10W: 15 -> 27
Z1-B30L4: 16 -> 11
Z1-B30L7: 4 -> 9
Z1-B30L10: 8 -> 8
Z1-B40L10: 1 -> 4
Z1-B40L10W: 12 -> 30
Z1-B50L10W: 24 -> 51
