## Run Definitions

In [None]:
import re
import os

def process_file(input_file, work_dir, residue_number, threshold=0.5):
    """
    Processes a single *_freq.tsv file, rearranges residues, applies a threshold filter for contact frequency,
    and includes the interaction type (extracted from the file name).
    Args:
        input_file (str): The name of the input file.
        work_dir (str): The directory where the files are located.
        residue_number (str): The residue number extracted from the input file name.
        threshold (float): The contact frequency threshold for inclusion in the output.
    """
    processed_data = []
    
    interaction_type = re.search(r'_(\w+)_interactions_freq\.tsv', input_file)
    if interaction_type:
        interaction_type = interaction_type.group(1)
        if interaction_type == "hplp":
            interaction_type = "hp"
    else:
        interaction_type = "unknown"

    with open(work_dir + input_file, 'r') as infile:
        for line in infile:
            if line.startswith('#'):
                continue
            
            parts = line.strip().split('\t')

            if len(parts) < 3:
                continue

            residue_1 = parts[0]
            residue_2 = parts[1]
            contact_frequency = float(parts[2])
            
            if 'all' in residue_1 or 'all' in residue_2 or 'all' in interaction_type:
                continue
            if 'pc' in residue_1 or 'pc' in residue_2 or 'pc' in interaction_type:
                continue
            
            if contact_frequency < threshold:
                continue

            residue_1 = residue_1[2:]
            residue_2 = residue_2[2:]

            residue_1_number = int(residue_1.split(':')[1]) 
            residue_2_number = int(residue_2.split(':')[1]) 
            if residue_1_number > arg18_pos or residue_2_number > arg18_pos:                                                    #change this if needed, 468 is Arg18 in CGRP helix
                continue

            target_residue_1 = f"{residue_1.split(':')[0]}:{residue_number}"
            target_residue_2 = f"{residue_2.split(':')[0]}:{residue_number}"

            if residue_1 == target_residue_1:
                processed_data.append(f"{residue_1} {residue_2} {interaction_type} {contact_frequency}")
            elif residue_2 == target_residue_2:
                processed_data.append(f"{residue_2} {residue_1} {interaction_type} {contact_frequency}")
            else:
                processed_data.append(f"{residue_1} {residue_2} {interaction_type} {contact_frequency}")
    
    return processed_data


def process_all_files(work_dir, output_file, threshold=0.5):
    """
    Processes all *_freq.tsv files in the given directory and combines their contents into one output file.
    Args:
        work_dir (str): The directory where the files are located.
        output_file (str): The name of the combined output file to save the processed data.
        threshold (float): The contact frequency threshold for inclusion in the output.
    """
    all_processed_data = []

    for filename in os.listdir(work_dir):
        if filename.endswith('freq.tsv'):
            residue_number_match = re.search(r'(\d+)_.*\.tsv', filename)
            if residue_number_match:
                residue_number = residue_number_match.group(1)
                residue_number = int(residue_number)
            else:
                continue

            if residue_number >= ace_pos and residue_number <= arg18_pos:                         #change this if needed
                processed_data = process_file(filename, work_dir, str(residue_number), threshold)
                all_processed_data.extend(processed_data)

    def sort_key(line):
        """
        Helper function to extract the residue number from a line and sort it in ascending order.
        """
        try:
            residue = line.split()[0]
            residue_number = re.search(r'(\d+):', residue).group(1)
            return int(residue_number)
        except AttributeError:
            return float('inf')

    all_processed_data.sort(key=sort_key)

    with open(work_dir + output_file, 'w') as outfile:
        for line in all_processed_data:
            outfile.write(line + '\n')

    print(f"File processing complete. All data (above threshold) saved to: {work_dir + output_file}")

## Enter directory and file information

In [None]:
work_dir = '/mnt/storage1/adam/CGRP/New/CGRP1-37/ligand_receptor_interactions/'
ace_pos = 469                                                                   #or val8_pos
arg18_pos = 479
output_file = 'combined_processed_freq_sorted_with_interaction_type.tsv'
input_file = 'combined_processed_freq_sorted_with_interaction_type.tsv'

process_all_files(work_dir, output_file, threshold=0.5)

## Make ligand-ligand contact interaction map

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from collections import defaultdict
import re

def map_residue(residue):
    """
    Map a residue like ACE:357 to ACE7, VAL:458 to VAL8, and so on.
    Stops mapping at ARG:468 and excludes residues beyond this range.
    
    Args:
        residue (str): The residue in the form "RES:NUMBER".
    Returns:
        str: The mapped residue in the form "RES#".
    """
    try:
        res_name, res_number = residue.split(":")
        res_number = int(res_number) 
    except ValueError:
        print(f"Error processing residue: {residue}")
        return None

   
    if res_number >= ace_pos and res_number <= arg18_pos:                                             
        mapped_residue = f"{res_name}{res_number - 455}"       #change this if needed
        return mapped_residue
    else:
        return None

def process_tsv_and_create_table(input_file, work_dir):
    """
    Processes the input TSV file and creates a table with residues sorted in ascending order by their residue number.
    Displays only the interactions of residue_1, showing each interaction type on a separate line.
    
    Args:
        input_file (str): The name of the TSV file to process.
        work_dir (str): The directory where the file is located.
    """
    file_path = work_dir + input_file
    
    try:
        df = pd.read_csv(file_path, sep=' ', header=None, names=["residue_1", "residue_2", "interaction_type", "contact_frequency"])
        
        print("Loaded data:")
        print(df.head())

    except Exception as e:
        print(f"Error reading the file: {e}")
        return

    df['mapped_residue_1'] = df['residue_1'].apply(map_residue)
    df['mapped_residue_2'] = df['residue_2'].apply(map_residue)

    # Replace "HID" or "HIE" with "HIS" in mapped residues
    df['mapped_residue_1'] = df['mapped_residue_1'].str.replace("HID", "HIS").str.replace("HIE", "HIS")
    df['mapped_residue_2'] = df['mapped_residue_2'].str.replace("HID", "HIS").str.replace("HIE", "HIS")

    df = df.dropna(subset=['mapped_residue_1', 'mapped_residue_2'])
    df = df.dropna(subset=['mapped_residue_1', 'mapped_residue_2'])

    def get_residue_number(residue):
        """
        Extracts the last one or two numeric digits at the end of the residue for sorting.
        E.g., "R5112" will be treated as 12.
        """
        match = re.search(r'(\d{1,2})$', residue)
        if match:
            return int(match.group())
        else:
            return -1

    all_residues = sorted(set(df['mapped_residue_1'].unique()), key=get_residue_number)

    interaction_dict = defaultdict(lambda: defaultdict(list))

    for _, row in df.iterrows():
        residue_1 = row['mapped_residue_1']
        residue_2 = row['mapped_residue_2']
        interaction_type = row['interaction_type']
        contact_frequency = round(row['contact_frequency'] * 100)  # Convert to percentage

        interaction_dict[residue_1][residue_2].append(f"{interaction_type} {contact_frequency}%")

    interaction_matrix = pd.DataFrame('', index=all_residues, columns=all_residues)
    interaction_matrix = pd.DataFrame('', index=all_residues, columns=all_residues)

    for residue_1 in all_residues:
        for residue_2 in all_residues:
            if residue_1 == residue_2:
                continue 
            
            if residue_2 in interaction_dict[residue_1]:
                interactions = '\n'.join(interaction_dict[residue_1][residue_2])
                interaction_matrix.at[residue_1, residue_2] = interactions
            else:
                interaction_matrix.at[residue_1, residue_2] = ""

    fig, ax = plt.subplots(figsize=(12, 12))
    ax.axis('off')

    cell_colours = [['white' for _ in range(len(all_residues))] for _ in range(len(all_residues))]  # Default to white

    for i in range(len(all_residues)):
        cell_colours[i][i] = 'lightgrey'

        if i + 4 < len(all_residues):
            cell_colours[i][i + 4] = 'lightcoral'
            cell_colours[i + 4][i] = 'lightcoral'

    table = ax.table(cellText=interaction_matrix.values, 
                colLabels=interaction_matrix.columns, 
                rowLabels=interaction_matrix.index,
                loc='center', 
                cellLoc='center', 
                cellColours=cell_colours)

    table.auto_set_font_size(False)
    table.set_fontsize(12)
    table.scale(1.5, 8)

    output_image_path = work_dir + 'all_interaction_table_lig_lig_helix'
    plt.savefig(output_image_path, bbox_inches='tight', pad_inches=0, dpi=300)

    print(f"Table image saved to {output_image_path}")

process_tsv_and_create_table(input_file, work_dir)
