In [23]:
import os
import json
import pickle
import numpy as np
import pandas as pd
from Bio import PDB

def reverse_and_scale_matrix(matrix: np.ndarray, pae_cutoff: float = 12.0) -> np.ndarray:
    """
    Scale the values in the matrix such that:
    0 becomes 1, pae_cutoff becomes 0, and values greater than pae_cutoff are also 0.
    
    Args:
    - matrix (np.ndarray): Input numpy matrix.
    - pae_cutoff (float): Threshold above which values become 0.
    
    Returns:
    - np.ndarray: Transformed matrix.
    """
    
    # Scale the values to [0, 1] for values between 0 and cutoff
    scaled_matrix = (pae_cutoff - matrix) / pae_cutoff
    scaled_matrix = np.clip(scaled_matrix, 0, 1)  # Ensures values are between 0 and 1
    
    return scaled_matrix


def process_alphafold_output(base_directory: str, Protein_1 = "A", Protein_2 = "B", pae_cutoff: float = 12.0, lis_threshold: float = 0.203, lia_threshold: float = 3432.0) -> pd.DataFrame:
    """
    Process AlphaFold output files and return a DataFrame containing relevant information.
    
    Args:
    - base_directory (str): Base directory where the AlphaFold output files are stored.
    - pae_cutoff (float): Threshold for PAE matrix above which values become 0.
    - lis_threshold (float): Threshold for LIS (Local Interaction Score) to filter results.
    - lia_threshold (float): Threshold for LIA (Local Interaction Area) to filter results.
    
    Returns:
    - pd.DataFrame: DataFrame containing processed data.
    """
    
    # Define the range of model numbers and recycling numbers based on your files
    model_numbers = 5  # Models 1 to 5
    recycling_numbers = 5  # Recycling indices 0 to 4
    
    # Initialize a list to hold the Pandas Series objects
    series_list = []

    # Loop over each model number and recycling number
    for model_num in range(1, model_numbers+1):
        for recycling_num in range(0, recycling_numbers):
            # print(f"Model: {model_num}, Recycle: {recycling_num}")
            
            pae_name = f"pae_model_{model_num}_multimer_v3_pred_{recycling_num}.json"
            pae_file_path = os.path.join(base_directory, pae_name)
            with open(pae_file_path, 'r') as file:
                json_data = json.load(file)


            pdb_name = f"unrelaxed_model_{model_num}_multimer_v3_pred_{recycling_num}.pdb"
            pdb_file_path = os.path.join(base_directory, pdb_name)


            parser = PDB.PDBParser(QUIET=True)
            structure = parser.get_structure("example", pdb_file_path)
            
            chain_lengths = {}
            
            for model in structure:
                for chain in model:
                    chain_id = chain.get_id()
                    chain_length = sum(1 for _ in chain.get_residues())  # Calculate chain length
                    chain_lengths[chain_id] = chain_length
                    # Accessing the length of chain 'A' from the dictionary
                    protein_a_len = chain_lengths.get('B', 0)  # Default to 0 if 'A' chain is not found
            # print(chain_lengths)
            # print(protein_a_len)

            pkl_name = f"result_model_{model_num}_multimer_v3_pred_{recycling_num}.pkl"
            pkl_file_path = os.path.join(base_directory, pkl_name)
            d = pickle.load(open(pkl_file_path,'rb'))
            iptm = d.get('iptm')
            ptm = d.get('ptm')
            plddt = np.mean(d.get('plddt'))
            confidence = d.get('ranking_confidence')

            # print(pae_file_path)
            # print(pdb_file_path)
            # print(pkl_file_path)
            # print(iptm)
            # print(ptm)
            # print(plddt)

            pae = np.array(json_data[0]['predicted_aligned_error'])
            # pae_data[f"model_{model_num}_recycle_{recycling_num}"] = pae_matrix

            pae_cutoff = 12

            thresholded_pae = np.where(pae < pae_cutoff, 1, 0)

            # Calculate the interaction amino acid numbers
            local_interaction_protein_a = np.count_nonzero(thresholded_pae[:protein_a_len, :protein_a_len])
            local_interaction_protein_b = np.count_nonzero(thresholded_pae[protein_a_len:, protein_a_len:])
            local_interaction_interface_1 = np.count_nonzero(thresholded_pae[:protein_a_len, protein_a_len:])
            local_interaction_interface_2 = np.count_nonzero(thresholded_pae[protein_a_len:, :protein_a_len])
            local_interaction_interface_avg = (
                local_interaction_interface_1 + local_interaction_interface_2
            )

            
            # Calculate average thresholded_pae for each region
            average_thresholded_protein_a = thresholded_pae[:protein_a_len,:protein_a_len].mean() * 100
            average_thresholded_protein_b = thresholded_pae[protein_a_len:,protein_a_len:].mean() * 100
            average_thresholded_interaction1 = thresholded_pae[:protein_a_len,protein_a_len:].mean() * 100
            average_thresholded_interaction2 = thresholded_pae[protein_a_len:,:protein_a_len].mean() * 100
            average_thresholded_interaction_total = (average_thresholded_interaction1 + average_thresholded_interaction2) / 2
            

            pae_protein_a = np.mean( pae[:protein_a_len,:protein_a_len] )
            pae_protein_b = np.mean( pae[protein_a_len:,protein_a_len:] )
            pae_interaction1 = np.mean(pae[:protein_a_len,protein_a_len:])
            pae_interaction2 = np.mean(pae[protein_a_len:,:protein_a_len])
            pae_interaction_total = ( pae_interaction1 + pae_interaction2 ) / 2

            # For pae_A
            selected_values_protein_a = pae[:protein_a_len, :protein_a_len][thresholded_pae[:protein_a_len, :protein_a_len] == 1]
            average_selected_protein_a = np.mean(selected_values_protein_a)

            # For pae_B
            selected_values_protein_b = pae[protein_a_len:, protein_a_len:][thresholded_pae[protein_a_len:, protein_a_len:] == 1]
            average_selected_protein_b = np.mean(selected_values_protein_b)

            # For pae_interaction1
            selected_values_interaction1 = pae[:protein_a_len, protein_a_len:][thresholded_pae[:protein_a_len, protein_a_len:] == 1]
            average_selected_interaction1 = np.mean(selected_values_interaction1) if selected_values_interaction1.size > 0 else pae_cutoff

            # For pae_interaction2
            selected_values_interaction2 = pae[protein_a_len:, :protein_a_len][thresholded_pae[protein_a_len:, :protein_a_len] == 1]
            average_selected_interaction2 = np.mean(selected_values_interaction2) if selected_values_interaction2.size > 0 else pae_cutoff

            # For pae_interaction_total
            average_selected_interaction_total = (average_selected_interaction1 + average_selected_interaction2) / 2

        # At this point, plddt_data and pae_data dictionaries will have the extracted data
            print_results = False
            if print_results:
                # Print the total results
                print("Total pae_A : {:.2f}".format(pae_protein_a))
                print("Total pae_B : {:.2f}".format(pae_protein_b))
                print("Total pae_i_1 : {:.2f}".format(pae_interaction1))
                print("Total pae_i_2 : {:.2f}".format(pae_interaction2))
                print("Total pae_i_avg : {:.2f}".format(pae_interaction_total))

                # Print the local results
                print("Local pae_A : {:.2f}".format(average_selected_protein_a))
                print("Local pae_B : {:.2f}".format(average_selected_protein_b))
                print("Local pae_i_1 : {:.2f}".format(average_selected_interaction1))
                print("Local pae_i_2 : {:.2f}".format(average_selected_interaction2))
                print("Local pae_i_avg : {:.2f}".format(average_selected_interaction_total))

                # Print the >PAE-cutoff area
                print("Local interaction area (Protein A):", local_interaction_protein_a)
                print("Local interaction area (Protein B):", local_interaction_protein_b)
                print("Local interaction area (Interaction 1):", local_interaction_interface_1)
                print("Local interaction area (Interaction 2):", local_interaction_interface_2)
                print("Total Interaction area (Interface):", local_interaction_interface_avg)


            # Transform the pae matrix
            scaled_pae = reverse_and_scale_matrix(pae, pae_cutoff)

            # For local interaction score for protein_a
            selected_values_protein_a = scaled_pae[:protein_a_len, :protein_a_len][thresholded_pae[:protein_a_len, :protein_a_len] == 1]
            average_selected_protein_a_score = np.mean(selected_values_protein_a)

            # For local interaction score for protein_b
            selected_values_protein_b = scaled_pae[protein_a_len:, protein_a_len:][thresholded_pae[protein_a_len:, protein_a_len:] == 1]
            average_selected_protein_b_score = np.mean(selected_values_protein_b)

            # For local interaction score1
            selected_values_interaction1_score = scaled_pae[:protein_a_len, protein_a_len:][thresholded_pae[:protein_a_len, protein_a_len:] == 1]
            average_selected_interaction1_score = np.mean(selected_values_interaction1_score) if selected_values_interaction1_score.size > 0 else 0

            # For local interaction score2
            selected_values_interaction2_score = scaled_pae[protein_a_len:, :protein_a_len][thresholded_pae[protein_a_len:, :protein_a_len] == 1]
            average_selected_interaction2_score = np.mean(selected_values_interaction2_score) if selected_values_interaction2_score.size > 0 else 0

            # For average local interaction score
            average_selected_interaction_total_score = (average_selected_interaction1_score + average_selected_interaction2_score) / 2

            # Append the data to the series list
            series_list.append(pd.Series({
                'Protein_1': Protein_1,
                'Protein_2': Protein_2,
                'LIS': round(average_selected_interaction_total_score, 3), # Local Interaction Score (LIS)
                'LIA': local_interaction_interface_avg, # Local Interaction Area (LIA)
                'ipTM': round(float(iptm), 3),
                'Confidence': round(float(iptm*0.8 + ptm*0.2),3),
                'pTM': round(float(ptm), 3),
                'pLDDT': round(plddt, 2),
                'Model': model_num,
                'Recycle': recycling_num,
                'saved folder': os.path.dirname(pdb_file_path),
                'pdb': os.path.basename(pdb_file_path),
                'pkl': os.path.basename(pkl_file_path),
            }))

    # Concatenate all Pandas Series objects into a single DataFrame
    result_df = pd.concat(series_list, axis=1).T
    
    # Filter rows based on the specified LIS and LIA thresholds
    result_df_filtered = result_df[(result_df['LIS'] >= lis_threshold) & (result_df['LIA'] >= lia_threshold)]
    
    return result_df, result_df_filtered

# Create a DataFrame for the additional data
metrics_data = {
    'Metric': ['Average LIS', 'Best LIS', 'Average LIA', 'Best LIA', 'Average ipTM', 'Best ipTM', 'Average Confidence', 'Best Confidence', 'Average pDockQ', 'Best pDockQ', 'Average pDockQ2', 'Best pDockQ2'],
    'Optimal Threshold': [0.0734, 0.203, 1610.4, 3432, 0.322, 0.38, 0.3672, 0.432, 0.133427109, 0.148516258, 0.015093248, 0.02106924],
    'Specificity': [0.926011561, 0.919075145, 0.876300578, 0.855491329, 0.937572254, 0.823121387, 0.951445087, 0.85433526, 0.804624277, 0.865895954, 0.917919075, 0.895953757],
    'Sensitivity': [0.786516854, 0.730337079, 0.767790262, 0.775280899, 0.711610487, 0.734082397, 0.674157303, 0.685393258, 0.68164794, 0.666666667, 0.692883895, 0.670411985],
    'AUC': [0.910601632, 0.890710312, 0.888699097, 0.86613626, 0.891301336, 0.862501353, 0.859056959, 0.84053387, 0.79601221, 0.818267628, 0.841622827, 0.832053863],
    "Youden's Index": [0.712528415, 0.649412223, 0.64409084, 0.630772228, 0.649182741, 0.557203784, 0.62560239, 0.539728519, 0.486272218, 0.53256262, 0.61080297, 0.566365742]
}
metrics_data_df = pd.DataFrame(metrics_data)
metrics_data_df = metrics_data_df.round(3)



In [24]:
# Example usage:
base_directory = 'YOUR ALPHAFOLD OUTPUT FOLDER'

Protein_1 = "A"
Protein_2 = "B"
pae_cutoff = 12 
lis_threshold = 0.203 # might be too strict for weak or transient interaction
lia_threshold = 3432.0 
total_prediction, positive_prediction = process_alphafold_output(base_directory,  Protein_1, Protein_2, pae_cutoff, lis_threshold, lia_threshold,)


In [None]:
# Write DataFrames to an Excel file with three sheets
with pd.ExcelWriter('alphafold_predictions.xlsx') as writer:
    positive_prediction.to_excel(writer, sheet_name='Positive_PPI', index=False)
    total_prediction.to_excel(writer, sheet_name='Total_Prediction', index=False)
    
    # Write metrics data with wrapped header
    metrics_data_df.to_excel(writer, sheet_name='Optimal Thresholds', index=False)
    
    # Get the workbook and the worksheet
    workbook = writer.book
    worksheet = writer.sheets['Optimal Thresholds']
    
    # Get the header range
    header_range = f'A1:{chr(64 + len(metrics_data_df.columns))}1'
    
    # Wrap the header text
    header_format = workbook.add_format({'text_wrap': True, 'valign': 'vcenter', 'align': 'center', 'bold': True})
    worksheet.set_row(0, None, header_format)
    
    # Auto-adjust column width
    for i, col in enumerate(metrics_data_df.columns):
        column_len = metrics_data_df[col].astype(str).str.len().max()
        column_width = max(column_len, len(col))
        worksheet.set_column(i, i, column_width + 2)
