## General Description

1. Create Datasets
2. Create AF jobs
3. analyze AF output

## 1. Create Datasets

### Library imports and helper functions

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from functions_filtering import *

### Constants

In [None]:
STRING_PATH = '/home/markus/MPI_local/data/STRING/9606.protein.physical.links.detailed.v12.0.txt_processed.csv'
PROTEOME_PATH = '/home/markus/MPI_local/data/Proteome/uniprotkb_proteome_UP000005640_2025_05_28.tsv'
# PROTEOME_PATH = '/home/markus/MPI_local/data/full_UP/uniprotkb_AND_reviewed_true_2025_07_10.tsv'
TF_DATASET_PATH = '/home/markus/MPI_local/data/human_TFs/DatabaseExtract_v_1.01.csv'
ENSEMBL_MAPPING_PATH = '/home/markus/MPI_local/data/Ensembl_mapping/Homo_sapiens.GRCh38.114.uniprot.tsv/hps/nobackup/flicek/ensembl/production/release_dumps/release-114/ftp_dumps/vertebrates/tsv/homo_sapiens/Homo_sapiens.GRCh38.114.uniprot.tsv'
AIUPRED_PATH = '/home/markus/MPI_local/data/AIUPred/AIUPred_data.json'
DISPROT_PATH = '/home/markus/MPI_local/data/DisProt/DisProt_release_2024_12 with_ambiguous_evidences.tsv'
IUPRED3_PATH = '../../iupred3'

AF_TOKEN_LIMIT = 5120

### Read in Proteome

In [None]:
all_uniprot = pd.read_csv(PROTEOME_PATH, low_memory=False, sep='\t')
uniprot_filtered = all_uniprot[all_uniprot['Reviewed'] == 'reviewed']
# uniprot_filtered = all_uniprot

### create Armadillo dataset

In [None]:
# import filter lists
from filter_data import arm_accs_ipr, arm_accs_pfam

# strip version number since it is not included in the uniprot annotation
arm_accs_pfam_stripped = [acc.split(".")[0] for acc in arm_accs_pfam]
    
# find proteins with "ARM" in their Repeat column
repeat_mask_arm = uniprot_filtered['Repeat'].apply(lambda x: "ARM" in str(x) if pd.notna(x) else False)
print(f"Proteins with 'ARM' in Repeat column: {len(uniprot_filtered[repeat_mask_arm])}")

# Filter rows where InterPro column contains any interpro_annotations
interpro_mask_arm = uniprot_filtered['InterPro'].apply(lambda x: contains_any_annotation(x, arm_accs_ipr))
print(f"Proteins with specified InterPro annotation: {len(uniprot_filtered[interpro_mask_arm])}")

# Filter rows where Pfam column contains any pfam_annotations
pfam_mask_arm = uniprot_filtered['Pfam'].apply(lambda x: contains_any_annotation(x, arm_accs_pfam_stripped))
print(f"Proteins with specified Pfam annotation: {len(uniprot_filtered[pfam_mask_arm])}")

# apply filters using OR
armadillo_proteins = uniprot_filtered[interpro_mask_arm | pfam_mask_arm | repeat_mask_arm]

print(f"Found {len(armadillo_proteins)} proteins with armadillo domains")

### Create Transcription Factor dataset (UniProtFilter)

In [None]:
# def str_tf(x):
#     return "transcription factor" in x.lower()

# def str_t(x):
#     return "transcription" in x.lower()

# # IPR accessions containing "transcription"
# IPR_entries = pd.read_csv("../entry.list", sep="\t")
# tf_accs_ipr = IPR_entries[IPR_entries['ENTRY_NAME'].apply(lambda x: str_t(x))]["ENTRY_AC"].tolist()

# # PFAM accessions containing "transcription factor"
# PFAM_entries = pd.read_csv("../data/pfam_parsed_data.csv", sep=",")
# tf_accs_pfam = PFAM_entries[PFAM_entries['DE'].apply(lambda x: str_t(x))]["AC"].tolist()

# # strip version number since it is not included in the uniprot annotation
# for i in range(len(tf_accs_pfam)):
#     tf_accs_pfam[i] = tf_accs_pfam[i].split(".")[0]

# interpro_mask_tf = reviewed_proteins['InterPro'].apply(lambda x: contains_any_annotation(x, tf_accs_ipr))
# print(f"Proteins with specified InterPro annotation: {len(reviewed_proteins[interpro_mask_tf])}")

# pfam_mask_tf = reviewed_proteins['Pfam'].apply(lambda x: contains_any_annotation(x, tf_accs_pfam))
# print(f"Proteins with specified Pfam annotation: {len(reviewed_proteins[pfam_mask_tf])}")

# txt_mask_tf = reviewed_proteins['Protein names'].apply(lambda x: str_tf(x))
# print(f"Proteins with 'Transcription factor' in the name: {len(reviewed_proteins[txt_mask_tf])}")


# # Combine filters with OR operation
# tf_proteins_uniprot_ds = reviewed_proteins[interpro_mask_tf | pfam_mask_tf | txt_mask_tf]

# print(f"Found {len(tf_proteins_uniprot_ds)} proteins with transcription factor annotation")

### Use existing Transcription Factor dataset

In [None]:
human_TFs = pd.read_csv(TF_DATASET_PATH)
human_TFs = human_TFs[human_TFs['Is TF?'] == 'Yes']
len(human_TFs)

In [None]:
ensembl_mapping = pd.read_csv(ENSEMBL_MAPPING_PATH, sep='\t')
ensembl_mapping_swissProt = ensembl_mapping[ensembl_mapping['db_name'] == 'Uniprot/SWISSPROT']

In [None]:
human_TFs_gids = human_TFs['Ensembl ID'].tolist()

# Use swiss prot accessions to prevent duplicates
# TODO: use caching
# TODO: adjust ensemble mapping if using other proteome?
human_TF_uniprot_accs = ensembl_mapping_swissProt[ensembl_mapping_swissProt['gene_stable_id'].apply(lambda x: any((id == x) for id in human_TFs_gids))]['xref'].tolist()
print(len(human_TF_uniprot_accs))

tf_proteins_curated_ds = uniprot_filtered[uniprot_filtered['Entry'].apply(lambda x: any((id in x) for id in human_TF_uniprot_accs))]
print(len(tf_proteins_curated_ds))

# 4m 1.7s

#### IUPred 3

In [None]:
IUPRED3_THRESHOLD = 0.5
MIN_LENGTH_DISORDERED_REGION = 20
IUPRED_CACHE_DIR = '/home/markus/MPI_local/production1/IUPred3'

tf_proteins_curated_ds = add_iupred3(tf_proteins_curated_ds, 'long', 'no', IUPRED_CACHE_DIR, IUPRED3_THRESHOLD, MIN_LENGTH_DISORDERED_REGION, IUPRED3_PATH)

tf_proteins_curated_ds_IUPred3_diso = tf_proteins_curated_ds[tf_proteins_curated_ds['num_disordered_regions'] > 0]

print(f"Number of transcription factors with at least one disordered region (IUPred3, threshold={IUPRED3_THRESHOLD}, min length={MIN_LENGTH_DISORDERED_REGION}): {len(tf_proteins_curated_ds_IUPred3_diso)}")

#### Disprot

In [None]:
disprot_df = pd.read_csv(DISPROT_PATH, sep='\t')

# make format the same as in uniprot columns
disprot_df['disprot_id'] = disprot_df['disprot_id'].apply(lambda x: x + ';')

tf_disprot_ids = tf_proteins_curated_ds['DisProt'].dropna().tolist()

disprot_tfs = disprot_df[disprot_df['disprot_id'].apply(lambda x: x in tf_disprot_ids)]

In [None]:
tf_proteins_curated_ds_disprot = tf_proteins_curated_ds.merge(disprot_df, how='left', left_on='DisProt', right_on='disprot_id')

### Create all pairs

In [None]:
all_pairs = create_all_pairs(armadillo_proteins, tf_proteins_curated_ds_IUPred3_diso)

In [None]:
num_pairs_over_token_limit = len(all_pairs[all_pairs['Length_arm'] + all_pairs['Length_tf'] > AF_TOKEN_LIMIT])
print(f"Pairs over token limit: {num_pairs_over_token_limit} ({(num_pairs_over_token_limit/len(all_pairs))*100}%)")

### STRING

In [None]:
# read in the STRING file
# note that the file is a STRING database dump preprocessed with the scripts in /src/STRING 
# it should contain columns p1_Uniprot, p2_Uniprot and pair_id
string_df = pd.read_csv(STRING_PATH, sep=',')

In [None]:
# annotate the all_pairs df with the STRING scores
# IMPORTANT: drop rows that don't have a matching STRING entry
all_pairs_w_STRING = pd.merge(all_pairs, string_df, on='pair_id', how='inner')

# print number of unmatched pairs
unmatched_pairs = all_pairs[~all_pairs['pair_id'].isin(all_pairs_w_STRING['pair_id'])]
num_all_pairs = len(all_pairs)
num_all_pairs_w_STRING = len(all_pairs_w_STRING)
print(f"Number of pairs in all_pairs: {num_all_pairs}")
print(f"Number of pairs successfully merged with STRING data: {num_all_pairs_w_STRING} ({(num_all_pairs_w_STRING/num_all_pairs)*100}%)")

### IntAct

In [None]:
from functions_intact import *

In [None]:
intact_cleaned = read_clean_intact('../../data/IntAct/human/human.txt')

In [None]:
all_pairs_w_IntAct = pd.merge(all_pairs, intact_cleaned, on='pair_id', how='inner')

num_all_pairs = len(all_pairs)
num_all_pairs_w_IntAct = len(all_pairs_w_IntAct)
print(f"Number of pairs in all_pairs: {num_all_pairs}")
print(f"Number of pairs successfully merged with IntAct data: {num_all_pairs_w_IntAct} ({(num_all_pairs_w_IntAct/num_all_pairs)*100}%)")

In [None]:
all_pairs_intersect_STRING_IntAct = pd.merge(all_pairs_w_STRING, all_pairs_w_IntAct, on='pair_id', how='inner')

print(len(all_pairs_intersect_STRING_IntAct))

In [None]:
all_pairs_union_STRING_IntAct = pd.merge(all_pairs_w_STRING, all_pairs_w_IntAct, on='pair_id', how='outer')

print(len(all_pairs_union_STRING_IntAct))

### write to files

In [None]:
# for idx, row in armadillo_proteins.iterrows():
#     print_to_fasta(row['Entry'], row['Sequence'], '../../production1/arm_all_uniprot_rev_fasta', row['Reviewed'])
# for idx, row in tf_proteins_curated_ds_IUPred3_diso.iterrows():
#     print_to_fasta(row['Entry'], row['Sequence'], '../../production1/tf_all_uniprot_rev_fasta', row['Reviewed'])

In [None]:
# armadillo_proteins.to_csv('../../armadillo_proteins.csv', index=False)
# tf_proteins_curated_ds_IUPred3_diso.to_csv('../../transcription_factors.csv', index=False)

## 2. Create AF job files
- create job files for alphafold
- don't create duplicate jobs

### constants

### create job files

In [None]:
STRING_SCORE_COLUMN = 'experimental'
INTACT_SCORE_COLUMN = 'intact_score'

In [None]:
from functions_job_creation import *

In [None]:
BATCH_DIRS = [os.path.join('../../production1/AF_job_batches/', d) for d in os.listdir('../../production1/AF_job_batches') if os.path.isdir(os.path.join('../../production1/AF_job_batches', d)) and 'batch' in d]
BATCH_SIZE = 1

### STRING
# note that the order is important. The category (100,200) is very large so it comes last to fill up the remaining jobs
# categories = [(900,1000), (800,900), (700,800), (600,700), (500,600), (400,500), (300,400), (100,200), (0,100), (200,300)]
# new_af_jobs = create_job_batch_scoreCategories(all_pairs_w_STRING, BATCH_SIZE, categories, BATCH_DIRS, STRING_SCORE_COLUMN, AF_TOKEN_LIMIT)

### IntAct
# categories = [(0.1,0.2), (0.9,1), (0.8,0.9), (0.7,0.8), (0.6,0.7), (0.5,0.6), (0.4,0.5), (0.2,0.3), (0.3,0.4)]
# new_af_jobs = create_job_batch_scoreCategories(all_pairs_w_IntAct, BATCH_SIZE, categories, BATCH_DIRS, INTACT_SCORE_COLUMN, AF_TOKEN_LIMIT)

### all pairs
new_af_jobs = create_job_batch_all_pairs(all_pairs, BATCH_SIZE, BATCH_DIRS, AF_TOKEN_LIMIT)


### ID list:
id_list_good = [
    ("Q13285", "A0A2R8YCH5"),
    ("P04637", "A0A8I5KU01"),
    ("P04637", "A0A8I5KU01"),
    ("Q9H3D4", "A0A8I5KU01"),
    ("Q8NHM5", "A1YPR0"),
    ("Q9UJU2", "A0A2R8YCH5"),
    ("Q9UJU2", "A0A2R8YCH5"),
    ("Q6SJ96", "O14981"),
    ("Q9NRY4", "O00750"),
    ("Q6ZRS2", "A0A8V8TQN3")
]

# id_list_complex = [
#     ("Q03181", "Q9H3U1"),
#     ("P04637", "A0A994J4J0"),
#     ("Q6SJ96", "O14981"),
#     ("P49450", "O14981"),
#     ("P49450", "O14981"),
#     ("P49450", "O14981"),
#     ("P49450", "O14981"),
#     ("Q9UBG7", "A0A8I5KU01"),
#     ("P19838", "A0A1W2PRG6")
# ]

missing = [('Q13285', 'Q6BTZ4'), ('P04637', 'Q9VL06'), ('P04637', 'Q9VL06'), ('Q9UIF8', 'Q54U63'), ('Q15047', 'Q54U63'), ('Q15047', 'Q5R881'), ('Q9H3D4', 'Q9VL06'), ('P49450', 'Q4WJI7'), ('P49450', 'Q4WJI7'), ('P49450', 'Q4WJI7'), ('P49450', 'Q4WJI7'), ('Q6ZRS2', 'P38811'), ('Q6ZRS2', 'P38811')]
# new_af_jobs = create_job_batch_id_list(all_pairs, missing, BATCH_DIRS, AF_TOKEN_LIMIT)


write_af_jobs_to_individual_files(new_af_jobs, '../../production1/AF_job_batches/batch_51')

## 3. Analyze AF results

### Constants

In [None]:
HPC_RESULT_DIR = "/home/markus/MPI_local/HPC_results"

### Helper functions

In [None]:
from functions_analysis import *

### Negatomes

#### IntAct Negatome

In [None]:
# intact_negative = pd.read_csv('../../data/IntAct/human/human_negative.txt', sep='\t')
# intact_negative.drop(['Alias(es) interactor A', 
#                      'Alias(es) interactor B', 
#                      'Interaction detection method(s)',
#                      'Publication 1st author(s)',
#                      'Publication Identifier(s)',
#                      'Taxid interactor A',
#                      'Taxid interactor B',
#                      'Biological role(s) interactor A',
#                      'Biological role(s) interactor B',
#                      'Experimental role(s) interactor A',
#                      'Experimental role(s) interactor B',
#                      'Type(s) interactor A',
#                      'Type(s) interactor B',
#                      'Xref(s) interactor A',
#                      'Xref(s) interactor B',
#                      'Interaction Xref(s)',
#                      'Annotation(s) interactor A',
#                      'Annotation(s) interactor B',
#                      'Interaction annotation(s)',
#                      'Host organism(s)',
#                      'Interaction parameter(s)',
#                      'Creation date',
#                      'Update date',
#                      'Checksum(s) interactor A',
#                      'Checksum(s) interactor B',
#                      'Interaction Checksum(s)',
#                      'Feature(s) interactor A',
#                      'Feature(s) interactor B',
#                      'Stoichiometry(s) interactor A',
#                      'Stoichiometry(s) interactor B',
#                      'Identification method participant A',
#                      'Identification method participant B',
#                      'Expansion method(s)'
#                      ], axis=1, inplace=True)
# intact_negative.loc[:, 'intact_score'] = intact_negative.loc[: , 'Confidence value(s)'].apply(intact_score_filter)
# intact_negative['pair_id'] = intact_negative.apply(lambda row: str(tuple(sorted([row['#ID(s) interactor A'].replace('uniprotkb:', '').split('-')[0], row['ID(s) interactor B'].replace('uniprotkb:', '').split('-')[0]]))), axis=1)
# intact_negative = intact_negative.sort_values('intact_score', ascending=False).drop_duplicates('pair_id', keep='first')
# all_pairs_intact_negative = pd.merge(all_pairs, intact_negative, on='pair_id', how='inner')
# print(len(all_pairs_intact_negative))

#### Blohm negatome2.0

In [None]:
# negatome2 = pd.read_csv('../../data/negatome2.0/combined.txt', sep='\t', names=['ID_1', 'ID_2'])
# negatome2['pair_id'] = negatome2.apply(lambda row: str(tuple(sorted([row['ID_1'].split('-')[0], row['ID_2'].split('-')[0]]))), axis=1)
# all_pairs_negatome2 = pd.merge(all_pairs, negatome2, on='pair_id', how='inner')
# print(len(all_pairs_negatome2))

#### Stelzl 2005 negatome

In [None]:
# stelzl_neg = pd.read_csv('../../data/16169070_neg.mitab', sep='\t')

### Read in HPC results

In [None]:
# Create DataFrame from all job data
results_df_uc = pd.DataFrame(data=find_summary_files(HPC_RESULT_DIR))

# Print basic information about the DataFrame
print(f"Total jobs processed: {len(results_df_uc)}")

results_df_uc['pair_id'] = results_df_uc.apply(create_pair_id, axis=1)

print(f"jobs before cleaning: {len(results_df_uc)}")
results_df = clean_results(results_df_uc)
print(f"jobs after cleaning: {len(results_df)}")

In [None]:
results_df_annotated = pd.merge(results_df, string_df, on='pair_id', how='left')
results_df_annotated = pd.merge(results_df_annotated, intact_cleaned, on='pair_id', how='left')

# Print information about the merged dataframe
print(f"Total number of modelled pairs: {len(results_df)}")
print(f"Total rows in merged_df: {len(results_df_annotated)}")
print(f"Rows with annotated data (STRING): {results_df_annotated['combined_score'].notna().sum()}")
print(f"Rows with annotated data (IntAct): {results_df_annotated['intact_score'].notna().sum()}")


# convert all STRING scores from 0-1000 to 0-1 (linear conversion)
STRING_COLS = ['experimental', 'database', 'textmining', 'combined_score']
for col in STRING_COLS:
    results_df_annotated[col] = results_df_annotated[col] / 1000

### Comparing AlphaFold ranking scores with STRING combined scores

In [None]:
from functions_plotting import *

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot(results_df_annotated.dropna(subset=[STRING_SCORE_COLUMN]), 'iptm', STRING_SCORE_COLUMN, ax=axes[0])
create_scatter_plot(results_df_annotated.dropna(subset=[STRING_SCORE_COLUMN]), 'ptm', STRING_SCORE_COLUMN, ax=axes[1])
create_scatter_plot(results_df_annotated.dropna(subset=[STRING_SCORE_COLUMN]), 'ranking_score', STRING_SCORE_COLUMN, ax=axes[2])

fig.suptitle('AlphaFold Metrics vs STRING score, NAs dropped', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()


fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'iptm', INTACT_SCORE_COLUMN, ax=axes[0])
create_scatter_plot(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'ptm', INTACT_SCORE_COLUMN, ax=axes[1])
create_scatter_plot(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'ranking_score', INTACT_SCORE_COLUMN, ax=axes[2])

fig.suptitle('AlphaFold Metrics vs IntAct score, NAs dropped', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot(results_df_annotated.fillna({STRING_SCORE_COLUMN: 0}), 'iptm', STRING_SCORE_COLUMN, ax=axes[0])
create_scatter_plot(results_df_annotated.fillna({STRING_SCORE_COLUMN: 0}), 'ptm', STRING_SCORE_COLUMN, ax=axes[1])
create_scatter_plot(results_df_annotated.fillna({STRING_SCORE_COLUMN: 0}), 'ranking_score', STRING_SCORE_COLUMN, ax=axes[2])

fig.suptitle('AlphaFold Metrics vs STRING score, NAs treated as 0', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot_colour(results_df_annotated.dropna(subset=[STRING_SCORE_COLUMN]), 'ranking_score', 'iptm', STRING_SCORE_COLUMN, 'ranking_score vs iptm', ax=axes[0])
create_scatter_plot_colour(results_df_annotated.dropna(subset=[STRING_SCORE_COLUMN]), 'ranking_score', 'ptm', STRING_SCORE_COLUMN, 'ranking_score vs ptm', ax=axes[1])
create_scatter_plot_colour(results_df_annotated.dropna(subset=[STRING_SCORE_COLUMN]), 'ptm', 'iptm', STRING_SCORE_COLUMN, 'iptm vs ptm', ax=axes[2])

fig.suptitle('AlphaFold Metrics Colored by STRING Score (experiments), NAs dropped', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()


In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot_colour(results_df_annotated, 'ranking_score', 'iptm', STRING_SCORE_COLUMN, 'ranking_score vs iptm', ax=axes[0])
create_scatter_plot_colour(results_df_annotated, 'ranking_score', 'ptm', STRING_SCORE_COLUMN, 'ranking_score vs ptm', ax=axes[1])
create_scatter_plot_colour(results_df_annotated, 'ptm', 'iptm', STRING_SCORE_COLUMN, 'iptm vs ptm', ax=axes[2])

fig.suptitle('AlphaFold Metrics Colored by STRING Score (experiments); NAs included', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'iptm', INTACT_SCORE_COLUMN, ax=axes[0])
create_scatter_plot(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'ptm', INTACT_SCORE_COLUMN, ax=axes[1])
create_scatter_plot(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'ranking_score', INTACT_SCORE_COLUMN, ax=axes[2])

fig.suptitle('AlphaFold Metrics vs IntAct score, NAs dropped', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot_colour(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'ranking_score', 'iptm', INTACT_SCORE_COLUMN, 'ranking_score vs iptm', ax=axes[0])
create_scatter_plot_colour(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'ranking_score', 'ptm', INTACT_SCORE_COLUMN, 'ranking_score vs ptm', ax=axes[1])
create_scatter_plot_colour(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'ptm', 'iptm', INTACT_SCORE_COLUMN, 'iptm vs ptm', ax=axes[2])

fig.suptitle('AlphaFold Metrics Colored by IntAct Score, NAs dropped', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()


In [None]:
results_df_annotated['avg_STRING_IntAct'] = (results_df_annotated[STRING_SCORE_COLUMN] + results_df_annotated[INTACT_SCORE_COLUMN]) / 2 

fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot_colour(results_df_annotated.dropna(subset=['avg_STRING_IntAct']), 'ranking_score', 'iptm', 'avg_STRING_IntAct', 'ranking_score vs iptm', ax=axes[0])
create_scatter_plot_colour(results_df_annotated.dropna(subset=['avg_STRING_IntAct']), 'ranking_score', 'ptm', 'avg_STRING_IntAct', 'ranking_score vs ptm', ax=axes[1])
create_scatter_plot_colour(results_df_annotated.dropna(subset=['avg_STRING_IntAct']), 'ptm', 'iptm', 'avg_STRING_IntAct', 'iptm vs ptm', ax=axes[2])

fig.suptitle('AlphaFold Metrics Colored by average of STRING and IntAct score, NAs dropped', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()

### ROC curves

creating AUC curves with STRING / IntAct alone does not really make sense. My understanding is that both include only **positive** interaction candidates and assign scores to the certainty. So even a low score means a relatively high probability of interaction because the candidate is in the DB.

In [None]:
import sklearn

def plot_roc(df: pd.DataFrame, pos: Set[str], neg: Set[str], param_name: str, min_val: float, max_val: float, 
             sampling: int = 1000, direction: str = 'up', plot_title: str = '', ax=None) -> Tuple[List[float], List[float]]:
    """Calculate and plot the ROC curve based on a specified parameter.
    
    This function evaluates the performance of a binary classifier by varying a threshold parameter
    and calculating the True Positive Rate (TPR) and False Positive Rate (FPR) at each threshold.
    
    Args:
        df (pd.DataFrame): DataFrame containing the parameter to evaluate and pair_id column
        pos (Set[str]): Set of positive example pair_ids (ground truth positive cases)
        neg (Set[str]): Set of negative example pair_ids (ground truth negative cases)
        param_name (str): Name of the column in df to use as the classification parameter
        min_val (float): Minimum threshold value to evaluate
        max_val (float): Maximum threshold value to evaluate
        sampling (int, optional): Number of threshold points to sample between min and max. Defaults to 1000.
        direction (str, optional): Direction of classification - 'up' means values >= threshold are positive,
                                  'down' means values < threshold are positive. Defaults to 'up'.
        plot_title (str, optional): Title for the ROC curve plot. If empty, a default title is used. Defaults to ''.
        ax (matplotlib.axes.Axes, optional): Axes object to plot on. If None, creates new figure.
                                   
    Returns:
        Tuple[List[float], List[float]]: Lists of FPR and TPR values that make up the ROC curve
    """
    if ax is None:
        # Print the size of positive and negative sets for verification    
        print(f"Number of positive examples: {len(pos)}")
        print(f"Number of negative examples: {len(neg)}")
        plt.figure(figsize=(8, 8))
        ax = plt.gca()
        show_plot = True
    else:
        show_plot = False

    # Calculate step size based on range and sampling
    step = (max_val - min_val) / sampling
    if step <= 0:
        raise ValueError("max_val must be greater than min_val")
    
    # Lists to store TPR and FPR values
    tpr_list = []
    fpr_list = []
    
    for i in range(sampling + 1):
        sep_val = min_val + (i * step)
        
        if direction == 'up':
            calc_pos = set(df[df[param_name] >= sep_val]['pair_id'].tolist())
            calc_neg = set(df[df[param_name] < sep_val]['pair_id'].tolist())
        elif direction == 'down':
            calc_pos = set(df[df[param_name] < sep_val]['pair_id'].tolist())
            calc_neg = set(df[df[param_name] >= sep_val]['pair_id'].tolist())
        else:
            raise ValueError("direction must be either 'up' or 'down'")
        
        # Calculate confusion matrix values
        TP_num = len(calc_pos.intersection(pos))
        FP_num = len(calc_pos.intersection(neg))
        TN_num = len(calc_neg.intersection(neg))
        FN_num = len(calc_neg.intersection(pos))
        
        # Avoid division by zero
        TPR = TP_num / max(TP_num + FN_num, 1)
        FPR = FP_num / max(FP_num + TN_num, 1)
        
        tpr_list.append(TPR)
        fpr_list.append(FPR)
    
    # Plot the ROC curve
    ax.plot(fpr_list, tpr_list, 'b-', linewidth=2)
    ax.plot([0, 1], [0, 1], 'k--', linewidth=2)  # Diagonal line representing random guess
    
    auc_value = sklearn.metrics.auc(fpr_list, tpr_list)
    
    # Add labels and title
    ax.set_xlabel('False Positive Rate', fontsize=12)
    ax.set_ylabel('True Positive Rate', fontsize=12)
    
    if plot_title:
        title = plot_title
    else:
        title = f'ROC Curve for {param_name}'
    
    ax.set_title(f'{title}\nAUC = {auc_value:.3f}', fontsize=12)
    
    # Add grid and improve appearance
    ax.grid(True, linestyle='--', alpha=0.7)
    ax.set_xlim([0, 1])
    ax.set_ylim([0, 1])
    
    # Show plot only if not using subplots
    if show_plot:
        plt.tight_layout()
        plt.show()
    
    return fpr_list, tpr_list

In [None]:
# STRING_SCORE_CUTOFF = 0.4
# # comparisons like >=, < with NAs don't pass the filter
# string_combined_pos = set(results_df_annotated[results_df_annotated[STRING_SCORE_COLUMN] >= STRING_SCORE_CUTOFF]['pair_id'].to_list())
# string_combined_neg = set(results_df_annotated[results_df_annotated[STRING_SCORE_COLUMN] < STRING_SCORE_CUTOFF]['pair_id'].to_list())

# print(f"Number of positive examples: {len(string_combined_pos)}")
# print(f"Number of negative examples: {len(string_combined_neg)}")

# fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# plot_roc(results_df_annotated, string_combined_pos, string_combined_neg, 'ranking_score', 0, 1, plot_title='Ranking Score', ax=axes[0])
# plot_roc(results_df_annotated, string_combined_pos, string_combined_neg, 'iptm', 0, 1, plot_title='iPTM', ax=axes[1])
# plot_roc(results_df_annotated, string_combined_pos, string_combined_neg, 'ptm', 0, 1, plot_title='PTM', ax=axes[2])

# fig.suptitle(f'ROC Curves for AlphaFold Metrics (STRING cutoff >= {STRING_SCORE_CUTOFF})', fontsize=16)
# plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
# plt.show()

In [None]:
# INTACT_SCORE_CUTOFF = 0.4
# # comparisons like >=, < with NAs don't pass the filter
# intact_combined_pos = set(results_df_annotated[results_df_annotated[INTACT_SCORE_COLUMN] >= INTACT_SCORE_CUTOFF]['pair_id'].to_list())
# intact_combined_neg = set(results_df_annotated[results_df_annotated[INTACT_SCORE_COLUMN] < INTACT_SCORE_CUTOFF]['pair_id'].to_list())

# print(f"Number of positive examples: {len(intact_combined_pos)}")
# print(f"Number of negative examples: {len(intact_combined_neg)}")

# fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# plot_roc(results_df_annotated, intact_combined_pos, intact_combined_neg, 'ranking_score', 0, 1, plot_title='Ranking Score', ax=axes[0])
# plot_roc(results_df_annotated, intact_combined_pos, intact_combined_neg, 'iptm', 0, 1, plot_title='iPTM', ax=axes[1])
# plot_roc(results_df_annotated, intact_combined_pos, intact_combined_neg, 'ptm', 0, 1, plot_title='PTM', ax=axes[2])

# fig.suptitle(f'ROC Curves for AlphaFold Metrics (IntAct cutoff >= {INTACT_SCORE_CUTOFF})', fontsize=16)
# plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
# plt.show()