## General Description

1. Create Datasets
2. Create AF jobs
3. analyze AF output

## 1. Create Datasets

### Library imports and helper functions

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from functions_filtering import *

### Constants

In [None]:
STRING_PATH = '/home/markus/MPI_local/data/STRING/9606.protein.physical.links.detailed.v12.0.txt_processed.csv'
PROTEOME_PATH = '/home/markus/MPI_local/data/Proteome/uniprotkb_proteome_UP000005640_2025_05_28.tsv'
# PROTEOME_PATH = '/home/markus/MPI_local/data/full_UP/uniprotkb_AND_reviewed_true_2025_07_10.tsv'
TF_DATASET_PATH = '/home/markus/MPI_local/data/human_TFs/DatabaseExtract_v_1.01.csv'
ENSEMBL_MAPPING_PATH = '/home/markus/MPI_local/data/Ensembl_mapping/Homo_sapiens.GRCh38.114.uniprot.tsv/hps/nobackup/flicek/ensembl/production/release_dumps/release-114/ftp_dumps/vertebrates/tsv/homo_sapiens/Homo_sapiens.GRCh38.114.uniprot.tsv'
AIUPRED_PATH = '/home/markus/MPI_local/data/AIUPred/AIUPred_data.json'
DISPROT_PATH = '/home/markus/MPI_local/data/DisProt/DisProt_release_2024_12 with_ambiguous_evidences.tsv'

import importlib
import constants
importlib.reload(constants)
from constants import *

### Read in Proteome

In [None]:
all_uniprot = pd.read_csv(PROTEOME_PATH, low_memory=False, sep='\t')
uniprot_filtered = all_uniprot[all_uniprot['Reviewed'] == 'reviewed']
# uniprot_filtered = all_uniprot

### create Armadillo dataset

In [None]:
# strip version number since it is not included in the uniprot annotation
arm_accs_pfam_stripped = [acc.split(".")[0] for acc in arm_accs_pfam]
    
# find proteins with "ARM" in their Repeat column
repeat_mask_arm = uniprot_filtered['Repeat'].apply(lambda x: "ARM" in str(x) if pd.notna(x) else False)
print(f"Proteins with 'ARM' in Repeat column: {len(uniprot_filtered[repeat_mask_arm])}")

# Filter rows where InterPro column contains any interpro_annotations
interpro_mask_arm = uniprot_filtered['InterPro'].apply(lambda x: contains_any_annotation(x, arm_accs_ipr))
print(f"Proteins with specified InterPro annotation: {len(uniprot_filtered[interpro_mask_arm])}")

# Filter rows where Pfam column contains any pfam_annotations
pfam_mask_arm = uniprot_filtered['Pfam'].apply(lambda x: contains_any_annotation(x, arm_accs_pfam_stripped))
print(f"Proteins with specified Pfam annotation: {len(uniprot_filtered[pfam_mask_arm])}")

# apply filters using OR
armadillo_proteins = uniprot_filtered[interpro_mask_arm | pfam_mask_arm | repeat_mask_arm]

print(f"Found {len(armadillo_proteins)} proteins with armadillo domains")

### Create Transcription Factor dataset (UniProtFilter)

In [None]:
# def str_tf(x):
#     return "transcription factor" in x.lower()

# def str_t(x):
#     return "transcription" in x.lower()

# # IPR accessions containing "transcription"
# IPR_entries = pd.read_csv("../entry.list", sep="\t")
# tf_accs_ipr = IPR_entries[IPR_entries['ENTRY_NAME'].apply(lambda x: str_t(x))]["ENTRY_AC"].tolist()

# # PFAM accessions containing "transcription factor"
# PFAM_entries = pd.read_csv("../data/pfam_parsed_data.csv", sep=",")
# tf_accs_pfam = PFAM_entries[PFAM_entries['DE'].apply(lambda x: str_t(x))]["AC"].tolist()

# # strip version number since it is not included in the uniprot annotation
# for i in range(len(tf_accs_pfam)):
#     tf_accs_pfam[i] = tf_accs_pfam[i].split(".")[0]

# interpro_mask_tf = reviewed_proteins['InterPro'].apply(lambda x: contains_any_annotation(x, tf_accs_ipr))
# print(f"Proteins with specified InterPro annotation: {len(reviewed_proteins[interpro_mask_tf])}")

# pfam_mask_tf = reviewed_proteins['Pfam'].apply(lambda x: contains_any_annotation(x, tf_accs_pfam))
# print(f"Proteins with specified Pfam annotation: {len(reviewed_proteins[pfam_mask_tf])}")

# txt_mask_tf = reviewed_proteins['Protein names'].apply(lambda x: str_tf(x))
# print(f"Proteins with 'Transcription factor' in the name: {len(reviewed_proteins[txt_mask_tf])}")


# # Combine filters with OR operation
# tf_proteins_uniprot_ds = reviewed_proteins[interpro_mask_tf | pfam_mask_tf | txt_mask_tf]

# print(f"Found {len(tf_proteins_uniprot_ds)} proteins with transcription factor annotation")

### Use existing Transcription Factor dataset

In [None]:
human_TFs = pd.read_csv(TF_DATASET_PATH)
human_TFs = human_TFs[human_TFs['Is TF?'] == 'Yes']
len(human_TFs)

In [None]:
ensembl_mapping = pd.read_csv(ENSEMBL_MAPPING_PATH, sep='\t')
ensembl_mapping_swissProt = ensembl_mapping[ensembl_mapping['db_name'] == 'Uniprot/SWISSPROT']

In [None]:
human_TFs_gids = human_TFs['Ensembl ID'].tolist()

# Use swiss prot accessions to prevent duplicates
# TODO: use caching
# TODO: adjust ensemble mapping if using other proteome?
human_TF_uniprot_accs = ensembl_mapping_swissProt[ensembl_mapping_swissProt['gene_stable_id'].apply(lambda x: any((id == x) for id in human_TFs_gids))]['xref'].tolist()
print(len(human_TF_uniprot_accs))

tf_proteins_curated_ds = uniprot_filtered[uniprot_filtered['Entry'].apply(lambda x: any((id in x) for id in human_TF_uniprot_accs))]
print(len(tf_proteins_curated_ds))

# 4m 1.7s

#### IUPred 3

In [None]:
tf_proteins_curated_ds = add_iupred3(tf_proteins_curated_ds, 'long', 'no', IUPRED_CACHE_DIR, IUPRED3_THRESHOLD, MIN_LENGTH_DISORDERED_REGION, IUPRED3_PATH)

tf_proteins_curated_ds_IUPred3_diso = tf_proteins_curated_ds[tf_proteins_curated_ds['num_disordered_regions'] > 0]

print(f"Number of transcription factors with at least one disordered region (IUPred3, threshold={IUPRED3_THRESHOLD}, min length={MIN_LENGTH_DISORDERED_REGION}): {len(tf_proteins_curated_ds_IUPred3_diso)}")

#### Disprot

In [None]:
disprot_df = pd.read_csv(DISPROT_PATH, sep='\t')

# make format the same as in uniprot columns
disprot_df['disprot_id'] = disprot_df['disprot_id'].apply(lambda x: x + ';')

tf_disprot_ids = tf_proteins_curated_ds['DisProt'].dropna().tolist()

disprot_tfs = disprot_df[disprot_df['disprot_id'].apply(lambda x: x in tf_disprot_ids)]

In [None]:
tf_proteins_curated_ds_disprot = tf_proteins_curated_ds.merge(disprot_df, how='left', left_on='DisProt', right_on='disprot_id')

### Create all pairs

In [None]:
all_pairs = create_all_pairs(armadillo_proteins, tf_proteins_curated_ds_IUPred3_diso)

In [None]:
num_pairs_over_token_limit = len(all_pairs[all_pairs['Length_arm'] + all_pairs['Length_tf'] > AF_TOKEN_LIMIT])
print(f"Pairs over token limit: {num_pairs_over_token_limit} ({(num_pairs_over_token_limit/len(all_pairs))*100}%)")

### STRING

In [None]:
# read in the STRING file
# note that the file is a STRING database dump preprocessed with the scripts in /src/STRING 
# it should contain columns p1_Uniprot, p2_Uniprot and pair_id
string_df = pd.read_csv(STRING_PATH, sep=',')

In [None]:
# annotate the all_pairs df with the STRING scores
# IMPORTANT: drop rows that don't have a matching STRING entry
all_pairs_w_STRING = pd.merge(all_pairs, string_df, on='pair_id', how='inner')

# print number of unmatched pairs
unmatched_pairs = all_pairs[~all_pairs['pair_id'].isin(all_pairs_w_STRING['pair_id'])]
num_all_pairs = len(all_pairs)
num_all_pairs_w_STRING = len(all_pairs_w_STRING)
print(f"Number of pairs in all_pairs: {num_all_pairs}")
print(f"Number of pairs successfully merged with STRING data: {num_all_pairs_w_STRING} ({(num_all_pairs_w_STRING/num_all_pairs)*100}%)")

### IntAct

In [None]:
from functions_intact import *

In [None]:
intact_cleaned = read_clean_intact('../../data/IntAct/human/human.txt')

In [None]:
all_pairs_w_IntAct = pd.merge(all_pairs, intact_cleaned, on='pair_id', how='inner')

num_all_pairs = len(all_pairs)
num_all_pairs_w_IntAct = len(all_pairs_w_IntAct)
print(f"Number of pairs in all_pairs: {num_all_pairs}")
print(f"Number of pairs successfully merged with IntAct data: {num_all_pairs_w_IntAct} ({(num_all_pairs_w_IntAct/num_all_pairs)*100}%)")

In [None]:
all_pairs_intersect_STRING_IntAct = pd.merge(all_pairs_w_STRING, all_pairs_w_IntAct, on='pair_id', how='inner')

print(len(all_pairs_intersect_STRING_IntAct))

In [None]:
all_pairs_union_STRING_IntAct = pd.merge(all_pairs_w_STRING, all_pairs_w_IntAct, on='pair_id', how='outer')

print(len(all_pairs_union_STRING_IntAct))

### write to files

In [None]:
# for idx, row in armadillo_proteins.iterrows():
#     print_to_fasta(row['Entry'], row['Sequence'], '../../production1/arm_all_uniprot_rev_fasta', row['Reviewed'])
# for idx, row in tf_proteins_curated_ds_IUPred3_diso.iterrows():
#     print_to_fasta(row['Entry'], row['Sequence'], '../../production1/tf_all_uniprot_rev_fasta', row['Reviewed'])

In [None]:
# armadillo_proteins.to_csv('../../armadillo_proteins.csv', index=False)
# tf_proteins_curated_ds_IUPred3_diso.to_csv('../../transcription_factors.csv', index=False)

## 1.1 PDB reports

### imports an functions

In [None]:
import functions_analysis
import functions_job_creation
import functions_filtering
import functions_plotting
import download_functions
import importlib

# Reload the module
importlib.reload(functions_analysis)
importlib.reload(functions_job_creation)
importlib.reload(functions_filtering)
importlib.reload(download_functions)
importlib.reload(functions_plotting)

# Step 2: Re-import everything you need
from functions_analysis import *
from functions_job_creation import *
from download_functions import *
from functions_filtering import *
from functions_plotting import *


In [None]:
def pdb_report_arm_filter(pdb_report: pd.DataFrame, armadillo_proteins: pd.DataFrame) -> pd.DataFrame:
    # filter entries that have at least one ARM, add column isARM = True|False
    keep_pdbs = set()
    armadillo_entries = armadillo_proteins['Entry'].tolist()
    pdb_report['isARM'] = False

    for ind, row in pdb_report.iterrows():
        if pd.notna(row['Accession Code(s)']) and row['Accession Code(s)'] in armadillo_entries:
            keep_pdbs.add(row['Entry ID'])
            pdb_report.at[ind, 'isARM'] = True
    
    pdb_report = pdb_report[pdb_report['Entry ID'].isin(keep_pdbs)]
    return pdb_report

def pdb_report_tf_filter(pdb_report: pd.DataFrame, tf_proteins: pd.DataFrame) -> pd.DataFrame:
    keep_pdbs = set()
    tf_entries = tf_proteins['Entry'].tolist()
    pdb_report['isDisoTF'] = False
    

    for ind, row in pdb_report.iterrows():
        if pd.notna(row['Accession Code(s)']) and row['Accession Code(s)'] in tf_entries:
            keep_pdbs.add(row['Entry ID'])
            pdb_report.at[ind, 'isDisoTF'] = True
    
    pdb_report = pdb_report[pdb_report['Entry ID'].isin(keep_pdbs)]
    return pdb_report

def pdb_report_disorder_filter(pdb_report: pd.DataFrame) -> pd.DataFrame:
    # filter for entries that have at least one protein that is not ARm and has a disordered region
    keep_pdbs = set()

    for _, row in pdb_report.iterrows():
        if row['isARM'] == False and row['num_disordered_regions'] > 0:
            keep_pdbs.add(row['Entry ID'])
            
    pdb_report = pdb_report[pdb_report['Entry ID'].isin(keep_pdbs)]
    return pdb_report

In [None]:
def annotate_AF_metrics(report_df: pd.DataFrame, results_dir: str) -> pd.DataFrame:
    """Annotate a report DataFrame with AlphaFold metrics from job results.
    
    This function takes a report DataFrame and annotates it with AlphaFold metrics
    (iptm, ptm, ranking_score) by matching entries with completed jobs in the results directory.
    
    Args:
        report_df (pd.DataFrame): DataFrame to annotate, should contain columns needed to construct job names
        results_dir (str): Directory containing AlphaFold job results
        
    Returns:
        pd.DataFrame: Report DataFrame with added af_iptm, af_ptm, af_ranking_score columns
    """
    report_df['job_name'] = report_df['job_name'].apply(lambda x: str(x).lower() if not pd.isna(x) else x)
    results_df = pd.DataFrame(data=find_summary_files(results_dir))
    results_df = clean_results(results_df)
    
    report_df = report_df.merge(results_df, on='job_name')
    
    return report_df

### report 1

In [None]:
pdb_report_1 = pd.read_csv('/home/markus/MPI_local/data/PDB_reports/1/combined_pdb_reports_processed.csv', low_memory=False)

In [None]:
# Get Entry IDs from both datasets
pdb_entry_ids = set(pdb_report_1['Accession Code(s)'].dropna())
tf_entry_ids = set(tf_proteins_curated_ds_IUPred3_diso['Entry'])

# Find intersection
common_entries = pdb_entry_ids.intersection(tf_entry_ids)

print(f"Number of Entry IDs in pdb_report_1_two_seq: {len(pdb_entry_ids)}")
print(f"Number of Entry IDs in tf_proteins_curated_ds: {len(tf_entry_ids)}")
print(f"Number of Entry IDs that appear in both datasets: {len(common_entries)}")
print(f"Percentage of PDB entries that are also TFs: {len(common_entries)/len(pdb_entry_ids)*100:.2f}%")

In [None]:
# use only entries where at least one Uniprot ID is in list of disordered TFs
pdb_report_1 = pdb_report_tf_filter(pdb_report_1, tf_proteins_curated_ds_IUPred3_diso)

In [None]:
# filter entries that have at least one ARM, add column isARM = True|False
pdb_report_1 = pdb_report_arm_filter(pdb_report_1, armadillo_proteins)
# annotate with iupred3
# pdb_report_1 = add_iupred3(pdb_report_1, 'long', 'no', IUPRED_CACHE_DIR, IUPRED3_THRESHOLD, MIN_LENGTH_DISORDERED_REGION, IUPRED3_PATH)

In [None]:
# find entries that appear exactly twice => arm interacting with tf (since one is arm and the other must be tf)
entry_counts = pdb_report_1['Entry ID'].value_counts()
entries_appearing_twice = entry_counts[entry_counts == 2].index.tolist()
pdb_report_1_seq2 = pdb_report_1[pdb_report_1['Entry ID'].isin(entries_appearing_twice)]

In [None]:
# filter pairs where ARM and disordered TF is the same protein
keep_pdbs = set()

for ind, row in pdb_report_1.iterrows():
    if not (row['isARM'] == True and row['isDisoTF'] == True):
        keep_pdbs.add(row['Entry ID'])

pdb_report_1 = pdb_report_1[pdb_report_1['Entry ID'].isin(keep_pdbs)]

In [None]:
# now we have all pairs that have one ARM partner => the other protein must be the TF (candidate), since that was in the original search criteria
# now filter for disordered regions
# pdb_report_1 = pdb_report_disorder_filter(pdb_report_1_seq2)

In [None]:
download_pdb_structures(set(pdb_report_1['Entry ID'].tolist()))

In [None]:
entries_appearing_3 = entry_counts[entry_counts == 3].index.tolist()
pdb_report_1_seq3 = pdb_report_1[pdb_report_1['Entry ID'].isin(entries_appearing_3)]

In [None]:
# create pairs by separating into arm and tf half
report_1_arm = pdb_report_1_seq2[(pdb_report_1_seq2['isARM'] == True) & (pdb_report_1_seq2['Total Number of polymer Entity Instances (Chains) per Entity'] == 1)]
report_1_tf = pdb_report_1_seq2[(pdb_report_1_seq2['isARM'] == False) & (pdb_report_1_seq2['Total Number of polymer Entity Instances (Chains) per Entity'] == 1)]
# report_1_arm = pdb_report_1_seq2[(pdb_report_1_seq2['isARM'] == True)]
# report_1_tf = pdb_report_1_seq2[(pdb_report_1_seq2['isARM'] == False)]

report_1_pairs = pd.merge(left=report_1_tf, right=report_1_arm, on='Entry ID', suffixes=['_tf', '_arm'])
NATIVE_PATH_PREFIX = "/home/markus/MPI_local/data/PDB/"
HPC_FULL_RESULTS_DIR = "/home/markus/MPI_local/HPC_results_full"
PDB_CACHE = '../../production1/pdb_cache'
DOCKQ_CACHE = '../../production1/dockq_cache'
# rename columns for compatibility with annotate_dockq()
report_1_pairs = report_1_pairs.rename(columns={'Entry ID': 'pdb_id'})

# print_dockq(report_1_pairs, NATIVE_PATH_PREFIX, HPC_FULL_RESULTS_DIR, all_uniprot, 'pdb', PDB_CACHE, DOCKQ_CACHE)

report_1_pairs = append_dockq(report_1_pairs, NATIVE_PATH_PREFIX, HPC_FULL_RESULTS_DIR, all_uniprot, 'pdb', PDB_CACHE, DOCKQ_CACHE)

In [None]:
# Sort by dockq column and print the requested information
report_1_pairs_sorted = report_1_pairs.sort_values('dockq_score', ascending=False)

print("PDB_ID\t\tDockQ\t\tRelease Date\t\tJob Name")
print("-" * 70)

for idx, row in report_1_pairs_sorted.iterrows():
    pdb_id = row['pdb_id']
    dockq = row['dockq_score'] if pd.notna(row['dockq_score']) else 'N/A'
    release_date = row['Release Date_tf'] + " " + row['Release Date_arm']  # Using tf release date
    job_name = row.get('job_name', 'N/A')  # Use get() in case column doesn't exist

    print(f"{pdb_id}\t\t{dockq}\t\t{release_date}\t\t{job_name}")

In [None]:
report_1_pairs = annotate_AF_metrics(report_1_pairs, '/home/markus/MPI_local/HPC_results_full')

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
create_scatter_plot(report_1_pairs, 'iptm', 'dockq_score', ax=axes[0], corr=True)
create_scatter_plot(report_1_pairs, 'ptm', 'dockq_score', ax=axes[1],corr=True)
create_scatter_plot(report_1_pairs, 'ranking_score', 'dockq_score', ax=axes[2], corr=True)

### report 2

In [None]:
pdb_report_2 = pd.read_csv('/home/markus/MPI_local/data/PDB_reports/2/combined_pdb_reports_processed.csv', low_memory=False)

In [None]:
# filter entries that have at least one ARM, add column isARM = True|False
pdb_report_2 = pdb_report_arm_filter(pdb_report_2, armadillo_proteins)
# annotate with iupred3
pdb_report_2 = add_iupred3(pdb_report_2, 'long', 'no', IUPRED_CACHE_DIR, IUPRED3_THRESHOLD, MIN_LENGTH_DISORDERED_REGION, IUPRED3_PATH)

In [None]:
# find entries that appear exactly twice => arm interacting with tf (since one is arm and the other must be tf)
entry_counts = pdb_report_2['Entry ID'].value_counts()
entries_appearing_twice = entry_counts[entry_counts == 2].index.tolist()
pdb_report_2_seq2 = pdb_report_2[pdb_report_2['Entry ID'].isin(entries_appearing_twice)]

In [None]:
# now we have all pairs that have one ARM partner => the other protein must be the TF (candidate), since that was in the original search criteria
# now filter for disordered regions
pdb_report_2_seq2 = pdb_report_disorder_filter(pdb_report_2_seq2)

In [None]:
download_pdb_structures(set(pdb_report_2_seq2['Entry ID'].tolist()))

In [None]:
report_2_arm = pdb_report_2_seq2[(pdb_report_2_seq2['isARM'] == True) & (pdb_report_2_seq2['Total Number of polymer Entity Instances (Chains) per Entity'] == 1)]
report_2_diso = pdb_report_2_seq2[(pdb_report_2_seq2['isARM'] == False) & (pdb_report_2_seq2['Total Number of polymer Entity Instances (Chains) per Entity'] == 1)]

report_2_pairs = pd.merge(left=report_2_diso, right=report_2_arm, on='Entry ID', suffixes=['_diso', '_arm'])
NATIVE_PATH_PREFIX = "/home/markus/MPI_local/data/PDB/"
HPC_FULL_RESULTS_DIR = "/home/markus/MPI_local/HPC_results_full"
PDB_CACHE = '../../production1/pdb_cache'
DOCKQ_CACHE = '../../production1/dockq_cache'
# rename columns for compatibility with annotate_dockq()
report_2_pairs.rename(columns={'Entry ID': 'pdb_id'}, inplace=True)
report_2_pairs = append_dockq(report_2_pairs, NATIVE_PATH_PREFIX, HPC_FULL_RESULTS_DIR, all_uniprot, 'pdb', PDB_CACHE, DOCKQ_CACHE)
report_2_pairs.rename(columns={'pdb_id': 'Entry ID'}, inplace=True)

In [None]:
# Sort by dockq column and print the requested information
report_2_pairs_sorted = report_2_pairs.sort_values('dockq_score', ascending=False)

print("PDB_ID\t\tDockQ\t\tRelease Date\t\tJob Name")
print("-" * 70)

for idx, row in report_2_pairs_sorted.iterrows():
    pdb_id = row['Entry ID']
    dockq = row['dockq_score'] if pd.notna(row['dockq_score']) else 'N/A'
    release_date = row['Release Date_diso']
    
    print(f"{pdb_id}\t\t{dockq}\t\t{release_date}")

In [None]:
report_2_pairs = annotate_AF_metrics(report_2_pairs, '/home/markus/MPI_local/HPC_results_full')

In [None]:
report_2_pairs['in_training_set'] = report_2_pairs['Release Date_diso'] <= AF_TRAINING_CUTOFF

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
create_scatter_plot_colour(report_2_pairs, 'iptm', 'dockq_score', 'in_training_set', ax=axes[0])
create_scatter_plot_colour(report_2_pairs, 'ptm', 'dockq_score', 'in_training_set', ax=axes[1])
create_scatter_plot(report_2_pairs, 'ranking_score', 'dockq_score', ax=axes[2], corr=True)

### job creation

In [None]:
BATCH_DIRS = [os.path.join('../../production1/AF_job_batches/', d) for d in os.listdir('../../production1/AF_job_batches') if os.path.isdir(os.path.join('../../production1/AF_job_batches', d)) and 'batch' in d]
BATCH_DIRS.extend([os.path.join('../../production1/PDB_modelling/', d) for d in os.listdir('../../production1/PDB_modelling') if os.path.isdir(os.path.join('../../production1/PDB_modelling', d)) and 'batch' in d])


# IMPORTANT: check for duplicates that were modelled in production1! They may need to be downloaded extra

job_list = create_job_list_from_filtered_report(pdb_report_2_seq2)
new_af_jobs = create_job_batch_sequences_dict(job_list, BATCH_DIRS, 12000)
write_af_jobs_to_individual_files(new_af_jobs, '../../production1/PDB_modelling/batch_7')

## 1.2 Hadeer approach

In [None]:
hadeer_df = pd.read_csv('/home/markus/MPI_local/downloads/transcription_factors_pdbs.csv', sep='\t')

In [None]:
import os
import glob
from pathlib import Path

def check_interface(pdb_id, chain_X, chain_Y, data_dir, min_atoms=10, max_distance=5) -> bool:
    """check if in the specified pdb there is an interface between chain_A and chain_B
    using Gregors data

    Args:
        pdb_id (_type_): _description_
        chain_A (_type_): _description_
        chain_B (_type_): _description_

    Returns:
        bool: _description_
    """
    
    # find files matching pattern and check for atom pairs within 5A between the two chains
    files = list(Path(data_dir).rglob(f"{str.upper(pdb_id)}_detailed_interactions.csv"))

    if len(files) > 1:
        print(f"ERROR: for {pdb_id}, multiple files were found.")
        return False
    elif len(files) == 0:
        print(f"ERROR: for {pdb_id}, no files were found.")
        return False
        
    df = pd.read_csv(files[0], sep=',')
    
    atom_counter = 0
    for _,row in df.iterrows():
        if (row['Chain_A'] == chain_X and row['Chain_B'] == chain_Y) or (row['Chain_A'] == chain_Y and row['Chain_B'] == chain_X):
            if row['Distance'] <= max_distance:
                atom_counter += 1

    return atom_counter >= min_atoms

In [None]:
# hadeer_df.drop(columns=['Protein.names', 'Gene.Names', 'Organism', 'Length', 'InterPro', 'Pfam', 'DisProt', 'STRING', 'IntAct', 'Ensembl', 'Repeat', 'HGNC'])
hadeer_df.rename(columns={'chain': 'chain_tf', 'Entry': 'Entry_tf'}, inplace=True)

# add for each pdb the other chains and other pdbs
# find out which chain is ARM
# check interaction between chain_tf and chain_arm

hadeer_df['chains_arm'] = None
hadeer_df['Entrys_arm'] = None

for ind,row in hadeer_df.iterrows():

    pdb_id = row['pdb']
    if pd.isna(pdb_id):
        continue

    # normalize pdb id and ensure chain is present
    pdb_id = str(pdb_id).upper().strip()
    chain_tf = row['chain_tf']
    if pd.isna(chain_tf):
        continue

    all_pdbs = get_pdb_chains_to_uniprot(pdb_id, '../../production1/mapping_cache')
    
    arm_pairs = []
    
    for chain, id in all_pdbs.items():
        if id in armadillo_proteins['Entry'].tolist():
            arm_pairs.append((chain, id))
            
    
    if len(arm_pairs) == 0:
        continue
    else:
        hadeer_df.at[ind, 'chains_arm'] = [pair[0] for pair in arm_pairs]
        hadeer_df.at[ind, 'Entrys_arm'] = [pair[1] for pair in arm_pairs]
    

In [None]:
# define interface as having at least n atoms within 5 A of each other 
INTERFACE_MIN_ATOMS = 10
# Check interface between TF chain and ARM chain
hadeer_df['arm_tf_interface'] = False  # Initialize with False
valid_rows = hadeer_df[hadeer_df['pdb'].notna() & hadeer_df['chain_tf'].notna() & hadeer_df['chains_arm'].notna()]

for ind in valid_rows.index:
    pdb_id = str(valid_rows.at[ind, 'pdb']).upper().strip()
    chain_tf = valid_rows.at[ind, 'chain_tf']
    chains_arm = valid_rows.at[ind, 'chains_arm']
    
    for chain_arm in chains_arm:
        if check_interface(pdb_id, chain_tf, chain_arm, '/home/markus/MPI_local/data/PDB2Net', INTERFACE_MIN_ATOMS):
            hadeer_df.at[ind, 'arm_tf_interface'] = True

In [None]:
print(len(hadeer_df[hadeer_df['arm_tf_interface'] == True]['pdb'].unique()))
display(hadeer_df[hadeer_df['arm_tf_interface'] == True]['pdb'].unique())

In [None]:
BATCH_DIRS = [os.path.join('../../production1/AF_job_batches/', d) for d in os.listdir('../../production1/AF_job_batches') if os.path.isdir(os.path.join('../../production1/AF_job_batches', d)) and 'batch' in d]
BATCH_DIRS.extend([os.path.join('../../production1/PDB_modelling/', d) for d in os.listdir('../../production1/PDB_modelling') if os.path.isdir(os.path.join('../../production1/PDB_modelling', d)) and 'batch' in d])

new_af_jobs = create_job_batch_from_PDB_IDs(hadeer_df[hadeer_df['arm_tf_interface'] == True]['pdb'].unique().tolist(), BATCH_DIRS, 12000)
write_af_jobs_to_individual_files(new_af_jobs, '../../production1/PDB_modelling/batch_6')

## 2. Create AF job files
- create job files for alphafold
- don't create duplicate jobs

### create job files

In [None]:
STRING_SCORE_COLUMN = 'experimental'
INTACT_SCORE_COLUMN = 'intact_score'
from functions_job_creation import *

In [None]:
BATCH_DIRS = [os.path.join('../../production1/AF_job_batches/', d) for d in os.listdir('../../production1/AF_job_batches') if os.path.isdir(os.path.join('../../production1/AF_job_batches', d)) and 'batch' in d]
BATCH_DIRS.extend([os.path.join('../../production1/PDB_modelling/', d) for d in os.listdir('../../production1/PDB_modelling') if os.path.isdir(os.path.join('../../production1/PDB_modelling', d)) and 'batch' in d])
BATCH_SIZE = 2000

### STRING
# note that the order is important. The category (100,200) is very large so it comes last to fill up the remaining jobs
# categories = [(900,1000), (800,900), (700,800), (600,700), (500,600), (400,500), (300,400), (100,200), (0,100), (200,300)]
# new_af_jobs = create_job_batch_scoreCategories(all_pairs_w_STRING, BATCH_SIZE, categories, BATCH_DIRS, STRING_SCORE_COLUMN, AF_TOKEN_LIMIT)

### IntAct
# categories = [(0.1,0.2), (0.9,1), (0.8,0.9), (0.7,0.8), (0.6,0.7), (0.5,0.6), (0.4,0.5), (0.2,0.3), (0.3,0.4)]
# new_af_jobs = create_job_batch_scoreCategories(all_pairs_w_IntAct, BATCH_SIZE, categories, BATCH_DIRS, INTACT_SCORE_COLUMN, AF_TOKEN_LIMIT)

### all pairs
new_af_jobs = create_job_batch_all_pairs(all_pairs, BATCH_SIZE, BATCH_DIRS, AF_TOKEN_LIMIT)


### ID list:
id_list_good = [
    ("Q13285", "A0A2R8YCH5"),
    ("P04637", "A0A8I5KU01"),
    ("P04637", "A0A8I5KU01"),
    ("Q9H3D4", "A0A8I5KU01"),
    ("Q8NHM5", "A1YPR0"),
    ("Q9UJU2", "A0A2R8YCH5"),
    ("Q9UJU2", "A0A2R8YCH5"),
    ("Q6SJ96", "O14981"),
    ("Q9NRY4", "O00750"),
    ("Q6ZRS2", "A0A8V8TQN3")
]

# id_list_complex = [
#     ("Q03181", "Q9H3U1"),
#     ("P04637", "A0A994J4J0"),
#     ("Q6SJ96", "O14981"),
#     ("P49450", "O14981"),
#     ("P49450", "O14981"),
#     ("P49450", "O14981"),
#     ("P49450", "O14981"),
#     ("Q9UBG7", "A0A8I5KU01"),
#     ("P19838", "A0A1W2PRG6")
# ]

# missing = [('Q13285', 'Q6BTZ4'), ('P04637', 'Q9VL06'), ('P04637', 'Q9VL06'), ('Q9UIF8', 'Q54U63'), ('Q15047', 'Q54U63'), ('Q15047', 'Q5R881'), ('Q9H3D4', 'Q9VL06'), ('P49450', 'Q4WJI7'), ('P49450', 'Q4WJI7'), ('P49450', 'Q4WJI7'), ('P49450', 'Q4WJI7'), ('Q6ZRS2', 'P38811'), ('Q6ZRS2', 'P38811')]
# new_af_jobs = create_job_batch_id_list(all_pairs, missing, BATCH_DIRS, AF_TOKEN_LIMIT)


write_af_jobs_to_individual_files(new_af_jobs, '../../production1/AF_job_batches/batch_55')

## 3. Analyze AF results

### Constants

In [None]:
HPC_RESULT_DIR = "/home/markus/MPI_local/HPC_results"

### Helper functions

In [None]:
from functions_analysis import *

### Negatomes

#### IntAct Negatome

In [None]:
# intact_negative = pd.read_csv('../../data/IntAct/human/human_negative.txt', sep='\t')
# intact_negative.drop(['Alias(es) interactor A', 
#                      'Alias(es) interactor B', 
#                      'Interaction detection method(s)',
#                      'Publication 1st author(s)',
#                      'Publication Identifier(s)',
#                      'Taxid interactor A',
#                      'Taxid interactor B',
#                      'Biological role(s) interactor A',
#                      'Biological role(s) interactor B',
#                      'Experimental role(s) interactor A',
#                      'Experimental role(s) interactor B',
#                      'Type(s) interactor A',
#                      'Type(s) interactor B',
#                      'Xref(s) interactor A',
#                      'Xref(s) interactor B',
#                      'Interaction Xref(s)',
#                      'Annotation(s) interactor A',
#                      'Annotation(s) interactor B',
#                      'Interaction annotation(s)',
#                      'Host organism(s)',
#                      'Interaction parameter(s)',
#                      'Creation date',
#                      'Update date',
#                      'Checksum(s) interactor A',
#                      'Checksum(s) interactor B',
#                      'Interaction Checksum(s)',
#                      'Feature(s) interactor A',
#                      'Feature(s) interactor B',
#                      'Stoichiometry(s) interactor A',
#                      'Stoichiometry(s) interactor B',
#                      'Identification method participant A',
#                      'Identification method participant B',
#                      'Expansion method(s)'
#                      ], axis=1, inplace=True)
# intact_negative.loc[:, 'intact_score'] = intact_negative.loc[: , 'Confidence value(s)'].apply(intact_score_filter)
# intact_negative['pair_id'] = intact_negative.apply(lambda row: str(tuple(sorted([row['#ID(s) interactor A'].replace('uniprotkb:', '').split('-')[0], row['ID(s) interactor B'].replace('uniprotkb:', '').split('-')[0]]))), axis=1)
# intact_negative = intact_negative.sort_values('intact_score', ascending=False).drop_duplicates('pair_id', keep='first')
# all_pairs_intact_negative = pd.merge(all_pairs, intact_negative, on='pair_id', how='inner')
# print(len(all_pairs_intact_negative))

#### Blohm negatome2.0

In [None]:
# negatome2 = pd.read_csv('../../data/negatome2.0/combined.txt', sep='\t', names=['ID_1', 'ID_2'])
# negatome2['pair_id'] = negatome2.apply(lambda row: str(tuple(sorted([row['ID_1'].split('-')[0], row['ID_2'].split('-')[0]]))), axis=1)
# all_pairs_negatome2 = pd.merge(all_pairs, negatome2, on='pair_id', how='inner')
# print(len(all_pairs_negatome2))

#### Stelzl 2005 negatome

In [None]:
# stelzl_neg = pd.read_csv('../../data/16169070_neg.mitab', sep='\t')

### Read in HPC results

In [None]:
# Create DataFrame from all job data
results_df_uc = pd.DataFrame(data=find_summary_files(HPC_RESULT_DIR))

# Print basic information about the DataFrame
print(f"Total jobs processed: {len(results_df_uc)}")

results_df_uc['pair_id'] = results_df_uc.apply(create_pair_id, axis=1)

print(f"jobs before cleaning: {len(results_df_uc)}")
results_df = clean_results(results_df_uc)
print(f"jobs after cleaning: {len(results_df)}")

In [None]:
results_df_annotated = pd.merge(results_df, string_df, on='pair_id', how='left')
results_df_annotated = pd.merge(results_df_annotated, intact_cleaned, on='pair_id', how='left')

# Print information about the merged dataframe
print(f"Total number of modelled pairs: {len(results_df)}")
print(f"Total rows in merged_df: {len(results_df_annotated)}")
print(f"Rows with annotated data (STRING): {results_df_annotated['combined_score'].notna().sum()}")
print(f"Rows with annotated data (IntAct): {results_df_annotated['intact_score'].notna().sum()}")


# convert all STRING scores from 0-1000 to 0-1 (linear conversion)
STRING_COLS = ['experimental', 'database', 'textmining', 'combined_score']
for col in STRING_COLS:
    results_df_annotated[col] = results_df_annotated[col] / 1000

### Comparing AlphaFold ranking scores with STRING combined scores

In [None]:
from functions_plotting import *

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot(results_df_annotated.dropna(subset=[STRING_SCORE_COLUMN]), 'iptm', STRING_SCORE_COLUMN, ax=axes[0])
create_scatter_plot(results_df_annotated.dropna(subset=[STRING_SCORE_COLUMN]), 'ptm', STRING_SCORE_COLUMN, ax=axes[1])
create_scatter_plot(results_df_annotated.dropna(subset=[STRING_SCORE_COLUMN]), 'ranking_score', STRING_SCORE_COLUMN, ax=axes[2])

fig.suptitle('AlphaFold Metrics vs STRING score, NAs dropped', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()


fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'iptm', INTACT_SCORE_COLUMN, ax=axes[0])
create_scatter_plot(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'ptm', INTACT_SCORE_COLUMN, ax=axes[1])
create_scatter_plot(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'ranking_score', INTACT_SCORE_COLUMN, ax=axes[2])

fig.suptitle('AlphaFold Metrics vs IntAct score, NAs dropped', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot(results_df_annotated.fillna({STRING_SCORE_COLUMN: 0}), 'iptm', STRING_SCORE_COLUMN, ax=axes[0])
create_scatter_plot(results_df_annotated.fillna({STRING_SCORE_COLUMN: 0}), 'ptm', STRING_SCORE_COLUMN, ax=axes[1])
create_scatter_plot(results_df_annotated.fillna({STRING_SCORE_COLUMN: 0}), 'ranking_score', STRING_SCORE_COLUMN, ax=axes[2])

fig.suptitle('AlphaFold Metrics vs STRING score, NAs treated as 0', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot_colour(results_df_annotated.dropna(subset=[STRING_SCORE_COLUMN]), 'ranking_score', 'iptm', STRING_SCORE_COLUMN, 'ranking_score vs iptm', ax=axes[0])
create_scatter_plot_colour(results_df_annotated.dropna(subset=[STRING_SCORE_COLUMN]), 'ranking_score', 'ptm', STRING_SCORE_COLUMN, 'ranking_score vs ptm', ax=axes[1])
create_scatter_plot_colour(results_df_annotated.dropna(subset=[STRING_SCORE_COLUMN]), 'ptm', 'iptm', STRING_SCORE_COLUMN, 'iptm vs ptm', ax=axes[2])

fig.suptitle('AlphaFold Metrics Colored by STRING Score (experiments), NAs dropped', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()


In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot_colour(results_df_annotated, 'ranking_score', 'iptm', STRING_SCORE_COLUMN, 'ranking_score vs iptm', ax=axes[0])
create_scatter_plot_colour(results_df_annotated, 'ranking_score', 'ptm', STRING_SCORE_COLUMN, 'ranking_score vs ptm', ax=axes[1])
create_scatter_plot_colour(results_df_annotated, 'ptm', 'iptm', STRING_SCORE_COLUMN, 'iptm vs ptm', ax=axes[2])

fig.suptitle('AlphaFold Metrics Colored by STRING Score (experiments); NAs included', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'iptm', INTACT_SCORE_COLUMN, ax=axes[0])
create_scatter_plot(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'ptm', INTACT_SCORE_COLUMN, ax=axes[1])
create_scatter_plot(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'ranking_score', INTACT_SCORE_COLUMN, ax=axes[2])

fig.suptitle('AlphaFold Metrics vs IntAct score, NAs dropped', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot_colour(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'ranking_score', 'iptm', INTACT_SCORE_COLUMN, 'ranking_score vs iptm', ax=axes[0])
create_scatter_plot_colour(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'ranking_score', 'ptm', INTACT_SCORE_COLUMN, 'ranking_score vs ptm', ax=axes[1])
create_scatter_plot_colour(results_df_annotated.dropna(subset=[INTACT_SCORE_COLUMN]), 'ptm', 'iptm', INTACT_SCORE_COLUMN, 'iptm vs ptm', ax=axes[2])

fig.suptitle('AlphaFold Metrics Colored by IntAct Score, NAs dropped', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()


In [None]:
results_df_annotated['avg_STRING_IntAct'] = (results_df_annotated[STRING_SCORE_COLUMN] + results_df_annotated[INTACT_SCORE_COLUMN]) / 2 

fig, axes = plt.subplots(1, 3, figsize=(18, 6))

create_scatter_plot_colour(results_df_annotated.dropna(subset=['avg_STRING_IntAct']), 'ranking_score', 'iptm', 'avg_STRING_IntAct', 'ranking_score vs iptm', ax=axes[0])
create_scatter_plot_colour(results_df_annotated.dropna(subset=['avg_STRING_IntAct']), 'ranking_score', 'ptm', 'avg_STRING_IntAct', 'ranking_score vs ptm', ax=axes[1])
create_scatter_plot_colour(results_df_annotated.dropna(subset=['avg_STRING_IntAct']), 'ptm', 'iptm', 'avg_STRING_IntAct', 'iptm vs ptm', ax=axes[2])

fig.suptitle('AlphaFold Metrics Colored by average of STRING and IntAct score, NAs dropped', fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
plt.show()

### ROC curves

creating AUC curves with STRING / IntAct alone does not really make sense. My understanding is that both include only **positive** interaction candidates and assign scores to the certainty. So even a low score means a relatively high probability of interaction because the candidate is in the DB.

In [None]:
import sklearn

def plot_roc(df: pd.DataFrame, pos: Set[str], neg: Set[str], param_name: str, min_val: float, max_val: float, 
             sampling: int = 1000, direction: str = 'up', plot_title: str = '', ax=None) -> Tuple[List[float], List[float]]:
    """Calculate and plot the ROC curve based on a specified parameter.
    
    This function evaluates the performance of a binary classifier by varying a threshold parameter
    and calculating the True Positive Rate (TPR) and False Positive Rate (FPR) at each threshold.
    
    Args:
        df (pd.DataFrame): DataFrame containing the parameter to evaluate and pair_id column
        pos (Set[str]): Set of positive example pair_ids (ground truth positive cases)
        neg (Set[str]): Set of negative example pair_ids (ground truth negative cases)
        param_name (str): Name of the column in df to use as the classification parameter
        min_val (float): Minimum threshold value to evaluate
        max_val (float): Maximum threshold value to evaluate
        sampling (int, optional): Number of threshold points to sample between min and max. Defaults to 1000.
        direction (str, optional): Direction of classification - 'up' means values >= threshold are positive,
                                  'down' means values < threshold are positive. Defaults to 'up'.
        plot_title (str, optional): Title for the ROC curve plot. If empty, a default title is used. Defaults to ''.
        ax (matplotlib.axes.Axes, optional): Axes object to plot on. If None, creates new figure.
                                   
    Returns:
        Tuple[List[float], List[float]]: Lists of FPR and TPR values that make up the ROC curve
    """
    if ax is None:
        # Print the size of positive and negative sets for verification    
        print(f"Number of positive examples: {len(pos)}")
        print(f"Number of negative examples: {len(neg)}")
        plt.figure(figsize=(8, 8))
        ax = plt.gca()
        show_plot = True
    else:
        show_plot = False

    # Calculate step size based on range and sampling
    step = (max_val - min_val) / sampling
    if step <= 0:
        raise ValueError("max_val must be greater than min_val")
    
    # Lists to store TPR and FPR values
    tpr_list = []
    fpr_list = []
    
    for i in range(sampling + 1):
        sep_val = min_val + (i * step)
        
        if direction == 'up':
            calc_pos = set(df[df[param_name] >= sep_val]['pair_id'].tolist())
            calc_neg = set(df[df[param_name] < sep_val]['pair_id'].tolist())
        elif direction == 'down':
            calc_pos = set(df[df[param_name] < sep_val]['pair_id'].tolist())
            calc_neg = set(df[df[param_name] >= sep_val]['pair_id'].tolist())
        else:
            raise ValueError("direction must be either 'up' or 'down'")
        
        # Calculate confusion matrix values
        TP_num = len(calc_pos.intersection(pos))
        FP_num = len(calc_pos.intersection(neg))
        TN_num = len(calc_neg.intersection(neg))
        FN_num = len(calc_neg.intersection(pos))
        
        # Avoid division by zero
        TPR = TP_num / max(TP_num + FN_num, 1)
        FPR = FP_num / max(FP_num + TN_num, 1)
        
        tpr_list.append(TPR)
        fpr_list.append(FPR)
    
    # Plot the ROC curve
    ax.plot(fpr_list, tpr_list, 'b-', linewidth=2)
    ax.plot([0, 1], [0, 1], 'k--', linewidth=2)  # Diagonal line representing random guess
    
    auc_value = sklearn.metrics.auc(fpr_list, tpr_list)
    
    # Add labels and title
    ax.set_xlabel('False Positive Rate', fontsize=12)
    ax.set_ylabel('True Positive Rate', fontsize=12)
    
    if plot_title:
        title = plot_title
    else:
        title = f'ROC Curve for {param_name}'
    
    ax.set_title(f'{title}\nAUC = {auc_value:.3f}', fontsize=12)
    
    # Add grid and improve appearance
    ax.grid(True, linestyle='--', alpha=0.7)
    ax.set_xlim([0, 1])
    ax.set_ylim([0, 1])
    
    # Show plot only if not using subplots
    if show_plot:
        plt.tight_layout()
        plt.show()
    
    return fpr_list, tpr_list

In [None]:
# STRING_SCORE_CUTOFF = 0.4
# # comparisons like >=, < with NAs don't pass the filter
# string_combined_pos = set(results_df_annotated[results_df_annotated[STRING_SCORE_COLUMN] >= STRING_SCORE_CUTOFF]['pair_id'].to_list())
# string_combined_neg = set(results_df_annotated[results_df_annotated[STRING_SCORE_COLUMN] < STRING_SCORE_CUTOFF]['pair_id'].to_list())

# print(f"Number of positive examples: {len(string_combined_pos)}")
# print(f"Number of negative examples: {len(string_combined_neg)}")

# fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# plot_roc(results_df_annotated, string_combined_pos, string_combined_neg, 'ranking_score', 0, 1, plot_title='Ranking Score', ax=axes[0])
# plot_roc(results_df_annotated, string_combined_pos, string_combined_neg, 'iptm', 0, 1, plot_title='iPTM', ax=axes[1])
# plot_roc(results_df_annotated, string_combined_pos, string_combined_neg, 'ptm', 0, 1, plot_title='PTM', ax=axes[2])

# fig.suptitle(f'ROC Curves for AlphaFold Metrics (STRING cutoff >= {STRING_SCORE_CUTOFF})', fontsize=16)
# plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
# plt.show()

In [None]:
# INTACT_SCORE_CUTOFF = 0.4
# # comparisons like >=, < with NAs don't pass the filter
# intact_combined_pos = set(results_df_annotated[results_df_annotated[INTACT_SCORE_COLUMN] >= INTACT_SCORE_CUTOFF]['pair_id'].to_list())
# intact_combined_neg = set(results_df_annotated[results_df_annotated[INTACT_SCORE_COLUMN] < INTACT_SCORE_CUTOFF]['pair_id'].to_list())

# print(f"Number of positive examples: {len(intact_combined_pos)}")
# print(f"Number of negative examples: {len(intact_combined_neg)}")

# fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# plot_roc(results_df_annotated, intact_combined_pos, intact_combined_neg, 'ranking_score', 0, 1, plot_title='Ranking Score', ax=axes[0])
# plot_roc(results_df_annotated, intact_combined_pos, intact_combined_neg, 'iptm', 0, 1, plot_title='iPTM', ax=axes[1])
# plot_roc(results_df_annotated, intact_combined_pos, intact_combined_neg, 'ptm', 0, 1, plot_title='PTM', ax=axes[2])

# fig.suptitle(f'ROC Curves for AlphaFold Metrics (IntAct cutoff >= {INTACT_SCORE_CUTOFF})', fontsize=16)
# plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for subtitle
# plt.show()

## 4. Structure Dataset analysis

In [None]:
import glob
from functions_blastp import *
from download_functions import *
from functions_job_creation import *

In [None]:
review_files = glob.glob('/home/markus/MPI_local/production1/structure_reviews/*.csv')
# review_files = glob.glob('/home/markus/MPI_local/production1/structure_reviews/intersect_df - set3.csv')
struct_ds = pd.concat([pd.read_csv(f) for f in review_files], ignore_index=True)
struct_ds = struct_ds.drop_duplicates(subset=['pdb_id', 'query_tf', 'query_arm', 'chain_tf', 'chain_arm'])
struct_ds['pair_id'] = struct_ds.apply(lambda row: str(tuple(sorted([row['query_tf'].split('|')[0].upper(), row['query_arm'].split('|')[0].upper()]))), axis=1)

In [None]:
pdb_ids = struct_ds['pdb_id'].drop_duplicates().to_list()
for pdb_id in pdb_ids:
    sequences = download_pdb_sequence(pdb_id)
    if len(sequences) == 0:
        print(f"Empty sequence list: {pdb_id}")
        continue
    job = create_alphafold_job_ms(pdb_id, sequences)
    file_path = os.path.join('/home/markus/MPI_local/production1/PDB_modelling/', f"{pdb_id}.json")
    with open(file_path, 'w') as f:
        json.dump(job, f, indent=2)

In [None]:
struct_ds = struct_ds.drop(columns=[
    "chain_tf", "chain_arm", "%identity_tf", "%identity_arm",
    "evalue_tf", "bit score_tf", "evalue_arm", "subject_tf",
    "bit score_arm", "alignment length_tf", "mismatches_tf", "gap opens_tf", "q. start_tf",
    "q. end_tf", "s. start_tf", "s. end_tf", "subject_arm", "alignment length_arm",
    "mismatches_arm", "gap opens_arm", "q. start_arm", "q. end_arm", "s. start_arm", "s. end_arm", "Unnamed: 12", "Unnamed: 5"
])

In [None]:
BLAST_IDENTITY_CUTOFF: int|bool = False
BLAST_SCORE_CUTOFF: int|bool = False
BLAST_EVALUE_CUTOFF: float|bool = 0.00001
BLAST_COVERAGE_CUTOFF: float|bool = 0.5
TF_OUTPUT_DIR = "/home/markus/MPI_local/production1/blastp_results/tf_blastp_no_e_lim_new_fmt"
ARM_OUTPUT_DIR = "/home/markus/MPI_local/production1/blastp_results/tf_blastp_no_e_lim_new_fmt"


columns: List[str] = [
    "query", "subject", "%identity", "alignment length", "mismatches", "gap opens",
    "q. start", "q. end", "s. start", "s. end", "evalue", "bit score", "% query coverage per subject", "% query coverage per hsp", "% query coverage per uniq subject"
]

tf_blast_df: pd.DataFrame = clean_blastp_out(read_blast_to_df(TF_OUTPUT_DIR, columns), 
                                             identity_cutoff=BLAST_IDENTITY_CUTOFF,
                                             score_cutoff=BLAST_SCORE_CUTOFF,
                                             evalue_cutoff=BLAST_EVALUE_CUTOFF,
                                             coverage_cutoff=BLAST_COVERAGE_CUTOFF)
print(len(tf_blast_df))
arm_blast_df: pd.DataFrame = clean_blastp_out(read_blast_to_df(ARM_OUTPUT_DIR, columns), 
                                             identity_cutoff=BLAST_IDENTITY_CUTOFF,
                                             score_cutoff=BLAST_SCORE_CUTOFF,
                                             evalue_cutoff=BLAST_EVALUE_CUTOFF,
                                             coverage_cutoff=BLAST_COVERAGE_CUTOFF)
print(len(arm_blast_df))

In [None]:
struct_ds = struct_ds.merge(arm_blast_df.add_suffix('_arm'), how='left', left_on='query_arm', right_on='uniprot_id_arm')
struct_ds = struct_ds.merge(tf_blast_df.add_suffix('_tf'), how='left', left_on='query_tf', right_on='uniprot_id_tf')

In [None]:
print(len(struct_ds))
struct_ds = struct_ds[struct_ds['pair_id'].isin(all_pairs['pair_id'])]
print(len(struct_ds))

In [None]:
good_structs = all_pairs_struct_ds_annotated[all_pairs_struct_ds_annotated['comment'].str.contains('looks good|in complex', case=False, na=False)]

# Get unique PDB IDs from reviews_df
unique_pdb_ids = good_structs['pdb_id'].unique()
print(f"Found {len(unique_pdb_ids)} unique PDB IDs to download")

# Define download directory
pdb_download_dir = "/home/markus/MPI_local/data/PDB"

# Download each PDB structure (function handles duplicate checking)
downloaded_count = 0
failed_count = 0

for pdb_id in unique_pdb_ids:
    if pd.isna(pdb_id):  # Skip NaN values
        continue
        
    result = download_pdb_structure(pdb_id, pdb_download_dir)
    if result:
        downloaded_count += 1
    else:
        failed_count += 1

print(f"\nSummary:")
print(f"Successfully processed: {downloaded_count}")
print(f"Failed downloads: {failed_count}")
print(f"Total processed: {len([pdb_id for pdb_id in unique_pdb_ids if not pd.isna(pdb_id)])}")

In [None]:
NATIVE_PATH_PREFIX = "/home/markus/MPI_local/data/PDB/"
HPC_FULL_RESULTS_DIR = "/home/markus/MPI_local/HPC_results_full"

# Calculate DockQ scores using the function from functions_filtering.py
good_structs = append_dockq(good_structs, NATIVE_PATH_PREFIX, HPC_FULL_RESULTS_DIR, all_uniprot)