# 1. Imports

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from functions_filtering import *

import importlib
import constants
importlib.reload(constants)
from constants import *

import functions_analysis
import functions_job_creation
import functions_filtering
import functions_plotting
import functions_download
import functions_pdb2net

# Reload the module
importlib.reload(functions_analysis)
importlib.reload(functions_job_creation)
importlib.reload(functions_filtering)
importlib.reload(functions_download)
importlib.reload(functions_plotting)
importlib.reload(functions_pdb2net)

from functions_analysis import *
from functions_job_creation import *
from functions_download import *
from functions_filtering import *
from functions_plotting import *
from functions_pdb2net import *

# 2. Filter

In [None]:
rep = pd.read_csv('/home/markus/MPI_local/data/PDB_reports/4/combined_pdb_reports_processed.csv', low_memory=False)

In [None]:
# remove rows with NaN in sequence field
print(f"Before removing NaN sequences: {len(rep)}")
rep = rep.dropna(subset=['Sequence'])
print(f"After removing NaN sequences: {len(rep)}")

In [None]:
# length filter
MIN_LENGTH = 20

print(f"Before length filter: {len(rep)}")
rep['Sequence length'] = rep['Sequence'].fillna('').str.len()
rep = rep[rep['Sequence length'] >= MIN_LENGTH]
print(f"After length filter: {len(rep)}")

In [None]:
rep = add_iupred3(rep, 'long', 'no', IUPRED_CACHE_DIR, IUPRED3_THRESHOLD, MIN_LENGTH_DISORDERED_REGION, IUPRED3_PATH, debug=False)

In [None]:
# filter out all PDBs that don't have at least 1 disordered chain
print(len(rep))
keep_pdbs = set()
for _, row in rep.iterrows():
    if row['num_disordered_regions'] > 0:
        keep_pdbs.add(row['Entry ID'])
        
rep = rep[rep['Entry ID'].isin(keep_pdbs)]

print(len(rep))

In [None]:
# filter out all PDBs that don't have at least 1 ordered chain
print(len(rep))
keep_pdbs = set()
for _, row in rep.iterrows():
    if row['num_disordered_regions'] == 0:
        keep_pdbs.add(row['Entry ID'])
        
rep = rep[rep['Entry ID'].isin(keep_pdbs)]

print(len(rep))

In [None]:
# write into directory for PDB2Net processing
PDB2NET_PREFIX = '/home/markus/PDB2Net/in/'
path_prefix = PDB2NET_PREFIX + 'pipeline2/'
rep['model_path'] = rep['Entry ID'].apply(lambda id: path_prefix + id.lower() + '.cif')

rep.drop_duplicates(subset=['model_path'], inplace=False)['model_path'].to_csv(PDB2NET_PREFIX + 'pipeline2.csv', index=False)

# download pdb structures for pdb2net
download_pdb_structures(set(rep['Entry ID'].tolist()), path_prefix, 'cif', '/home/markus/MPI_local/data/PDB', debug=False)

In [None]:
# annotate interface
# define interface as having at least INTERFACE_MIN_ATOMS atoms within INTERFACE_MAX_DISTANCE A of each other
INTERFACE_MIN_ATOMS = 10
INTERFACE_MAX_DISTANCE = 5 # higher not possible => change PDB2Net data

In [None]:
PATH = '/home/markus/MPI_local/data/PDB2Net/pipeline2/2025-08-28_19-38-42'
interfaces_df = get_interfaces_pdb2net(PATH, INTERFACE_MIN_ATOMS, INTERFACE_MAX_DISTANCE, set(rep['Entry ID'].to_list()))

In [None]:
# Deduplicate interfaces based on UniProt IDs
print(f"Before deduplication: {len(interfaces_df)} interfaces")

interfaces_df['normalized_uniprot'] = interfaces_df['Uniprot IDs'].apply(normalize_uniprot_pair)
interfaces_df = interfaces_df.drop_duplicates(subset=['normalized_uniprot'])
interfaces_df = interfaces_df.drop('normalized_uniprot', axis=1)

print(f"After deduplication: {len(interfaces_df)} interfaces")

In [None]:
# use only interactions between ordered-disordered chain
rep_reindex = rep.copy()

rep_reindex.set_index(['Entry ID', 'Asym ID'], inplace=True)

import json
print(len(interfaces_df))

for ind,row in interfaces_df.iterrows():
    interface_id = json.loads(row['Interface ID'].replace('\'', '"'))
    chainID_1 = interface_id[0]
    chainID_2 = interface_id[1]
    try:
        disorder_chain1 = int(rep_reindex.loc[(row['Entry ID'], chainID_1), 'num_disordered_regions'])
        disorder_chain2 = int(rep_reindex.loc[(row['Entry ID'], chainID_2), 'num_disordered_regions'])
    except KeyError as e:
        # Skip this interface if entry/chain not found
        # print(f"Entry not found: {e}")
        interfaces_df.drop(ind, inplace=True)
        continue
    disorder_chain1 = int(rep_reindex.loc[(row['Entry ID'], chainID_1), 'num_disordered_regions'])
    disorder_chain2 = int(rep_reindex.loc[(row['Entry ID'], chainID_2), 'num_disordered_regions'])
    
    if not ((disorder_chain1 == 0 and disorder_chain2 >= 1) or (disorder_chain2 == 0 and disorder_chain1 >= 1)):
        interfaces_df.drop(ind, inplace=True)
        
print(len(interfaces_df))

In [None]:
# annotate interfaces with data
interfaces_df['Release Date'] = interfaces_df['Entry ID'].map(rep.groupby('Entry ID')['Release Date'].first())

In [None]:
up_ids_structure_ds = interfaces_df['Uniprot IDs'].tolist()
%store up_ids_structure_ds

# 3. Job Creation

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from functions_filtering import *

import importlib
import constants
importlib.reload(constants)
from constants import *

import functions_analysis
import functions_job_creation
import functions_filtering
import functions_plotting
import functions_download
import functions_pdb2net

# Reload the module
importlib.reload(functions_analysis)
importlib.reload(functions_job_creation)
importlib.reload(functions_filtering)
importlib.reload(functions_download)
importlib.reload(functions_plotting)
importlib.reload(functions_pdb2net)

from functions_analysis import *
from functions_job_creation import *
from functions_download import *
from functions_filtering import *
from functions_plotting import *
from functions_pdb2net import *

In [None]:
BATCH_DIRS = []
BATCH_DIRS.extend([os.path.join('../../production1/PDB_modelling/', d) for d in os.listdir('../../production1/PDB_modelling') if os.path.isdir(os.path.join('../../production1/PDB_modelling', d)) and 'batch' in d])

interfaces_df.rename(columns={'Interface ID': 'Chains'}, inplace=True)
new_af_jobs, reference_jobs = create_job_batch_from_PDB_chains(interfaces_df, BATCH_DIRS, 5120)
write_af_jobs_to_individual_files(new_af_jobs, '../../production1/PDB_modelling/batch_12')
write_af_jobs_to_individual_files(reference_jobs, '../../production1/PDB_modelling/reference', 'alphafold3', True)
interfaces_df.rename(columns={'Chains': 'Interface ID'}, inplace=True)

# 4. Analysis

In [None]:
interfaces_df.rename(columns={'Chains': 'Interface ID'}, inplace=True)
interfaces_df.rename(columns={'pdb_id': 'Entry ID'}, inplace=True)

In [None]:
NATIVE_PATH_PREFIX = "/home/markus/MPI_local/data/PDB/"
HPC_FULL_RESULTS_DIR = "/home/markus/MPI_local/HPC_results_full"
PDB_CACHE = '../../production1/pdb_cache'
DOCKQ_CACHE = '../../production1/dockq_cache_p2'
JOB_DIR = '/home/markus/MPI_local/production1/PDB_modelling'
# rename columns for compatibility with annotate_dockq()
interfaces_df['job_name'] = interfaces_df.apply(lambda row: row['Entry ID'] + "_" + "_".join(eval(row['Interface ID'])), axis=1)
interfaces_df.rename(columns={'Entry ID': 'pdb_id'}, inplace=True)
interfaces_df, no_model = append_dockq_two_chainIDs(interfaces_df, NATIVE_PATH_PREFIX, HPC_FULL_RESULTS_DIR, JOB_DIR, PDB_CACHE, DOCKQ_CACHE)
interfaces_df.rename(columns={'pdb_id': 'Entry ID'}, inplace=True)

In [None]:
print(no_model)
find_job_files(no_model, '/home/markus/MPI_local/production1')

In [None]:
interfaces_df = annotate_AF_metrics(interfaces_df, '/home/markus/MPI_local/HPC_results_full')
interfaces_df['in_training_set'] = interfaces_df['Release Date'] <= AF_TRAINING_CUTOFF

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
create_scatter_plot_colour(interfaces_df, 'iptm', 'dockq_score', 'in_training_set', ax=axes[0], corr=True)
create_scatter_plot_colour(interfaces_df, 'ptm', 'dockq_score',  'in_training_set', ax=axes[1], corr=True)
create_scatter_plot(interfaces_df, 'ranking_score', 'dockq_score', ax=axes[2], corr=True)

In [None]:
interfaces_df['interface_pae_max'] = interfaces_df['chain_pair_pae_min'].apply(lambda c: max([c[0][1], c[1][0]]))
interfaces_df['interface_pae_min'] = interfaces_df['chain_pair_pae_min'].apply(lambda c: min([c[0][1], c[1][0]]))

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
create_scatter_plot(interfaces_df, 'interface_pae_max', 'dockq_score', ax=axes[0], corr=True)
create_scatter_plot(interfaces_df, 'interface_pae_min', 'dockq_score', ax=axes[1], corr=True)

In [None]:
interfaces_df['chain_pair_iptm_max'] = interfaces_df['chain_pair_iptm'].apply(lambda c: max([c[0][0], c[1][1]]))
interfaces_df['chain_pair_iptm_min'] = interfaces_df['chain_pair_iptm'].apply(lambda c: min([c[0][0], c[1][1]]))


In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
create_scatter_plot(interfaces_df, 'chain_pair_iptm_max', 'dockq_score', ax=axes[0], corr=True)
create_scatter_plot(interfaces_df, 'chain_pair_iptm_min', 'dockq_score', ax=axes[1], corr=True)