In [1]:
import json
import os
import pandas as pd
from Bio import PDB
import matplotlib.pyplot as plt



In [1]:
def calculate_residue_plddt(directory, base_name):
    # Loop through the indices for the models
    all_residue_plddt = []
    for i in range(5):  # Models 0 to 4
        json_file = f"fold_{base_name}_full_data_{i}.json"
        cif_file = f"fold_{base_name}_model_{i}.cif"

        json_path = os.path.join(directory, json_file)
        cif_path = os.path.join(directory, cif_file)

        # Check if both files exist
        if os.path.exists(json_path) and os.path.exists(cif_path):
            # Load the JSON data
            with open(json_path) as json_data:
                plddts = json.load(json_data)['atom_plddts']

            # Parse the CIF file
            structure = PDB.MMCIFParser(QUIET=True).get_structure('model', cif_path)
            
            # Extract atom-to-residue information 
            atom_to_residue = [(chain.id, residue.resname, residue.id[1]) for model in structure for chain in model for residue in chain for atom in residue]

            # Create dataframe
            df = pd.DataFrame({
                'chain_id': [item[0] for item in atom_to_residue],
                'res_name': [item[1] for item in atom_to_residue],
                'res_id': [item[2] for item in atom_to_residue],
                'plddts': plddts
            })

            # Calculate the mean pLDDT for each residue
            residue_plddt = df.groupby(['chain_id', 'res_id']).agg(
                res_name=('res_name', 'first'),
                mean_plddt=('plddts', 'mean'),
                std_plddt=('plddts', 'std'),
            ).reset_index()

            # Calculate the coefficient of variation (CV)
            residue_plddt['cv_plddt'] = residue_plddt['std_plddt'] / residue_plddt['mean_plddt']

            # Add a column for the source file
            residue_plddt['model'] = f"model_{i}"

            # Append to the list
            all_residue_plddt.append(residue_plddt)

    # Concatenate dataframes
    all_residue_plddt_df = pd.concat(all_residue_plddt, ignore_index=True)

    # Plot CV pLDDT as a boxplot for all models
    plt.figure(figsize=(15, 6))
    all_residue_plddt_df.boxplot(column='cv_plddt', by='model', grid=False)
    plt.title(f"CVs of pLDDT for {base_name} Models")
    plt.suptitle('')  # Remove the default title to avoid duplication
    plt.xlabel('Model')
    plt.ylabel('CV of per-residue pLDDT')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

    # Plot mean pLDDT vs residue position with CV error bars for all models
    plt.figure(figsize=(15, 6))
    for model in range(5):
        model_data = all_residue_plddt_df[all_residue_plddt_df['model'] == f"model_{model}"]
        for chain in model_data['chain_id'].unique():
            chain_data = model_data[model_data['chain_id'] == chain]
            plt.errorbar(chain_data['res_id'], chain_data['mean_plddt'], yerr=chain_data['std_plddt'], label=f"Model {model} Chain {chain}")

    plt.title(f"pLDDT vs Residue Position for {base_name} Models")
    plt.xlabel('Residue Position')
    plt.ylabel('Mean pLDDT')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    return                        