## General Description

1. Create Datasets
2. Create AF jobs
3. analyze AF output

## 1. Create Datasets

### Library imports and helper functions

In [None]:
import pandas as pd
import os
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import requests
import sys
import re
from typing import List, Tuple, Dict, Any, Union, Optional, Set

def contains_any_annotation(cell_value: Any, annotations_list: List[str]) -> bool:
    """Check if any of the annotations are in the column value.

    Args:
        cell_value (Any): Cell value from DataFrame column to check
        annotations_list (List[str]): List of annotation strings to check for

    Returns:
        bool: True if any annotation is found in the cell_value, False otherwise
    """
    if pd.isna(cell_value):
        return False
    for annotation in annotations_list:
        if annotation in cell_value:
            return True
    return False

### Constants

In [None]:
STRING_PATH = '/home/markus/MPI_local/data/STRING/9606.protein.physical.links.detailed.v12.0.txt_processed.csv'
PROTEOME_PATH = '/home/markus/MPI_local/data/Proteome/uniprotkb_proteome_UP000005640_2025_05_28.tsv'
# PROTEOME_PATH = '/home/markus/MPI_local/data/full_PDB/uniprotkb_AND_reviewed_true_2025_07_10.tsv'
TF_DATASET_PATH = '/home/markus/MPI_local/data/human_TFs/DatabaseExtract_v_1.01.csv'
ENSEMBL_MAPPING_PATH = '/home/markus/MPI_local/data/Ensembl_mapping/Homo_sapiens.GRCh38.114.uniprot.tsv/hps/nobackup/flicek/ensembl/production/release_dumps/release-114/ftp_dumps/vertebrates/tsv/homo_sapiens/Homo_sapiens.GRCh38.114.uniprot.tsv'
AIUPRED_PATH = '/home/markus/MPI_local/data/AIUPred/AIUPred_data.json'
DISPROT_PATH = '/home/markus/MPI_local/data/DisProt/DisProt_release_2024_12 with_ambiguous_evidences.tsv'
IUPRED3_PATH = '../../iupred3'

AF_TOKEN_LIMIT = 5120

### Read in Proteome

In [None]:
all_uniprot = pd.read_csv(PROTEOME_PATH, low_memory=False, sep='\t')
uniprot_filtered = all_uniprot[all_uniprot['Reviewed'] == 'reviewed']
# uniprot_filtered = all_proteins

### create Armadillo dataset

In [None]:
# Filter for Armadillos
## all the interpro IDs that have the word "armadillo" in the description
arm_accs_ipr = [
    'IPR013636',
    'IPR024574',
    'IPR041322',
    'IPR055241',
    'IPR056252',
    'IPR016617',
    'IPR031524',
    'IPR038739',
    'IPR038905',
    'IPR039868',
    'IPR040268',
    'IPR042462',
    'IPR042834',
    'IPR043379',
    'IPR044282',
    'IPR051303',
    'IPR052441',
    'IPR011989',
    'IPR016024',
    'IPR000225',
    'IPR041209',
    'IPR049152',
    'IPR006911'
]

# all pfam accessions that have "armadillo" in the description
arm_accs_pfam = [
    'PF00514.29',
    'PF17822.6',
    'PF08427.15',
    'PF15767.10',
    'PF22915.1',
    'PF04826.19',
    'PF23295.1',
    'PF16629.10',
    'PF18770.7',
    'PF21052.3',
    'PF11841.14',
    'PF14726.11',
    'PF18581.7'
]

# strip version number since it is not included in the uniprot annotation
for i in range(len(arm_accs_pfam)):
    arm_accs_pfam[i] = arm_accs_pfam[i].split(".")[0]
    
# find proteins with "ARM" in their Repeat column
repeat_mask_arm = uniprot_filtered['Repeat'].apply(lambda x: "ARM" in str(x) if pd.notna(x) else False)
print(f"Proteins with 'ARM' in Repeat column: {len(uniprot_filtered[repeat_mask_arm])}")

# Filter rows where InterPro column contains any interpro_annotations
interpro_mask_arm = uniprot_filtered['InterPro'].apply(lambda x: contains_any_annotation(x, arm_accs_ipr))
print(f"Proteins with specified InterPro annotation: {len(uniprot_filtered[interpro_mask_arm])}")

# Filter rows where Pfam column contains any pfam_annotations
pfam_mask_arm = uniprot_filtered['Pfam'].apply(lambda x: contains_any_annotation(x, arm_accs_pfam))
print(f"Proteins with specified Pfam annotation: {len(uniprot_filtered[pfam_mask_arm])}")

# apply filters using OR
armadillo_proteins = uniprot_filtered[interpro_mask_arm | pfam_mask_arm | repeat_mask_arm]

print(f"Found {len(armadillo_proteins)} proteins with armadillo domains")

In [None]:
armadillo_proteins['Entry'].to_csv('armadillo_proteins_entries.txt', index=False, header=False)

### Create Transcription Factor dataset (UniProtFilter)

In [None]:
# def str_tf(x):
#     return "transcription factor" in x.lower()

# def str_t(x):
#     return "transcription" in x.lower()

# # IPR accessions containing "transcription"
# IPR_entries = pd.read_csv("../entry.list", sep="\t")
# tf_accs_ipr = IPR_entries[IPR_entries['ENTRY_NAME'].apply(lambda x: str_t(x))]["ENTRY_AC"].tolist()

# # PFAM accessions containing "transcription factor"
# PFAM_entries = pd.read_csv("../data/pfam_parsed_data.csv", sep=",")
# tf_accs_pfam = PFAM_entries[PFAM_entries['DE'].apply(lambda x: str_t(x))]["AC"].tolist()

# # strip version number since it is not included in the uniprot annotation
# for i in range(len(tf_accs_pfam)):
#     tf_accs_pfam[i] = tf_accs_pfam[i].split(".")[0]

# interpro_mask_tf = reviewed_proteins['InterPro'].apply(lambda x: contains_any_annotation(x, tf_accs_ipr))
# print(f"Proteins with specified InterPro annotation: {len(reviewed_proteins[interpro_mask_tf])}")

# pfam_mask_tf = reviewed_proteins['Pfam'].apply(lambda x: contains_any_annotation(x, tf_accs_pfam))
# print(f"Proteins with specified Pfam annotation: {len(reviewed_proteins[pfam_mask_tf])}")

# txt_mask_tf = reviewed_proteins['Protein names'].apply(lambda x: str_tf(x))
# print(f"Proteins with 'Transcription factor' in the name: {len(reviewed_proteins[txt_mask_tf])}")


# # Combine filters with OR operation
# tf_proteins_uniprot_ds = reviewed_proteins[interpro_mask_tf | pfam_mask_tf | txt_mask_tf]

# print(f"Found {len(tf_proteins_uniprot_ds)} proteins with transcription factor annotation")

### Use existing Transcription Factor dataset

In [None]:
human_TFs = pd.read_csv(TF_DATASET_PATH)
human_TFs = human_TFs[human_TFs['Is TF?'] == 'Yes']
len(human_TFs)

In [None]:
ensembl_mapping = pd.read_csv(ENSEMBL_MAPPING_PATH, sep='\t')
ensembl_mapping_swissProt = ensembl_mapping[ensembl_mapping['db_name'] == 'Uniprot/SWISSPROT']

In [None]:
human_TFs_gids = human_TFs['Ensembl ID'].tolist()

# Use swiss prot accessions to prevent duplicates
human_TF_uniprot_accs = ensembl_mapping_swissProt[ensembl_mapping_swissProt['gene_stable_id'].apply(lambda x: any((id == x) for id in human_TFs_gids))]['xref'].tolist()
print(len(human_TF_uniprot_accs))

tf_proteins_curated_ds = uniprot_filtered[uniprot_filtered['Entry'].apply(lambda x: any((id in x) for id in human_TF_uniprot_accs))]
print(len(tf_proteins_curated_ds))

### Filter disordered transcription factors

#### AIUPred

##### Annotaion

In [None]:
# API requests (only necessary once, see cache file AIUPRED_PATH)

# url = 'https://aiupred.elte.hu/rest_api'

# AIUPred_data = []
# c = 0
# for acc in tf_proteins_curated_ds['Entry'].tolist():
#     data = {'accession': acc, 'smoothing': 'default'}
#     response = requests.get(url, params=data)
#     if response.status_code == 200:
#         AIUPred_data.append(json.loads(response.text))
#         c += 1
#         if c % 100 == 0:
#             print(c)
#     else:
#         print(f"Failed to fetch data for accession {acc}: {response.status_code}")
        
# with open('AIUPred_data.json', 'a') as json_file:
#     json.dump(AIUPred_data, json_file, indent=2)

In [None]:
# load cache file
with open(AIUPRED_PATH, 'r') as file:
    AIUPred_df = pd.DataFrame(json.load(file))

##### Filter

In [None]:
from typing import List, Tuple

def find_subranges(data: List[float], threshold: float, min_length: int) -> List[Tuple[int, int]]:
    """Find continuous subranges in data where values exceed the threshold for at least min_length positions.

    Args:
        data (List[float]): List of numerical values to analyze
        threshold (float): Minimum value to be considered part of a subrange
        min_length (int): Minimum length a subrange must have to be included in results

    Returns:
        List[Tuple[int, int]]: List of tuples containing start and end indices of subranges
    """
    subranges = []
    start = None

    for i, value in enumerate(data):
        if value >= threshold:
            if start is None:
                start = i
        else:
            # End of a potential subrange
            if start is not None:
                if i - start >= min_length:
                    subranges.append((start, i - 1))
                start = None
    
    # Check if we ended with an ongoing subrange
    if start is not None:
        if len(data) - start >= min_length:
            subranges.append((start, len(data) - 1))

    return subranges

In [None]:
assert find_subranges([1,1,2,4,5,6,2,3,1], 2, 3) == [(2,7)]
assert find_subranges([1,1,2,4,5,6,2,3,1], 2, 20) == []
assert find_subranges([], 2, 20) == []
assert find_subranges([1,2,2,2,3,3,3,3,2,3,3,2,1,3,3,3], 3, 3) == [(4,7), (13,15)]

In [None]:
MIN_LENGTH_DISORDERED_REGION = 20
AIUPRED_THRESHOLD = 0.9

# AIUPred_df['ind_disordered_regions'] = AIUPred_df['AIUPred'].apply(lambda x: find_subranges(x, AIUPRED_THRESHOLD, MIN_LENGTH_DISORDERED_REGION))
# AIUPred_df['num_disordered_regions'] = AIUPred_df['ind_disordered_regions'].apply(len)
# tf_proteins_curated_ds_AIUpred = tf_proteins_curated_ds.merge(AIUPred_df, left_on='Entry', right_on='accession', how='inner')
# tf_proteins_curated_ds_AIUpred_diso = tf_proteins_curated_ds_AIUpred[tf_proteins_curated_ds_AIUpred['num_disordered_regions'] > 0]

# print(f'Num proteins before disorder filter: {len(tf_proteins_curated_ds_AIUpred)}')
# print(f'Num proteins after disorder filter: {len(tf_proteins_curated_ds_AIUpred_diso)}')


#### IUPred 3

In [None]:
import sys
sys.path.append(IUPRED3_PATH)
import iupred3_lib

# sequence = tf_proteins_curated_ds['Sequence'].iloc[0]
# print(sequence)
# iupred3_result = iupred3_lib.iupred(sequence, 'long', smoothing='no')
# print(iupred3_result[0])

In [None]:
# TODO: caching
IUPRED3_THRESHOLD = 0.5

tf_proteins_curated_ds['iupred3'] = tf_proteins_curated_ds['Sequence'].apply(lambda x: iupred3_lib.iupred(x, 'long', smoothing='no')[0])

tf_proteins_curated_ds['num_disordered_regions'] = tf_proteins_curated_ds['iupred3'].apply(lambda x: len(find_subranges(x, IUPRED3_THRESHOLD, MIN_LENGTH_DISORDERED_REGION)))

tf_proteins_curated_ds_IUPred3_diso = tf_proteins_curated_ds[tf_proteins_curated_ds['num_disordered_regions'] > 0]

print(f"Number of transcription factors with at least one disordered region (IUPred3, threshold={IUPRED3_THRESHOLD}, min length={MIN_LENGTH_DISORDERED_REGION}): {len(tf_proteins_curated_ds_IUPred3_diso)}")

#### Disprot

In [None]:
disprot_df = pd.read_csv(DISPROT_PATH, sep='\t')

# make format the same as in uniprot columns
disprot_df['disprot_id'] = disprot_df['disprot_id'].apply(lambda x: x + ';')

tf_disprot_ids = tf_proteins_curated_ds['DisProt'].dropna().tolist()

disprot_tfs = disprot_df[disprot_df['disprot_id'].apply(lambda x: x in tf_disprot_ids)]

In [None]:
tf_proteins_curated_ds_disprot = tf_proteins_curated_ds.merge(disprot_df, how='left', left_on='DisProt', right_on='disprot_id')

### Create all pairs

In [None]:
def create_all_pairs(arm_df: pd.DataFrame, tf_df: pd.DataFrame) -> pd.DataFrame:
    """Create a dataframe with all possible pairs from the cartesian product of the two dataframes arm_df and tf_df.
    
    This function performs a full cartesian join between the armadillo proteins dataframe and the transcription factor proteins
    dataframe, generating all possible combinations between them.

    Args:
        arm_df (pd.DataFrame): DataFrame containing armadillo proteins
        tf_df (pd.DataFrame): DataFrame containing transcription factor proteins

    Returns:
        pd.DataFrame: DataFrame containing all possible pairs between armadillo and transcription factor proteins
        with a unique pair_id column for each combination
    """
    # Create a key for cross join
    arm_df_temp = arm_df.copy()
    tf_df_temp = tf_df.copy()
    
    arm_df_temp['key'] = 1
    tf_df_temp['key'] = 1
    
    # Perform a cross join using the dummy key
    pairs_df = pd.merge(arm_df_temp, tf_df_temp, on='key', suffixes=('_arm', '_tf'))
    
    # Drop the dummy key column
    pairs_df = pairs_df.drop('key', axis=1)
    
    # Create pair_id column for consistency with other functions in the pipeline
    pairs_df['pair_id'] = pairs_df.apply(lambda row: str(tuple(sorted([row['Entry_arm'].upper(), row['Entry_tf'].upper()]))), axis=1)
    
    print(f"Created {len(pairs_df)} possible protein pairs between {len(arm_df)} armadillo proteins and {len(tf_df)} transcription factors")
    
    return pairs_df

In [None]:
all_pairs = create_all_pairs(armadillo_proteins, tf_proteins_curated_ds_IUPred3_diso)

In [None]:
num_pairs_over_token_limit = len(all_pairs[all_pairs['Length_arm'] + all_pairs['Length_tf'] > AF_TOKEN_LIMIT])
print(f"Pairs over token limit: {num_pairs_over_token_limit} ({(num_pairs_over_token_limit/len(all_pairs))*100}%)")

### STRING

In [None]:
# read in the STRING file
# note that the file is a STRING database dump preprocessed with the scripts in /src/STRING 
# it should contain columns p1_Uniprot, p2_Uniprot and pair_id
string_df = pd.read_csv(STRING_PATH, sep=',')

In [None]:
# annotate the all_pairs df with the STRING scores
# IMPORTANT: drop rows that don't have a matching STRING entry
all_pairs_w_STRING = pd.merge(all_pairs, string_df, on='pair_id', how='inner')

# print number of unmatched pairs
unmatched_pairs = all_pairs[~all_pairs['pair_id'].isin(all_pairs_w_STRING['pair_id'])]
num_all_pairs = len(all_pairs)
num_all_pairs_w_STRING = len(all_pairs_w_STRING)
print(f"Number of pairs in all_pairs: {num_all_pairs}")
print(f"Number of pairs successfully merged with STRING data: {num_all_pairs_w_STRING} ({(num_all_pairs_w_STRING/num_all_pairs)*100}%)")

### IntAct

In [None]:
intact_df = pd.read_csv('../../data/IntAct/human/human.txt', sep='\t')

In [None]:
import re

# clean intact dataset:
# create new column 'intact_score' that has the score from 'Confidence value(s)' parsed
# drop unnecessary columns

def intact_score_filter(x: str) -> float:
    """Parse IntAct miscore from confidence value string.
    
    Args:
        x (str): Confidence value string containing intact-miscore
        
    Returns:
        float: Extracted IntAct miscore, or 0.0 if not found
    """
    match = re.search(r"(?<=intact-miscore:)\d\.\d*", x)
    return float(match.group()) if match else 0.0

print(len(intact_df))
intact_cleaned = intact_df.copy()
intact_cleaned.drop(['Alias(es) interactor A', 
                     'Alias(es) interactor B', 
                     'Interaction detection method(s)',
                     'Publication 1st author(s)',
                     'Publication Identifier(s)',
                     'Taxid interactor A',
                     'Taxid interactor B',
                     'Biological role(s) interactor A',
                     'Biological role(s) interactor B',
                     'Experimental role(s) interactor A',
                     'Experimental role(s) interactor B',
                     'Type(s) interactor A',
                     'Type(s) interactor B',
                     'Xref(s) interactor A',
                     'Xref(s) interactor B',
                     'Interaction Xref(s)',
                     'Annotation(s) interactor A',
                     'Annotation(s) interactor B',
                     'Interaction annotation(s)',
                     'Host organism(s)',
                     'Interaction parameter(s)',
                     'Creation date',
                     'Update date',
                     'Checksum(s) interactor A',
                     'Checksum(s) interactor B',
                     'Interaction Checksum(s)',
                     'Feature(s) interactor A',
                     'Feature(s) interactor B',
                     'Stoichiometry(s) interactor A',
                     'Stoichiometry(s) interactor B',
                     'Identification method participant A',
                     'Identification method participant B',
                     'Expansion method(s)'
                     ], axis=1, inplace=True)
# intact_cleaned = intact_df[intact_df['Interaction type(s)'].apply(lambda x: 'direct interaction' in x)]
intact_cleaned.loc[:, 'intact_score'] = intact_cleaned.loc[: , 'Confidence value(s)'].apply(intact_score_filter)
print(len(intact_cleaned))

In [None]:
intact_cleaned['pair_id'] = intact_cleaned.apply(lambda row: str(tuple(sorted([row['#ID(s) interactor A'].replace('uniprotkb:', ''), row['ID(s) interactor B'].replace('uniprotkb:', '')]))), axis=1)

In [None]:
intact_cleaned = intact_cleaned.sort_values('intact_score', ascending=False).drop_duplicates('pair_id', keep='first')
print(f"Number of unique pairs after deduplication: {len(intact_cleaned)}")

In [None]:
all_pairs_w_IntAct = pd.merge(all_pairs, intact_cleaned, on='pair_id', how='inner')

num_all_pairs = len(all_pairs)
num_all_pairs_w_IntAct = len(all_pairs_w_IntAct)
print(f"Number of pairs in all_pairs: {num_all_pairs}")
print(f"Number of pairs successfully merged with IntAct data: {num_all_pairs_w_IntAct} ({(num_all_pairs_w_IntAct/num_all_pairs)*100}%)")

In [None]:
all_pairs_intersect_STRING_IntAct = pd.merge(all_pairs_w_STRING, all_pairs_w_IntAct, on='pair_id', how='inner')

print(len(all_pairs_intersect_STRING_IntAct))

In [None]:
all_pairs_union_STRING_IntAct = pd.merge(all_pairs_w_STRING, all_pairs_w_IntAct, on='pair_id', how='outer')

print(len(all_pairs_union_STRING_IntAct))


### Create Structure Dataset

In [None]:
def print_to_fasta(id: str, seq: str, path: str, comment: str = '') -> None:
    """Create a FASTA file with the given ID and sequence.
    
    This function creates a new FASTA file with the specified ID as the filename
    (with .fasta extension) and writes the sequence in standard FASTA format.
    
    Args:
        id (str): Protein ID to use as both the filename and the FASTA header
        seq (str): Amino acid sequence to write to the file
        path (str): Directory path where the FASTA file should be created
        comment (str, optional): Optional comment to add to the FASTA header. Defaults to ''.
        
    Returns:
        None
        
    Raises:
        OSError: If the directory cannot be created or file cannot be written
    """
    import os
    
    filename = f"{id}.fasta"
    full_path = os.path.join(path, filename)
    
    os.makedirs(path, exist_ok=True)
    
    header = f">{id}"
    if comment:
        header += f'|{comment}\n'
    else:
        header += '\n'
    
    with open(full_path, 'w') as f:
        f.write(header)
        f.write(f"{seq}\n")

In [None]:
for idx, row in armadillo_proteins.iterrows():
    print_to_fasta(row['Entry'], row['Sequence'], '../../production1/arm_all_uniprot_rev_fasta', row['Reviewed'])
# for idx, row in tf_proteins_curated_ds_IUPred3_diso.iterrows():
#     print_to_fasta(row['Entry'], row['Sequence'], '../../production1/tf_all_uniprot_rev_fasta', row['Reviewed'])

## 2. Create AF job files
- create job files for alphafold
- don't create duplicate jobs

### constants

In [None]:
from typing import List, Dict, Any, Union, Tuple
import math, random
import os
import json
import copy

### helper functions

In [None]:
def write_af_jobs_to_individual_files(af_jobs: List[Dict[str, Any]], output_dir: str) -> None:
    """Write each AlphaFold job to an individual file.
    
    Each job is saved as a JSON file in the specified output directory with the job's name as the filename.
    
    Args:
        af_jobs (List[Dict[str, Any]]): List of AlphaFold job dictionaries
        output_dir (str): Directory where job files will be saved
        
    Returns:
        None
        
    Raises:
        SystemExit: If output directory already exists
    """
    if os.path.exists(output_dir):
        print(f"ERROR: output directory '{output_dir}' already exists!")
        print("Aborting.")
        return
    os.makedirs(output_dir, exist_ok=False)
    for job in af_jobs:
        file_name = f"{job['name']}.json"
        file_path = os.path.join(output_dir, file_name)
        with open(file_path, 'w') as f:
            json.dump([job], f, indent=2)  # use [] so AF parser knows it's in alphafoldserver dialect

def sort_rec(obj: Union[List[Any], Dict[str, Any], Any]) -> Any:
    """Sort a list or dictionary recursively.
    
    This function is used to create comparable job representations by sorting all nested structures.
    
    Args:
        obj (Union[List[Any], Dict[str, Any], Any]): Object to sort recursively

    Returns:
        Any: Sorted object with all nested structures sorted
    """
    if isinstance(obj, dict):
        return sorted((k, sort_rec(v)) for k, v in obj.items())
    if isinstance(obj, list):
        return sorted(sort_rec(x) for x in obj)
    else:
        return obj
    
def get_comparable_job(job_data: Dict[str, Any], deep_copy: bool = True) -> Any:
    """Create a comparable representation of the job by removing the name field and sorting the other fields.
    
    This function creates a standardized representation of an AlphaFold job that can be compared
    to other jobs to detect duplicates, regardless of the job name or field order.
    
    Args:
        job_data (Dict[str, Any]): AlphaFold job dictionary
        deep_copy (bool, optional): Whether to create a deep copy of the job data. Defaults to True.

    Returns:
        Any: Comparable representation of the job
    """
    if deep_copy:
        # Create a deep copy to avoid modifying the original
        comparable = copy.deepcopy(job_data)
    else:
        comparable = job_data
    if 'name' in comparable:
        del comparable['name']
    return sort_rec(comparable)

def collect_created_jobs(results_dir: str) -> List[Dict[str, Any]]:
    """Collect all jobs in a directory (all .json files are considered jobs).
    
    IMPORTANT: one file is considered to have one job!
    
    Args:
        results_dir (str): Directory containing AlphaFold job files
        
    Returns:
        List[Dict[str, Any]]: List of AlphaFold job dictionaries
        
    Raises:
        OSError: If directory cannot be accessed
    """
    collected_jobs = []
    
    # Go through all .json files in the results directory (not recursively)
    for file_name in os.listdir(results_dir):
        if file_name.endswith('.json') and os.path.isfile(os.path.join(results_dir, file_name)):
            try:
                with open(os.path.join(results_dir, file_name), 'r') as f:
                    collected_jobs += json.load(f)
                    
            except (json.JSONDecodeError, IOError) as e:
                print(f"Error reading {file_name}: {e}")
                continue
    return collected_jobs

def create_alphafold_job(job_name: str, sequence1: str, sequence2: str, dialect: str = 'alphafoldserver') -> Dict[str, Any]:
    """Create a standardized AlphaFold job dictionary from input parameters.

    Args:
        job_name (str): Name of the job, typically in the format "protein1_start-end_protein2_start-end"
        sequence1 (str): Amino acid sequence of the first protein
        sequence2 (str): Amino acid sequence of the second protein
        dialect (str, optional): The dialect to use for AlphaFold. Defaults to 'alphafoldserver'.

    Returns:
        Dict[str, Any]: A dictionary representing the AlphaFold job in the specified format
    """
    job = {
        'name': job_name,
        'modelSeeds': [],
        'sequences': [
            {
                'proteinChain': {
                    'sequence': sequence1,
                    'count': 1
                }
            },
            {
                'proteinChain': {
                    'sequence': sequence2,
                    'count': 1
                }
            }
        ],
        'dialect': dialect,
        'version': 1,
    }
    return job

def create_job_batch_scoreCategories(pair_df: pd.DataFrame, batch_size: int, categories: List[Tuple[float, float]], 
                                job_dirs: List[str], column_name: str, token_limit: int = 5120) -> List[Dict[str, Any]]:
    """Create a new batch of batch_size jobs from pair_df.
    
    Don't create jobs that have been created previously.
    For each category (range of scores) in categories create 
    batch_size/len(categories) new jobs randomly sampled from all possible jobs in the category.
    If a category does not have enough possible jobs to fill the limit, redistribute to the other
    categories. Filter jobs that are too large.
    
    Note: categories should be specified in a way so the category with the largest number of possible jobs comes at the end of the array.
    
    Args:
        pair_df (pd.DataFrame): DataFrame containing protein pairs with sequences and other information
        batch_size (int): Total number of jobs to create across all categories
        categories (List[Tuple[float, float]]): List of tuples (min_score, max_score) defining score ranges for each category
        job_dirs (List[str]): List of directories to search for existing jobs to avoid duplicates
        column_name (str): Column name in pair_df for the score to filter on
        token_limit (int, optional): Maximum total sequence length for a job. Defaults to 5120.
    
    Returns:
        List[Dict[str, Any]]: List of newly created AlphaFold job dictionaries
        
    Raises:
        KeyError: If required columns are missing from pair_df
    """
    new_jobs = []
    prev_jobs = []
    for dir in job_dirs:
        prev_jobs += collect_created_jobs(dir)
        
    for i in range(len(prev_jobs)):
        prev_jobs[i] = get_comparable_job(prev_jobs[i], deep_copy=False)
    
    total_created = 0
    num_categories = len(categories)
    category_counts = {}  # Dictionary to track jobs created in each category
    
    for category in categories:
        # Stop if we've already created enough jobs
        if total_created >= batch_size:
            break
            
        # dynamically adjust the quota for the next category to account for categories that don't fill their quota
        cat_quota = math.floor((batch_size-total_created)/num_categories)
        
        num_categories -= 1
        
        min_score = category[0]
        max_score = category[1]
        created_in_category = 0
        possible_pairs = pair_df[(pair_df[column_name] >= min_score) & (pair_df[column_name] <= max_score)]
        possible_ind = possible_pairs.index.tolist()
        
        category_key = f"{min_score}-{max_score}"
        category_counts[category_key] = 0
        
        while created_in_category < cat_quota and len(possible_ind) > 0:
            ind = random.choice(possible_ind)
            possible_ind.remove(ind)
            
            row = pair_df.iloc[ind]
            armadillo_entry = row['Entry_arm']
            tf_entry = row['Entry_tf']
            armadillo_sequence = row['Sequence_arm']
            tf_sequence = row['Sequence_tf']
            
            if len(tf_sequence) + len(armadillo_sequence) > token_limit:
                continue
            
            # Calculate indices for the sequences
            armadillo_x, armadillo_y = 1, len(armadillo_sequence)
            tf_x, tf_y = 1, len(tf_sequence)

            # Generate job name in the specified format
            job_name = f"{armadillo_entry}_{armadillo_x}-{armadillo_y}_{tf_entry}_{tf_x}-{tf_y}"
            
            # Create job using the helper function
            job = create_alphafold_job(job_name, armadillo_sequence, tf_sequence)
            
            # check if job was already created earlier
            job_comparable = get_comparable_job(job)
            if not any(existing_comparable == job_comparable for existing_comparable in prev_jobs):
                # no duplicate job found
                created_in_category += 1
                total_created += 1
                category_counts[category_key] += 1
                prev_jobs.append(get_comparable_job(job))
                new_jobs.append(job)
                
                # Ensure we don't exceed batch_size
                if total_created >= batch_size:
                    break
        
    # Print the number of jobs created in each category
    print(f"Created {len(new_jobs)} new jobs total.")
    print("Jobs created per category:")
    for category, count in category_counts.items():
        print(f"  Score range {category}: {count} jobs")
    
    return new_jobs

def create_job_batch_all_pairs(pair_df: pd.DataFrame, batch_size: int, 
                               job_dirs: List[str], token_limit: int = 5120) -> List[Dict[str, Any]]:
    """Create a new batch of AlphaFold jobs by randomly sampling from all protein pairs.
    
    This function creates a specified number of new AlphaFold jobs by randomly selecting protein pairs
    from the input DataFrame. It avoids creating duplicate jobs by checking against existing jobs
    in the specified directories and filters out pairs that exceed the token limit.
    
    Args:
        pair_df (pd.DataFrame): DataFrame containing protein pairs with columns:
                               - 'Entry_arm': Armadillo protein entry ID
                               - 'Entry_tf': Transcription factor protein entry ID  
                               - 'Sequence_arm': Armadillo protein sequence
                               - 'Sequence_tf': Transcription factor protein sequence
        batch_size (int): Total number of new jobs to create
        job_dirs (List[str]): List of directories to search for existing jobs to avoid duplicates
        token_limit (int, optional): Maximum total sequence length allowed for a job. 
                                   Pairs with combined sequence length exceeding this limit
                                   will be skipped. Defaults to 5120.
    
    Returns:
        List[Dict[str, Any]]: List of newly created AlphaFold job dictionaries in alphafoldserver format.
                             Each job dictionary contains:
                             - 'name': Job name in format "proteinA_start-end_proteinB_start-end"
                             - 'sequences': List of protein chain specifications
                             - 'dialect': Set to 'alphafoldserver'
                             - 'version': Set to 1
                             - 'modelSeeds': Empty list
    
    Note:
        - The function randomly samples pairs without replacement until the batch size is reached
        - Pairs exceeding the token limit are automatically skipped
        - Duplicate jobs (based on sequence content) are avoided by comparing against existing jobs
        - Job names follow the format: "{armadillo_entry}_1-{seq_length}_{tf_entry}_1-{seq_length}"
        
    Raises:
        KeyError: If required columns are missing from pair_df
        IndexError: If pair_df is empty or insufficient pairs available
    """
    new_jobs = []
    prev_jobs = []
    for dir in job_dirs:
        prev_jobs += collect_created_jobs(dir)
        
    for i in range(len(prev_jobs)):
        prev_jobs[i] = get_comparable_job(prev_jobs[i], deep_copy=False)
    
    possible_ind = pair_df.index.tolist()   
    
    total_created = 0
    while total_created < batch_size:
        
        ind = random.choice(possible_ind)
        possible_ind.remove(ind)
            
        row = pair_df.iloc[ind]
        armadillo_entry = row['Entry_arm']
        tf_entry = row['Entry_tf']
        armadillo_sequence = row['Sequence_arm']
        tf_sequence = row['Sequence_tf']
            
        if len(tf_sequence) + len(armadillo_sequence) > token_limit:
            continue
            
        # Calculate indices for the sequences
        armadillo_x, armadillo_y = 1, len(armadillo_sequence)
        tf_x, tf_y = 1, len(tf_sequence)

        # Generate job name in the specified format
        job_name = f"{armadillo_entry}_{armadillo_x}-{armadillo_y}_{tf_entry}_{tf_x}-{tf_y}"
        
        # Create job using the helper function
        job = create_alphafold_job(job_name, armadillo_sequence, tf_sequence)
        
        # check if job was already created earlier
        job_comparable = get_comparable_job(job)
        if not any(existing_comparable == job_comparable for existing_comparable in prev_jobs):
            # no duplicate job found
            total_created += 1
            prev_jobs.append(get_comparable_job(job))
            new_jobs.append(job)
        
    # Print the number of jobs created in each category
    print(f"Created {len(new_jobs)} new jobs total.")
    
    return new_jobs

def create_job_batch_id_list(pair_df: pd.DataFrame, id_list: List[Tuple[str, str]], 
                               job_dirs: List[str], token_limit: int = 5120) -> List[Dict[str, Any]]:
    """Create a batch of AlphaFold jobs from a specific list of protein ID pairs.
    
    This function creates AlphaFold jobs for specific protein pairs identified by their UniProt IDs.
    It avoids creating duplicate jobs by checking against existing jobs in the specified directories.
    
    Args:
        pair_df (pd.DataFrame): DataFrame containing protein pairs with columns:
                               - 'Entry_arm': Armadillo protein entry ID
                               - 'Entry_tf': Transcription factor protein entry ID  
                               - 'Sequence_arm': Armadillo protein sequence
                               - 'Sequence_tf': Transcription factor protein sequence
        id_list (List[Tuple[str, str]]): List of tuples containing protein ID pairs to create jobs for
        job_dirs (List[str]): List of directories to search for existing jobs to avoid duplicates
        token_limit (int, optional): Maximum total sequence length allowed for a job. 
                                   Pairs with combined sequence length exceeding this limit
                                   will be skipped. Defaults to 5120.
    
    Returns:
        List[Dict[str, Any]]: List of newly created AlphaFold job dictionaries
        
    Raises:
        Exception: If no matching row is found for a pair or multiple matching rows are found
        KeyError: If required columns are missing from pair_df
    """
    new_jobs = []
    prev_jobs = []
    for dir in job_dirs:
        prev_jobs += collect_created_jobs(dir)
        
    for i in range(len(prev_jobs)):
        prev_jobs[i] = get_comparable_job(prev_jobs[i], deep_copy=False)
        
    total_created = 0
    
    for (id_1, id_2) in id_list:
        
        row_matches = pair_df[((pair_df['Entry_arm'] == id_1) & (pair_df['Entry_tf'] == id_2)) | 
                              ((pair_df['Entry_arm'] == id_2) & (pair_df['Entry_tf'] == id_1))]
        if len(row_matches) == 0:
            raise Exception(f"No matching row found for pair ({id_1}, {id_2}) in pair_df.")
        elif len(row_matches) > 1:
            raise Exception(f"Multiple matching rows found for pair ({id_1}, {id_2}) in pair_df.")
        
        row = row_matches.iloc[0]
        armadillo_entry = row['Entry_arm']
        tf_entry = row['Entry_tf']
        armadillo_sequence = row['Sequence_arm']
        tf_sequence = row['Sequence_tf']
            
        if len(tf_sequence) + len(armadillo_sequence) > token_limit:
            print(f"Skipping because of token limit: {armadillo_entry}-{tf_entry}")
            continue
            
        # Calculate indices for the sequences
        armadillo_x, armadillo_y = 1, len(armadillo_sequence)
        tf_x, tf_y = 1, len(tf_sequence)

        # Generate job name in the specified format
        job_name = f"{armadillo_entry}_{armadillo_x}-{armadillo_y}_{tf_entry}_{tf_x}-{tf_y}"
        
        # Create job using the helper function
        job = create_alphafold_job(job_name, armadillo_sequence, tf_sequence)
        
        # check if job was already created earlier
        job_comparable = get_comparable_job(job)
        if not any(existing_comparable == job_comparable for existing_comparable in prev_jobs):
            # no duplicate job found
            total_created += 1
            prev_jobs.append(get_comparable_job(job))
            new_jobs.append(job)
        
    # Print the number of jobs created in each category
    print(f"Created {len(new_jobs)} new jobs total.")
    
    return new_jobs

### create job files

In [None]:
STRING_SCORE_COLUMN = 'experimental'
INTACT_SCORE_COLUMN = 'intact_score'

In [None]:
BATCH_DIRS = [os.path.join('../../production1/AF_job_batches/', d) for d in os.listdir('../../production1/AF_job_batches') if os.path.isdir(os.path.join('../../production1/AF_job_batches', d)) and 'batch' in d]
BATCH_SIZE = 100

### STRING
# note that the order is important. The category (100,200) is very large so it comes last to fill up the remaining jobs
categories = [(900,1000), (800,900), (700,800), (600,700), (500,600), (400,500), (300,400), (100,200), (0,100), (200,300)]
new_af_jobs = create_job_batch_scoreCategories(all_pairs_w_STRING, BATCH_SIZE, categories, BATCH_DIRS, STRING_SCORE_COLUMN, AF_TOKEN_LIMIT)

### IntAct
# categories = [(0.1,0.2), (0.9,1), (0.8,0.9), (0.7,0.8), (0.6,0.7), (0.5,0.6), (0.4,0.5), (0.2,0.3), (0.3,0.4)]
# new_af_jobs = create_job_batch_scoreCategories(all_pairs_w_IntAct, BATCH_SIZE, categories, BATCH_DIRS, INTACT_SCORE_COLUMN, AF_TOKEN_LIMIT)

### all pairs
# new_af_jobs = create_job_batch_all_pairs(all_pairs, BATCH_SIZE, BATCH_DIRS, AF_TOKEN_LIMIT)


### ID list:
# id_list_good = [
#     ("Q13285", "A0A2R8YCH5"),
#     ("P04637", "A0A8I5KU01"),
#     ("P04637", "A0A8I5KU01"),
#     ("Q9H3D4", "A0A8I5KU01"),
#     ("Q8NHM5", "A1YPR0"),
#     ("Q9UJU2", "A0A2R8YCH5"),
#     ("Q9UJU2", "A0A2R8YCH5"),
#     ("Q6SJ96", "O14981"),
#     ("Q9NRY4", "O00750"),
#     ("Q6ZRS2", "A0A8V8TQN3")
# ]

# id_list_complex = [
#     ("Q03181", "Q9H3U1"),
#     ("P04637", "A0A994J4J0"),
#     ("Q6SJ96", "O14981"),
#     ("P49450", "O14981"),
#     ("P49450", "O14981"),
#     ("P49450", "O14981"),
#     ("P49450", "O14981"),
#     ("Q9UBG7", "A0A8I5KU01"),
#     ("P19838", "A0A1W2PRG6")
# ]

# new_af_jobs = create_job_batch_id_list(all_pairs, id_list_complex, BATCH_DIRS, AF_TOKEN_LIMIT)


write_af_jobs_to_individual_files(new_af_jobs, '../../production1/AF_job_batches/batch_43')

## 3. Analyze AF results

### Constants

In [None]:
HPC_RESULT_DIR = "/home/markus/MPI_local/HPC_results"

### Verify result completeness

In [None]:
def verify_hpc_results(hpc_results_dir: str, required_files_template: List[str]) -> Dict[str, List[str]]:
    """Verify that all required files are present in each subfolder of the HPC_results directory.
    
    This function checks each job folder in the HPC results directory to ensure that all expected
    output files have been generated correctly.

    Args:
        hpc_results_dir (str): Path to the HPC_results directory.
        required_files_template (List[str]): List of required file names with 'JOB_NAME' as a placeholder.

    Returns:
        Dict[str, List[str]]: A dictionary with job names as keys and a list of missing files as values.
    """
    missing_files_report = {}

    for job_folder in os.listdir(hpc_results_dir):
        job_path = os.path.join(hpc_results_dir, job_folder)

        if not os.path.isdir(job_path):
            continue

        # Replace 'JOB_NAME' in the template with the actual job name
        required_files = [file.replace("JOB_NAME", job_folder) for file in required_files_template]

        missing_files = [file for file in required_files if not os.path.exists(os.path.join(job_path, file))]

        if missing_files:
            missing_files_report[job_folder] = missing_files

    return missing_files_report

def missing_files_report() -> None:
    """Generate a report of missing files in the HPC results directory.
    
    This function checks for all required output files from AlphaFold jobs and reports any missing files.
    """
    required_files_template = [
        "JOB_NAME_confidences.json",
        "JOB_NAME_data.json",
        "JOB_NAME_model.cif",
        "JOB_NAME_summary_confidences.json",
        "ranking_scores.csv",
        "seed-1_sample-0",
        "seed-1_sample-1",
        "seed-1_sample-2",
        "seed-1_sample-3",
        "seed-1_sample-4",
    ]
    report = verify_hpc_results(HPC_RESULT_DIR, required_files_template)
    if report:
        print("Missing files detected:")
        for job, files in report.items():
            print(f"Job: {job}, Missing Files: {files}")
    else:
        print("All files are present.")

In [None]:
#missing_files_report()

### Helper functions

In [None]:
from functions_analysis import *

### Negatomes

#### IntAct Negatome

In [None]:
# intact_negative = pd.read_csv('../../data/IntAct/human/human_negative.txt', sep='\t')
# intact_negative.drop(['Alias(es) interactor A', 
#                      'Alias(es) interactor B', 
#                      'Interaction detection method(s)',
#                      'Publication 1st author(s)',
#                      'Publication Identifier(s)',
#                      'Taxid interactor A',
#                      'Taxid interactor B',
#                      'Biological role(s) interactor A',
#                      'Biological role(s) interactor B',
#                      'Experimental role(s) interactor A',
#                      'Experimental role(s) interactor B',
#                      'Type(s) interactor A',
#                      'Type(s) interactor B',
#                      'Xref(s) interactor A',
#                      'Xref(s) interactor B',
#                      'Interaction Xref(s)',
#                      'Annotation(s) interactor A',
#                      'Annotation(s) interactor B',
#                      'Interaction annotation(s)',
#                      'Host organism(s)',
#                      'Interaction parameter(s)',
#                      'Creation date',
#                      'Update date',
#                      'Checksum(s) interactor A',
#                      'Checksum(s) interactor B',
#                      'Interaction Checksum(s)',
#                      'Feature(s) interactor A',
#                      'Feature(s) interactor B',
#                      'Stoichiometry(s) interactor A',
#                      'Stoichiometry(s) interactor B',
#                      'Identification method participant A',
#                      'Identification method participant B',
#                      'Expansion method(s)'
#                      ], axis=1, inplace=True)
# intact_negative.loc[:, 'intact_score'] = intact_negative.loc[: , 'Confidence value(s)'].apply(intact_score_filter)
# intact_negative['pair_id'] = intact_negative.apply(lambda row: str(tuple(sorted([row['#ID(s) interactor A'].replace('uniprotkb:', '').split('-')[0], row['ID(s) interactor B'].replace('uniprotkb:', '').split('-')[0]]))), axis=1)
# intact_negative = intact_negative.sort_values('intact_score', ascending=False).drop_duplicates('pair_id', keep='first')
# all_pairs_intact_negative = pd.merge(all_pairs, intact_negative, on='pair_id', how='inner')
# print(len(all_pairs_intact_negative))

#### Blohm negatome2.0

In [None]:
# negatome2 = pd.read_csv('../../data/negatome2.0/combined.txt', sep='\t', names=['ID_1', 'ID_2'])
# negatome2['pair_id'] = negatome2.apply(lambda row: str(tuple(sorted([row['ID_1'].split('-')[0], row['ID_2'].split('-')[0]]))), axis=1)
# all_pairs_negatome2 = pd.merge(all_pairs, negatome2, on='pair_id', how='inner')
# print(len(all_pairs_negatome2))

#### Stelzl 2005 negatome

In [None]:
# stelzl_neg = pd.read_csv('../../data/16169070_neg.mitab', sep='\t')

### Read in HPC results

In [None]:
# Create DataFrame from all job data
results_df_uc = pd.DataFrame(data=find_summary_files(HPC_RESULT_DIR))

# Print basic information about the DataFrame
print(f"Total jobs processed: {len(results_df_uc)}")

results_df_uc['pair_id'] = results_df_uc.apply(create_pair_id, axis=1)

print(f"jobs before cleaning: {len(results_df_uc)}")
results_df = clean_results(results_df_uc)
print(f"jobs after cleaning: {len(results_df)}")

In [None]:
results_df_annotated = pd.merge(results_df, string_df, on='pair_id', how='left')
results_df_annotated = pd.merge(results_df_annotated, intact_cleaned, on='pair_id', how='left')

# Print information about the merged dataframe
print(f"Total number of modelled pairs: {len(results_df)}")
print(f"Total rows in merged_df: {len(results_df_annotated)}")
print(f"Rows with annotated data (STRING): {results_df_annotated['combined_score'].notna().sum()}")
print(f"Rows with annotated data (IntAct): {results_df_annotated['intact_score'].notna().sum()}")


# convert all STRING scores from 0-1000 to 0-1 (linear conversion)
STRING_COLS = ['experimental', 'database', 'textmining', 'combined_score']
for col in STRING_COLS:
    results_df_annotated[col] = results_df_annotated[col] / 1000

### Comparing AlphaFold ranking scores with STRING combined scores

In [None]:
def create_scatter_plot_colour(df: pd.DataFrame, x_metric: str, y_metric: str, color_metric: str, 
                        title: str = '', cmap: str = 'viridis', alpha: float = 0.7, size: int = 50, ax=None) -> None:
    """Create a scatter plot with two metrics and color by a third metric.
    
    This function creates a scatter plot between two specified metrics with points colored
    by a third metric.
    
    Args:
        df (pd.DataFrame): DataFrame containing the metrics to plot
        x_metric (str): Name of the column to plot on x-axis
        y_metric (str): Name of the column to plot on y-axis
        color_metric (str): Name of the column to use for point colors
        title (str, optional): Custom title for the plot. Defaults to ''.
        cmap (str, optional): Matplotlib colormap name. Defaults to 'viridis'.
        alpha (float, optional): Transparency of points. Defaults to 0.7.
        size (int, optional): Point size. Defaults to 50.
        ax (matplotlib.axes.Axes, optional): Axes object to plot on. If None, creates new figure.
        
    Returns:
        None
    """
    if ax is None:
        plt.figure(figsize=(10, 8))
        ax = plt.gca()
        show_plot = True
    else:
        show_plot = False
    
    scatter = ax.scatter(
        x=df[x_metric], 
        y=df[y_metric], 
        c=df[color_metric],
        cmap=cmap,
        alpha=alpha,
        s=size,
        edgecolors='w'  # White edge to make points stand out
    )
    
    if title == '':
        title = f'{y_metric} vs {x_metric}\nColored by {color_metric}'

    # Add a color bar to show the scale of the color metric
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label(color_metric, fontsize=10)

    # Add labels and title
    ax.set_xlabel(x_metric, fontsize=12)
    ax.set_ylabel(y_metric, fontsize=12)
    ax.set_title(title, fontsize=12)

    # Add a grid for better readability
    ax.grid(True, linestyle='--', alpha=0.7)

    # Add number of datapoints as a text box in the upper left corner
    num_points = len(df)
    ax.annotate(f'Datapoints: {num_points}', xy=(0.05, 0.95), xycoords='axes fraction', 
                fontsize=10, bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8))

    # Show plot only if not using subplots
    if show_plot:
        plt.tight_layout()
        plt.show()


In [None]:
def create_scatter_plot(df: pd.DataFrame, x_metric: str, y_metric: str,
                        title: str = '', alpha: float = 0.7, size: int = 50, ax=None) -> None:
    """Create a scatter plot with two metrics.
    
    This function creates a scatter plot between two specified metrics.

    Args:
        df (pd.DataFrame): DataFrame containing the metrics to plot
        x_metric (str): Name of the column to plot on x-axis
        y_metric (str): Name of the column to plot on y-axis
        title (str, optional): Custom title for the plot. Defaults to ''.
        alpha (float, optional): Transparency of points. Defaults to 0.7.
        size (int, optional): Point size. Defaults to 50.
        ax (matplotlib.axes.Axes, optional): Axes object to plot on. If None, creates new figure.
        
    Returns:
        None
    """
    if ax is None:
        plt.figure(figsize=(10, 8))
        ax = plt.gca()
        show_plot = True
    else:
        show_plot = False
    
    scatter = ax.scatter(
        x=df[x_metric], 
        y=df[y_metric], 
        alpha=alpha,
        s=size,
        edgecolors='w'  # White edge to make points stand out
    )
    
    if title == '':
        title = f'{y_metric} vs {x_metric}'

    # Add labels and title
    ax.set_xlabel(x_metric, fontsize=12)
    ax.set_ylabel(y_metric, fontsize=12)
    ax.set_title(title, fontsize=12)

    # Add a grid for better readability
    ax.grid(True, linestyle='--', alpha=0.7)

    # Add number of datapoints as a text box in the upper left corner
    num_points = len(df)
    ax.annotate(f'Datapoints: {num_points}', xy=(0.05, 0.95), xycoords='axes fraction', 
                fontsize=10, bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8))

    # Show plot only if not using subplots
    if show_plot:
        plt.tight_layout()
        plt.show()


In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot(results_df_annotated.dropna(subset=[STRING_SCORE_COLUMN]), 'iptm', STRING_SCORE_COLUMN, ax=axes[0])
create_scatter_plot(results_df_annotated.dropna(subset=[STRING_SCORE_COLUMN]), 'ptm', STRING_SCORE_COLUMN, ax=axes[1])
create_scatter_plot(results_df_annotated.dropna(subset=[STRING_SCORE_COLUMN]), 'ranking_score', STRING_SCORE_COLUMN, ax=axes[2])

fig.suptitle('AlphaFold Metrics vs STRING score, NAs dropped', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()


fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'iptm', INTACT_SCORE_COLUMN, ax=axes[0])
create_scatter_plot(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'ptm', INTACT_SCORE_COLUMN, ax=axes[1])
create_scatter_plot(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'ranking_score', INTACT_SCORE_COLUMN, ax=axes[2])

fig.suptitle('AlphaFold Metrics vs IntAct score, NAs dropped', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot(results_df_annotated.fillna({STRING_SCORE_COLUMN: 0}), 'iptm', STRING_SCORE_COLUMN, ax=axes[0])
create_scatter_plot(results_df_annotated.fillna({STRING_SCORE_COLUMN: 0}), 'ptm', STRING_SCORE_COLUMN, ax=axes[1])
create_scatter_plot(results_df_annotated.fillna({STRING_SCORE_COLUMN: 0}), 'ranking_score', STRING_SCORE_COLUMN, ax=axes[2])

fig.suptitle('AlphaFold Metrics vs STRING score, NAs treated as 0', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot_colour(results_df_annotated.dropna(subset=[STRING_SCORE_COLUMN]), 'ranking_score', 'iptm', STRING_SCORE_COLUMN, 'ranking_score vs iptm', ax=axes[0])
create_scatter_plot_colour(results_df_annotated.dropna(subset=[STRING_SCORE_COLUMN]), 'ranking_score', 'ptm', STRING_SCORE_COLUMN, 'ranking_score vs ptm', ax=axes[1])
create_scatter_plot_colour(results_df_annotated.dropna(subset=[STRING_SCORE_COLUMN]), 'ptm', 'iptm', STRING_SCORE_COLUMN, 'iptm vs ptm', ax=axes[2])

fig.suptitle('AlphaFold Metrics Colored by STRING Score (experiments), NAs dropped', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()


In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot_colour(results_df_annotated, 'ranking_score', 'iptm', STRING_SCORE_COLUMN, 'ranking_score vs iptm', ax=axes[0])
create_scatter_plot_colour(results_df_annotated, 'ranking_score', 'ptm', STRING_SCORE_COLUMN, 'ranking_score vs ptm', ax=axes[1])
create_scatter_plot_colour(results_df_annotated, 'ptm', 'iptm', STRING_SCORE_COLUMN, 'iptm vs ptm', ax=axes[2])

fig.suptitle('AlphaFold Metrics Colored by STRING Score (experiments); NAs included', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'iptm', INTACT_SCORE_COLUMN, ax=axes[0])
create_scatter_plot(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'ptm', INTACT_SCORE_COLUMN, ax=axes[1])
create_scatter_plot(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'ranking_score', INTACT_SCORE_COLUMN, ax=axes[2])

fig.suptitle('AlphaFold Metrics vs IntAct score, NAs dropped', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot_colour(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'ranking_score', 'iptm', INTACT_SCORE_COLUMN, 'ranking_score vs iptm', ax=axes[0])
create_scatter_plot_colour(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'ranking_score', 'ptm', INTACT_SCORE_COLUMN, 'ranking_score vs ptm', ax=axes[1])
create_scatter_plot_colour(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'ptm', 'iptm', INTACT_SCORE_COLUMN, 'iptm vs ptm', ax=axes[2])

fig.suptitle('AlphaFold Metrics Colored by IntAct Score, NAs dropped', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()


In [None]:
results_df_annotated['avg_STRING_IntAct'] = (results_df_annotated[STRING_SCORE_COLUMN] + results_df_annotated[INTACT_SCORE_COLUMN]) / 2 

fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot_colour(results_df_annotated.dropna(subset=['avg_STRING_IntAct']), 'ranking_score', 'iptm', 'avg_STRING_IntAct', 'ranking_score vs iptm', ax=axes[0])
create_scatter_plot_colour(results_df_annotated.dropna(subset=['avg_STRING_IntAct']), 'ranking_score', 'ptm', 'avg_STRING_IntAct', 'ranking_score vs ptm', ax=axes[1])
create_scatter_plot_colour(results_df_annotated.dropna(subset=['avg_STRING_IntAct']), 'ptm', 'iptm', 'avg_STRING_IntAct', 'iptm vs ptm', ax=axes[2])

fig.suptitle('AlphaFold Metrics Colored by average of STRING and IntAct score, NAs dropped', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()

### ROC curves

creating AUC curves with STRING / IntAct alone does not really make sense. My understanding is that both include only **positive** interaction candidates and assign scores to the certainty. So even a low score means a relatively high probability of interaction because the candidate is in the DB.

In [None]:
import sklearn

def plot_roc(df: pd.DataFrame, pos: Set[str], neg: Set[str], param_name: str, min_val: float, max_val: float, 
             sampling: int = 1000, direction: str = 'up', plot_title: str = '', ax=None) -> Tuple[List[float], List[float]]:
    """Calculate and plot the ROC curve based on a specified parameter.
    
    This function evaluates the performance of a binary classifier by varying a threshold parameter
    and calculating the True Positive Rate (TPR) and False Positive Rate (FPR) at each threshold.
    
    Args:
        df (pd.DataFrame): DataFrame containing the parameter to evaluate and pair_id column
        pos (Set[str]): Set of positive example pair_ids (ground truth positive cases)
        neg (Set[str]): Set of negative example pair_ids (ground truth negative cases)
        param_name (str): Name of the column in df to use as the classification parameter
        min_val (float): Minimum threshold value to evaluate
        max_val (float): Maximum threshold value to evaluate
        sampling (int, optional): Number of threshold points to sample between min and max. Defaults to 1000.
        direction (str, optional): Direction of classification - 'up' means values >= threshold are positive,
                                  'down' means values < threshold are positive. Defaults to 'up'.
        plot_title (str, optional): Title for the ROC curve plot. If empty, a default title is used. Defaults to ''.
        ax (matplotlib.axes.Axes, optional): Axes object to plot on. If None, creates new figure.
                                   
    Returns:
        Tuple[List[float], List[float]]: Lists of FPR and TPR values that make up the ROC curve
    """
    if ax is None:
        # Print the size of positive and negative sets for verification    
        print(f"Number of positive examples: {len(pos)}")
        print(f"Number of negative examples: {len(neg)}")
        plt.figure(figsize=(8, 8))
        ax = plt.gca()
        show_plot = True
    else:
        show_plot = False

    # Calculate step size based on range and sampling
    step = (max_val - min_val) / sampling
    if step <= 0:
        raise ValueError("max_val must be greater than min_val")
    
    # Lists to store TPR and FPR values
    tpr_list = []
    fpr_list = []
    
    for i in range(sampling + 1):
        sep_val = min_val + (i * step)
        
        if direction == 'up':
            calc_pos = set(df[df[param_name] >= sep_val]['pair_id'].tolist())
            calc_neg = set(df[df[param_name] < sep_val]['pair_id'].tolist())
        elif direction == 'down':
            calc_pos = set(df[df[param_name] < sep_val]['pair_id'].tolist())
            calc_neg = set(df[df[param_name] >= sep_val]['pair_id'].tolist())
        else:
            raise ValueError("direction must be either 'up' or 'down'")
        
        # Calculate confusion matrix values
        TP_num = len(calc_pos.intersection(pos))
        FP_num = len(calc_pos.intersection(neg))
        TN_num = len(calc_neg.intersection(neg))
        FN_num = len(calc_neg.intersection(pos))
        
        # Avoid division by zero
        TPR = TP_num / max(TP_num + FN_num, 1)
        FPR = FP_num / max(FP_num + TN_num, 1)
        
        tpr_list.append(TPR)
        fpr_list.append(FPR)
    
    # Plot the ROC curve
    ax.plot(fpr_list, tpr_list, 'b-', linewidth=2)
    ax.plot([0, 1], [0, 1], 'k--', linewidth=2)  # Diagonal line representing random guess
    
    auc_value = sklearn.metrics.auc(fpr_list, tpr_list)
    
    # Add labels and title
    ax.set_xlabel('False Positive Rate', fontsize=12)
    ax.set_ylabel('True Positive Rate', fontsize=12)
    
    if plot_title:
        title = plot_title
    else:
        title = f'ROC Curve for {param_name}'
    
    ax.set_title(f'{title}\nAUC = {auc_value:.3f}', fontsize=12)
    
    # Add grid and improve appearance
    ax.grid(True, linestyle='--', alpha=0.7)
    ax.set_xlim([0, 1])
    ax.set_ylim([0, 1])
    
    # Show plot only if not using subplots
    if show_plot:
        plt.tight_layout()
        plt.show()
    
    return fpr_list, tpr_list

In [None]:
# STRING_SCORE_CUTOFF = 0.4
# # comparisons like >=, < with NAs don't pass the filter
# string_combined_pos = set(results_df_annotated[results_df_annotated[STRING_SCORE_COLUMN] >= STRING_SCORE_CUTOFF]['pair_id'].to_list())
# string_combined_neg = set(results_df_annotated[results_df_annotated[STRING_SCORE_COLUMN] < STRING_SCORE_CUTOFF]['pair_id'].to_list())

# print(f"Number of positive examples: {len(string_combined_pos)}")
# print(f"Number of negative examples: {len(string_combined_neg)}")

# fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# plot_roc(results_df_annotated, string_combined_pos, string_combined_neg, 'ranking_score', 0, 1, plot_title='Ranking Score', ax=axes[0])
# plot_roc(results_df_annotated, string_combined_pos, string_combined_neg, 'iptm', 0, 1, plot_title='iPTM', ax=axes[1])
# plot_roc(results_df_annotated, string_combined_pos, string_combined_neg, 'ptm', 0, 1, plot_title='PTM', ax=axes[2])

# fig.suptitle(f'ROC Curves for AlphaFold Metrics (STRING cutoff >= {STRING_SCORE_CUTOFF})', fontsize=16)
# plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
# plt.show()

In [None]:
# INTACT_SCORE_CUTOFF = 0.4
# # comparisons like >=, < with NAs don't pass the filter
# intact_combined_pos = set(results_df_annotated[results_df_annotated[INTACT_SCORE_COLUMN] >= INTACT_SCORE_CUTOFF]['pair_id'].to_list())
# intact_combined_neg = set(results_df_annotated[results_df_annotated[INTACT_SCORE_COLUMN] < INTACT_SCORE_CUTOFF]['pair_id'].to_list())

# print(f"Number of positive examples: {len(intact_combined_pos)}")
# print(f"Number of negative examples: {len(intact_combined_neg)}")

# fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# plot_roc(results_df_annotated, intact_combined_pos, intact_combined_neg, 'ranking_score', 0, 1, plot_title='Ranking Score', ax=axes[0])
# plot_roc(results_df_annotated, intact_combined_pos, intact_combined_neg, 'iptm', 0, 1, plot_title='iPTM', ax=axes[1])
# plot_roc(results_df_annotated, intact_combined_pos, intact_combined_neg, 'ptm', 0, 1, plot_title='PTM', ax=axes[2])

# fig.suptitle(f'ROC Curves for AlphaFold Metrics (IntAct cutoff >= {INTACT_SCORE_CUTOFF})', fontsize=16)
# plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
# plt.show()