In [3]:
from pyfoldx.structure.structure import Structure
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import time
import plotly.graph_objects as go

def create_structure(pdb_id):
    """
    Initializes the Structure object for the given PDB ID.
    """
    return Structure(code=pdb_id)

def get_chain_sequence(structure, chain_id):
    """
    Retrieves the sequence of a specified chain from the structure.
    """
    return structure.getSequence(chain=chain_id)

def calculate_ddg(chain_sequence, chain_id, mutants, structure):
    """
    Calculates ΔΔG values for all specified mutations.
    """
    sequence_length = len(chain_sequence)
    df = pd.DataFrame(index=list(mutants), columns=range(1, sequence_length + 1))

    # Start timing the process
    start_time = time.time()

    for residue_number in tqdm(range(1, sequence_length + 1), desc="Processing Sequence", colour='#2bceee'):
        wt_residue = chain_sequence[residue_number - 1]

        # Mutations excluding the wild type
        mutations = [mutant_residue for mutant_residue in mutants if mutant_residue != wt_residue]
        for mutant_residue in mutations:
            mutation = f"{wt_residue}{chain_id}{residue_number}{mutant_residue};"

            # Perform mutation and extract ΔΔG value
            ddGs, mutModels, wtModels = structure.mutate(mutation, verbose=False)
            df.at[mutant_residue, residue_number] = ddGs['total'].values[0]

    # Print timing information
    total_time = time.time() - start_time
    print(f"Total time for all mutations: {total_time:.2f} seconds")

    return df

def plot_heatmap(value_df, chain_sequence, output_html="heatmap.html"):
    """
    Creates and saves a heatmap of ΔΔG values to an HTML file.
    """
    df_numeric = value_df.apply(pd.to_numeric, errors='coerce')
    
    # Create heatmap
    heatmap = go.Figure(data=go.Heatmap(
        z=df_numeric.values,
        x=df_numeric.columns,
        y=df_numeric.index,
        colorscale='RdBu',
        colorbar=dict(
            title='ΔΔG (kcal/mol)',
            tickvals=[-2, 0, 5],
            ticktext=['-2', '0', '5'],
            thickness=20
        ),
        zmin=-2,
        zmax=5
    ))

    # Update layout with better spacing and readability
    heatmap.update_layout(
        title=dict(
            text="Heatmap of ΔΔG for Mutants at Each Residue",
            x=0.5,
            y=0.95
        ),
        xaxis_title="Residue Position",
        yaxis_title="Mutant Amino Acid",
        xaxis=dict(
            tickmode='array',
            tickvals=list(range(len(chain_sequence))),
            ticktext=[f"{aa}{i+1}" for i, aa in enumerate(chain_sequence)],
            tickangle=45,
            tickfont=dict(size=10)
        ),
        height=600,
        width=1000,
        margin=dict(t=100, b=100, l=100, r=50)
    )

    # Add protein sequence with better positioning
    heatmap.add_annotation(
        x=0.5,
        y=-0.15,
        xref="paper",
        yref="paper",
        text=f"Protein Sequence: {''.join(chain_sequence)}",
        showarrow=False,
        font=dict(size=12),
        align="center"
    )

    # # Add wild-type matches with a single legend entry
    # first_point = True
    # for residue_number, wt_residue in enumerate(chain_sequence, start=1):
    #     if wt_residue in value_df.index:
    #         heatmap.add_trace(go.Scatter(
    #             x=[residue_number-1],
    #             y=[value_df.index.tolist().index(wt_residue)],
    #             mode='markers',
    #             marker=dict(color='black', size=8),
    #             name="Wild-type Match" if first_point else "Wild-type Match",
    #             showlegend=first_point
    #         ))
    #         first_point = False

    heatmap.write_html(output_html)
    print(f"Heatmap saved to {output_html}")
    

def run_pipeline(pdb_id, chain_id, mutants, output_html="heatmap.html"):
    """
    Runs the full pipeline: load structure, calculate ΔΔG, and generate heatmap.
    """
    print(f"Initializing structure for PDB ID: {pdb_id}")
    structure = create_structure(pdb_id)

    print(f"Retrieving chain sequence for chain ID: {chain_id}")
    chain_sequence = get_chain_sequence(structure, chain_id)

    print("Calculating ΔΔG values...")
    value_df = calculate_ddg(chain_sequence, chain_id, mutants, structure)

    print("Generating heatmap...")
    plot_heatmap(value_df, chain_sequence, output_html)

    print("Pipeline complete!")

# Example Usage
if __name__ == "__main__":
    pdb_id = "2L09"  # Replace with your PDB ID
    chain_id = "A"  # Replace with your chain ID
    mutants = "ACDEFGHIKLMNPQRSTVWY"  # All possible mutations
    output_html = "heatmap_2L09.html"  # Output HTML file name

    run_pipeline(pdb_id, chain_id, mutants, output_html)


Initializing structure for PDB ID: 2L09
Retrieving chain sequence for chain ID: A
Calculating ΔΔG values...


Processing Sequence:   0%|          | 0/62 [00:00<?, ?it/s]

Total time for all mutations: 7398.76 seconds
Generating heatmap...
Heatmap saved to heatmap_2L09.html
Pipeline complete!
