# 1. Imports

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from functions_filtering import *

import importlib
import constants
importlib.reload(constants)
from constants import *

import functions_analysis
import functions_job_creation
import functions_filtering
import functions_plotting
import functions_download
import functions_pdb2net

# Reload the module
importlib.reload(functions_analysis)
importlib.reload(functions_job_creation)
importlib.reload(functions_filtering)
importlib.reload(functions_download)
importlib.reload(functions_plotting)
importlib.reload(functions_pdb2net)

from functions_analysis import *
from functions_job_creation import *
from functions_download import *
from functions_filtering import *
from functions_plotting import *
from functions_pdb2net import *

# 2. Filter

In [None]:
rep = pd.read_csv('/home/markus/MPI_local/data/PDB_reports/4/combined_pdb_reports_processed.csv', low_memory=False)

In [None]:
# remove rows with NaN in sequence field
print(f"Before removing NaN sequences: {len(rep)}")
rep = rep.dropna(subset=['Sequence'])
print(f"After removing NaN sequences: {len(rep)}")

In [None]:
# length filter
MIN_LENGTH = 20

print(f"Before length filter: {len(rep)}")
rep['Sequence length'] = rep['Sequence'].fillna('').str.len()
rep = rep[rep['Sequence length'] >= MIN_LENGTH]
print(f"After length filter: {len(rep)}")

In [None]:
rep = add_iupred3(rep, 'long', 'no', IUPRED_CACHE_DIR, IUPRED3_THRESHOLD, MIN_LENGTH_DISORDERED_REGION, IUPRED3_PATH, debug=False)

In [None]:
# filter out all PDBs that don't have at least 1 disordered chain
print(len(rep))
keep_pdbs = set()
for _, row in rep.iterrows():
    if row['num_disordered_regions'] > 0:
        keep_pdbs.add(row['Entry ID'])
        
rep = rep[rep['Entry ID'].isin(keep_pdbs)]

print(len(rep))

In [None]:
# filter out all PDBs that don't have at least 1 ordered chain
print(len(rep))
keep_pdbs = set()
for _, row in rep.iterrows():
    if row['num_disordered_regions'] == 0:
        keep_pdbs.add(row['Entry ID'])
        
rep = rep[rep['Entry ID'].isin(keep_pdbs)]

print(len(rep))

In [None]:
# write into directory for PDB2Net processing
PDB2NET_PREFIX = '/home/markus/PDB2Net/in/'
path_prefix = PDB2NET_PREFIX + 'pipeline2/'
rep['model_path'] = rep['Entry ID'].apply(lambda id: path_prefix + id.lower() + '.cif')

rep.drop_duplicates(subset=['model_path'], inplace=False)['model_path'].to_csv(PDB2NET_PREFIX + 'pipeline2.csv', index=False)

# download pdb structures for pdb2net
download_pdb_structures(set(rep['Entry ID'].tolist()), path_prefix, 'cif', '/home/markus/MPI_local/data/PDB', debug=False)

In [None]:
# annotate interface
# define interface as having at least INTERFACE_MIN_ATOMS atoms within INTERFACE_MAX_DISTANCE A of each other
INTERFACE_MIN_ATOMS = 10
INTERFACE_MAX_DISTANCE = 5 # higher not possible => change PDB2Net data

In [None]:
PATH = '/home/markus/MPI_local/data/PDB2Net/pipeline2/2025-08-28_19-38-42'
interfaces_df = get_interfaces_pdb2net(PATH, INTERFACE_MIN_ATOMS, INTERFACE_MAX_DISTANCE, set(rep['Entry ID'].to_list()))

In [None]:
# Deduplicate interfaces based on UniProt IDs
print(f"Before deduplication: {len(interfaces_df)} interfaces")

interfaces_df['normalized_uniprot'] = interfaces_df['Uniprot IDs'].apply(normalize_uniprot_pair)
interfaces_df = interfaces_df.drop_duplicates(subset=['normalized_uniprot'])
interfaces_df = interfaces_df.drop('normalized_uniprot', axis=1)

print(f"After deduplication: {len(interfaces_df)} interfaces")

In [None]:
# use only interactions between ordered-disordered chain
rep_reindex = rep.copy()

rep_reindex.set_index(['Entry ID', 'Asym ID'], inplace=True)

import json
print(len(interfaces_df))

for ind,row in interfaces_df.iterrows():
    interface_id = json.loads(row['Interface ID'].replace('\'', '"'))
    chainID_1 = interface_id[0]
    chainID_2 = interface_id[1]
    try:
        disorder_chain1 = int(rep_reindex.loc[(row['Entry ID'], chainID_1), 'num_disordered_regions'])
        disorder_chain2 = int(rep_reindex.loc[(row['Entry ID'], chainID_2), 'num_disordered_regions'])
    except KeyError as e:
        # Skip this interface if entry/chain not found
        # print(f"Entry not found: {e}")
        interfaces_df.drop(ind, inplace=True)
        continue
    disorder_chain1 = int(rep_reindex.loc[(row['Entry ID'], chainID_1), 'num_disordered_regions'])
    disorder_chain2 = int(rep_reindex.loc[(row['Entry ID'], chainID_2), 'num_disordered_regions'])
    
    if not ((disorder_chain1 == 0 and disorder_chain2 >= 1) or (disorder_chain2 == 0 and disorder_chain1 >= 1)):
        interfaces_df.drop(ind, inplace=True)
        
print(len(interfaces_df))

In [None]:
# annotate interfaces with data
interfaces_df['Release Date'] = interfaces_df['Entry ID'].map(rep.groupby('Entry ID')['Release Date'].first())

In [None]:
interfaces_df['in_training_set'] = interfaces_df['Release Date'] <= AF_TRAINING_CUTOFF

In [None]:
up_ids_structure_ds = interfaces_df['Uniprot IDs'].tolist()
%store up_ids_structure_ds

# 3. Job Creation

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from functions_filtering import *

import importlib
import constants
importlib.reload(constants)
from constants import *

import functions_analysis
import functions_job_creation
import functions_filtering
import functions_plotting
import functions_download
import functions_pdb2net

# Reload the module
importlib.reload(functions_analysis)
importlib.reload(functions_job_creation)
importlib.reload(functions_filtering)
importlib.reload(functions_download)
importlib.reload(functions_plotting)
importlib.reload(functions_pdb2net)

from functions_analysis import *
from functions_job_creation import *
from functions_download import *
from functions_filtering import *
from functions_plotting import *
from functions_pdb2net import *

In [None]:
BATCH_DIRS = []
BATCH_DIRS.extend([os.path.join('../../production1/PDB_modelling/', d) for d in os.listdir('../../production1/PDB_modelling') if os.path.isdir(os.path.join('../../production1/PDB_modelling', d)) and 'batch' in d])

interfaces_df.rename(columns={'Interface ID': 'Chains'}, inplace=True)
new_af_jobs, reference_jobs = create_job_batch_from_PDB_chains(interfaces_df, BATCH_DIRS, 5120)
write_af_jobs_to_individual_files(new_af_jobs, '../../production1/PDB_modelling/batch_12')
write_af_jobs_to_individual_files(reference_jobs, '../../production1/PDB_modelling/reference', 'alphafold3', True)
interfaces_df.rename(columns={'Chains': 'Interface ID'}, inplace=True)

# 4. Analysis

In [None]:
results_df = interfaces_df.copy()
results_df.rename(columns={'Chains': 'Interface ID'}, inplace=True)
results_df.rename(columns={'pdb_id': 'Entry ID'}, inplace=True)

In [None]:
NATIVE_PATH_PREFIX = "/home/markus/MPI_local/data/PDB/"
HPC_FULL_RESULTS_DIR = "/home/markus/MPI_local/HPC_results_full"
PDB_CACHE = '../../production1/pdb_cache'
DOCKQ_CACHE = '../../production1/dockq_cache_p2'
JOB_DIR = '/home/markus/MPI_local/production1/PDB_modelling'
# rename columns for compatibility with annotate_dockq()
results_df['job_name'] = results_df.apply(lambda row: row['Entry ID'] + "_" + "_".join(eval(row['Interface ID'])), axis=1)
results_df.rename(columns={'Entry ID': 'pdb_id'}, inplace=True)
results_df, no_model = append_dockq_two_chainIDs(results_df, NATIVE_PATH_PREFIX, HPC_FULL_RESULTS_DIR, JOB_DIR, PDB_CACHE, DOCKQ_CACHE)
results_df.rename(columns={'pdb_id': 'Entry ID'}, inplace=True)

In [None]:
print(no_model)
find_job_files(no_model, '/home/markus/MPI_local/production1')

In [None]:
results_df = annotate_AF_metrics(results_df, '/home/markus/MPI_local/HPC_results_full')

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
create_scatter_plot_colour(results_df, 'iptm', 'dockq_score', 'in_training_set', ax=axes[0], corr=True)
create_scatter_plot_colour(results_df, 'ptm', 'dockq_score',  'in_training_set', ax=axes[1], corr=True)
create_scatter_plot(results_df, 'ranking_score', 'dockq_score', ax=axes[2], corr=True)

In [None]:
results_df['interface_pae_max'] = results_df['chain_pair_pae_min'].apply(lambda c: max([c[0][1], c[1][0]]))
results_df['interface_pae_min'] = results_df['chain_pair_pae_min'].apply(lambda c: min([c[0][1], c[1][0]]))

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
create_scatter_plot(results_df, 'interface_pae_max', 'dockq_score', ax=axes[0], corr=True)
create_scatter_plot(results_df, 'interface_pae_min', 'dockq_score', ax=axes[1], corr=True)

In [None]:
results_df['chain_pair_iptm_max'] = results_df['chain_pair_iptm'].apply(lambda c: max([c[0][0], c[1][1]]))
results_df['chain_pair_iptm_min'] = results_df['chain_pair_iptm'].apply(lambda c: min([c[0][0], c[1][1]]))


In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
create_scatter_plot(results_df, 'chain_pair_iptm_max', 'dockq_score', ax=axes[0], corr=True)
create_scatter_plot(results_df, 'chain_pair_iptm_min', 'dockq_score', ax=axes[1], corr=True)

## PR curve dockQ > 0.23

In [None]:
from sklearn.metrics import precision_recall_curve, roc_curve, auc
import matplotlib.pyplot as plt
import numpy as np

# Define the metrics to evaluate
metrics = ['iptm', 'ranking_score', 'ptm', 'chain_pair_iptm_max', 'chain_pair_iptm_min', 
           'interface_pae_max', 'interface_pae_min']

# Create binary labels based on dockQ
y_true = (results_df['dockq_score'] >= 0.80).astype(int)

# Report number of 1s and 0s in y_true
print("Dataset Class Distribution:")
print(f"Total samples: {len(y_true)}")
print(f"Positive samples (dockQ >= 0.80): {y_true.sum()} ({y_true.mean():.3f})")
print(f"Negative samples (dockQ < 0.80): {len(y_true) - y_true.sum()} ({1 - y_true.mean():.3f})")
print()

print("AUC Results and Optimal Cutoffs:")
print("=" * 60)

auc_results = {}

for metric in metrics:
    # Get predictions (higher values should indicate better models for most metrics)
    # For PAE metrics, we need to invert since lower PAE is better
    if 'pae' in metric:
        y_scores = -results_df[metric]  # Invert PAE scores
        original_scores = results_df[metric]  # Keep original for cutoff reporting
    else:
        y_scores = results_df[metric]
        original_scores = results_df[metric]
    
    # Remove NaN values
    mask = ~(y_scores.isna() | y_true.isna())
    y_scores_clean = y_scores[mask]
    y_true_clean = y_true[mask]
    original_scores_clean = original_scores[mask]
    
    if len(y_scores_clean) == 0:
        print(f"{metric}: No valid data")
        continue
    
    # Calculate PR curve and AUC
    precision, recall, pr_thresholds = precision_recall_curve(y_true_clean, y_scores_clean)
    pr_auc = auc(recall, precision)
    
    # Calculate ROC curve and AUC
    fpr, tpr, roc_thresholds = roc_curve(y_true_clean, y_scores_clean)
    roc_auc = auc(fpr, tpr)
    
    # Find optimal cutoffs
    # 1. F1-score optimal cutoff (from PR curve)
    f1_scores = 2 * (precision[:-1] * recall[:-1]) / (precision[:-1] + recall[:-1])
    f1_scores = np.nan_to_num(f1_scores)  # Handle division by zero
    optimal_f1_idx = np.argmax(f1_scores)
    optimal_f1_threshold = pr_thresholds[optimal_f1_idx]
    optimal_f1_score = f1_scores[optimal_f1_idx]
    
    # 2. Youden's J statistic optimal cutoff (from ROC curve)
    # J = sensitivity + specificity - 1 = tpr - fpr
    youden_j = tpr - fpr
    optimal_youden_idx = np.argmax(youden_j)
    optimal_youden_threshold = roc_thresholds[optimal_youden_idx]
    optimal_youden_j = youden_j[optimal_youden_idx]
    
    # Convert back to original scale for PAE metrics
    if 'pae' in metric:
        optimal_f1_threshold_original = -optimal_f1_threshold
        optimal_youden_threshold_original = -optimal_youden_threshold
    else:
        optimal_f1_threshold_original = optimal_f1_threshold
        optimal_youden_threshold_original = optimal_youden_threshold
    
    auc_results[metric] = {
        'PR_AUC': pr_auc, 
        'ROC_AUC': roc_auc,
        'optimal_f1_threshold': optimal_f1_threshold_original,
        'optimal_f1_score': optimal_f1_score,
        'optimal_youden_threshold': optimal_youden_threshold_original,
        'optimal_youden_j': optimal_youden_j
    }
    
    print(f"{metric}:")
    print(f"  PR-AUC: {pr_auc:.4f}")
    print(f"  ROC-AUC: {roc_auc:.4f}")
    print(f"  Optimal F1 cutoff: {optimal_f1_threshold_original:.4f} (F1 = {optimal_f1_score:.4f})")
    print(f"  Optimal Youden cutoff: {optimal_youden_threshold_original:.4f} (J = {optimal_youden_j:.4f})")
    print()

In [None]:
# Plot Precision-Recall and ROC curves for ranking_score
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Get ranking_score data
metric = 'ranking_score'
y_scores = results_df[metric]

# Remove NaN values
mask = ~(y_scores.isna() | y_true.isna())
y_scores_clean = y_scores[mask]
y_true_clean = y_true[mask]

# Calculate and plot PR curve
precision, recall, _ = precision_recall_curve(y_true_clean, y_scores_clean)
pr_auc = auc(recall, precision)
ax1.plot(recall, precision, label=f'PR Curve (AUC = {pr_auc:.4f})')
ax1.set_xlabel('Recall')
ax1.set_ylabel('Precision')
ax1.set_title(f'Precision-Recall Curve - {metric}')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Calculate and plot ROC curve
fpr, tpr, _ = roc_curve(y_true_clean, y_scores_clean)
roc_auc = auc(fpr, tpr)
ax2.plot(fpr, tpr, label=f'ROC Curve (AUC = {roc_auc:.4f})')
ax2.plot([0, 1], [0, 1], 'k--', label='Random')
ax2.set_xlabel('False Positive Rate')
ax2.set_ylabel('True Positive Rate')
ax2.set_title(f'ROC Curve - {metric}')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print summary statistics
print(f"\nSummary for {metric}:")
print(f"Total samples: {len(y_true_clean)}")
print(f"Positive samples (dockq > 0.23): {y_true_clean.sum()}")
print(f"Negative samples (dockq <= 0.23): {len(y_true_clean) - y_true_clean.sum()}")
print(f"Baseline accuracy: {y_true_clean.sum() / len(y_true_clean):.4f}")