# Read and Analyze AlphaFold Server Outputs

1. Run AlphaFold Server
2. Downlaod ZIP files from Google Drive
3. Place ZIP files in the defined folder
4. Run Script

## Imports

In [None]:
%pip install -q numpy pandas matplotlib py3Dmol openpyxl Bio

In [None]:
import zipfile
import json
import py3Dmol
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from IPython.display import display, HTML

## Configuration

In [None]:
class Config:
    AF_DIR = "AlphaFold"
    CANDIDATE_NUM = "1"

config = Config()

dir = Path(config.AF_DIR)
if not dir.exists():
    print(f"Creating directory: {dir}")
    dir.mkdir()

print(f"Searching for ColabFold .zip outputs in: {dir}")

# Find all zip files in the directory
zip_files = list(dir.glob("*.zip"))

if not zip_files:
    print("\nERROR: No .zip files found in the validation directory.")
    print("Please make sure to place your ColabFold outputs there before running this cell.")
else:
    print(f"Found {len(zip_files)} candidate zip files to analyze.")

## Data Analysis

In [None]:
import shutil
from Bio.PDB import MMCIFParser, PDBIO
from IPython.display import Image
LOOP_START_RESIDUE = 62
LOOP_END_RESIDUE = 69
CHAIN_SPLIT_POINT = 188  # Adjust based on your specific protein complex

results_list = []

# --- LOOP THROUGH EACH ZIP FILE AND ANALYZE ---
for zip_file_path in sorted(zip_files):
    candidate_name = zip_file_path.stem
    display(HTML(f"<hr><h2>Analyzing Candidate: {candidate_name} (Model {config.CANDIDATE_NUM})</h2>"))

    temp_split = candidate_name.split('_')

    if len(temp_split) >= 4:
        ligand_name = temp_split[1]
        target_name = temp_split[3]
        variant_name = temp_split[4]
    else:
        print("WARNING: Candidate names do not match the expected format {ligand}_variant_{target}_{variant}. Cannot create comparison matrices.")
        ligand_name = "Chain A"
        target_name = "Chain B"
        variant_name = "Unknown"
        
    # Unzip the contents into a subdirectory
    unzip_dir = dir / candidate_name
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        zip_ref.extractall(unzip_dir)
    
    # Find the files for the top-ranked model (rank 1)
    try:
        scores_file = next(unzip_dir.glob(f"*_summary_confidences_{config.CANDIDATE_NUM}.json"))
        full_data_file = next(unzip_dir.glob(f"*_full_data_{config.CANDIDATE_NUM}.json"))
        structure_file = next(unzip_dir.glob(f"*_model_{config.CANDIDATE_NUM}.cif"))
    except StopIteration:
        print(f"WARNING: Could not find all required output files for rank {config.CANDIDATE_NUM} in {candidate_name}. Skipping.")
        continue

    # --- EXTRACT AND DISPLAY KEY SCORES ---
    with open(scores_file, 'r') as f:
        scores_data = json.load(f)

    ipTM_score = scores_data.get("iptm")

    with open(full_data_file, 'r') as f:
        pae_data = json.load(f)

    plddt_array = pae_data.get("atom_plddts")
    mean_plddt = np.mean(plddt_array) if plddt_array else 0
    
    # --- NEW: TARGETED pLDDT ANALYSIS FOR THE MODIFIED LOOP (CIF-based) ---
    print(f"\n--- Targeted pLDDT Analysis (Loop {LOOP_START_RESIDUE}-{LOOP_END_RESIDUE}) ---")

    # Define the chain your loop is on
    LOOP_CHAIN_ID = 'A' 
    loop_plddt_values = []

    try:
        # 1. Parse the mmCIF file (you already do this later, but we move it up)
        parser = MMCIFParser()
        structure = parser.get_structure(candidate_name, str(structure_file))
        
        # Get the specific model (usually model 1, index 0)
        model = structure[0]
        
        # Get the specific chain
        chain = model[LOOP_CHAIN_ID]

        # 2. Loop through the residues in the target range
        for res_num in range(LOOP_START_RESIDUE, LOOP_END_RESIDUE + 1):
            # Bio.PDB requires a full residue ID: (' ', res_num, ' ') for standard residues
            residue_id = (' ', res_num, ' ')
            
            try:
                residue = chain[residue_id]
                
                # 3. Collect the pLDDT (stored in the B-factor column) for all atoms
                for atom in residue:
                    # The B-factor value is the pLDDT score
                    loop_plddt_values.append(atom.get_bfactor())
                    
            except KeyError:
                # Handle cases where a residue might be missing (e.g., if it's truncated)
                print(f"Warning: Residue {LOOP_CHAIN_ID}:{res_num} not found in structure.")
                continue

        # 4. Calculate the mean pLDDT for the loop
        if loop_plddt_values:
            loop_mean_plddt = np.mean(loop_plddt_values)
        else:
            loop_mean_plddt = 0.0

    except Exception as e:
        print(f"ERROR during CIF parsing for loop pLDDT: {e}")
        loop_mean_plddt = 0.0

    if loop_mean_plddt > 0:
        print(f"Loop ({LOOP_CHAIN_ID}:{LOOP_START_RESIDUE}-{LOOP_END_RESIDUE}) Mean pLDDT: {loop_mean_plddt:.2f} (Targeted loop confidence)")
    else:
        print("Loop pLDDT calculation failed or loop was not found.")

    results_list.append({
        "Candidate": candidate_name,
        "ipTM": ipTM_score,
        "Mean pLDDT": mean_plddt,
        "Loop pLDDT": loop_mean_plddt # Add the new score
    })

    print("\n--- Confidence Scores ---")
    if ipTM_score is not None:
        print(f"Interface pTM (ipTM): {ipTM_score:.4f}")
        if ipTM_score > 0.85:
            print("Interpretation: High confidence binding prediction.")
        else:
            print("Interpretation: Low confidence binding prediction.")
    
    if mean_plddt > 0:
            print(f"Mean pLDDT: {mean_plddt:.2f} (Overall structure confidence)")
    
    # --- PLOT THE PAE MATRIX FROM RAW JSON DATA ---
    pae_matrix = pae_data.get('pae')
    if pae_matrix:
        print("\n--- Predicted Aligned Error (PAE) Plot ---")
        pae_matrix_np = np.array(pae_matrix)
        total_len = len(pae_matrix_np)
        
        WINDOW_SIZE = 10 
        PAE_START_A = max(0, LOOP_START_RESIDUE - 1 - WINDOW_SIZE)
        PAE_END_A = min(CHAIN_SPLIT_POINT, LOOP_END_RESIDUE + WINDOW_SIZE)
        
        # Define the indices for Chain B region to display (full chain B for interface analysis)
        PAE_START_B = CHAIN_SPLIT_POINT
        PAE_END_B = total_len
        
        # The region of interest for PAE is Chain A vs Chain B (interaction)
        # This corresponds to the top-right and bottom-left quadrants of the full matrix.
        
        # Full PAE Plot 
        plt.figure(figsize=(8, 6))
        plt.imshow(pae_matrix_np, cmap='Greens_r', vmin=0, vmax=30)
        
        # --- ADD DIVIDING LINES ---
        # Vertical line at the chain boundary
        plt.axvline(x=CHAIN_SPLIT_POINT - 0.5, color='black', linestyle='--', linewidth=2)
        # Horizontal line at the chain boundary
        plt.axhline(y=CHAIN_SPLIT_POINT - 0.5, color='black', linestyle='--', linewidth=2)
        # ---------------------------
        
        plt.colorbar(label="Expected Position Error (Å)")
        plt.title(f"PAE Plot for {candidate_name} (Model {config.CANDIDATE_NUM})")
        
        # Add Chain labels if desired
        mid_A = CHAIN_SPLIT_POINT / 2
        mid_B = (CHAIN_SPLIT_POINT + total_len) / 2
        
        plt.text(mid_A, total_len * 1.03, 'Chain A', ha='center', va='bottom', transform=plt.gca().transAxes, fontsize=10)
        plt.text(mid_B, total_len * 1.03, 'Chain B', ha='center', va='bottom', transform=plt.gca().transAxes, fontsize=10)
        plt.xlabel("Scored Residue")
        plt.ylabel("Aligned Residue")
        pae_output_path = dir / f"PAE_{candidate_name}_model_{config.CANDIDATE_NUM}.png"
        plt.savefig(pae_output_path)
        plt.close()

        display(Image(filename=pae_output_path))

        # Focused PAE
        focused_pae = pae_matrix_np[PAE_START_A : PAE_END_A, PAE_START_B : PAE_END_B]

        if focused_pae.size > 0:
            pae_matrix_loop_only = pae_matrix_np[LOOP_START_RESIDUE : LOOP_END_RESIDUE, PAE_START_B : PAE_END_B] 
            mean_interface_pae_loop = np.mean(pae_matrix_loop_only)
            print(f"Mean Interface PAE (Loop vs {target_name}): {mean_interface_pae_loop:.2f} Å")

            plt.figure(figsize=(14, 6))
            plt.imshow(focused_pae, cmap='bwr', vmin=0, vmax=20, aspect='auto') # Use a different cmap for contrast
            
            # Highlight the loop itself on the Y-axis
            y_loop_start = LOOP_START_RESIDUE - 1 - PAE_START_A
            y_loop_end = LOOP_END_RESIDUE - PAE_START_A
            
            plt.axhline(y=y_loop_start, color='green', linestyle='-', linewidth=1.5, label='Loop Start/End')
            plt.axhline(y=y_loop_end, color='green', linestyle='-', linewidth=1.5)
            
            plt.colorbar(label="Expected Position Error (Å)")
            plt.title(f"Focused PAE: {ligand_name} Loop ({LOOP_START_RESIDUE}-{LOOP_END_RESIDUE}) vs {target_name} (Model {config.CANDIDATE_NUM})")
            
            # Adjust ticks to show original residue numbers
            y_ticks = np.arange(0, focused_pae.shape[0], step=5)
            y_labels = [str(int(PAE_START_A + t + 1)) for t in y_ticks]
            plt.yticks(y_ticks, y_labels)
            
            x_ticks = np.arange(0, focused_pae.shape[1], step=100)
            x_labels = [str(int(PAE_START_B + t + 1)) for t in x_ticks]
            plt.xticks(x_ticks, x_labels)

            plt.xlabel(f"{target_name} Residue")
            plt.ylabel(f"{ligand_name} Residue")

            focused_pae_output_path = dir / f"PAE_focused_{candidate_name}_model_{config.CANDIDATE_NUM}.png"
            plt.savefig(focused_pae_output_path)
            plt.close()

            display(Image(filename=focused_pae_output_path))
        else:
            mean_interface_pae_loop = 0.0 
            print("Warning: Interface PAE matrix slice was empty.")

        # Add to your results_list
        results_list[-1]["Interface PAE (Loop)"] = mean_interface_pae_loop

    print(f"\n--- Saving Structure to PDB Format ---")

    # Convert mmCIF to PDB
    pdb_output_path = dir / f"PDB_{candidate_name}_model_{config.CANDIDATE_NUM}.pdb"
    
    try:
        # 1. Parse the mmCIF file
        parser = MMCIFParser()
        structure = parser.get_structure(candidate_name, str(structure_file))
        
        # 2. Initialize the PDB writer
        io = PDBIO()
        io.set_structure(structure)
        
        # 3. Write the structure to the new PDB file
        io.save(str(pdb_output_path))
        print(f"Successfully saved PDB file to: {pdb_output_path}")

    except Exception as e:
        print(f"ERROR: Failed to convert CIF to PDB for {candidate_name}. Error: {e}")
    
    # --- VISUALIZE THE 3D STRUCTURE ---
    print("\n--- Predicted 3D Structure (Colored by pLDDT) ---")
    with open(structure_file, 'r') as f:
        structure_data = f.read()

    view = py3Dmol.view(width=800, height=600)
    view.addModel(structure_data, "mmcif")
    view.setStyle({'chain': 'A'}, {'cartoon': {'colorscheme': {
        'prop': 'b', 'gradient': 'roygb', 'min': 40, 'max': 100,
    }}})
    view.setStyle({'chain': 'B'}, {'sphere': {}})
    view.zoomTo()
    view.show()

    shutil.rmtree(unzip_dir)  # Clean up the unzipped directory
    print(f"Cleaned up temporary files for candidate: {candidate_name}")

### Compare across targets

In [None]:
import pandas as pd

# --- POST-PROCESSING: GENERATE COMPARISON MATRICES WITH GLOBAL COLORING ---
print("\n" + "="*70)
print("--- Comparison Matrix Analysis with Global Heatmaps ---")
print("="*70)

# Convert the list of results into a pandas DataFrame
df = pd.DataFrame(results_list)

# Extract Components for PIVOTING
# Extraction assuming format: {ligand}_variant_{target}_{variant}
temp_split = df['Candidate'].str.split('_', expand=True)

if temp_split.shape[1] >= 4:
    df['Target'] = temp_split[3]
    df['Variant'] = temp_split[4]
else:
    print("WARNING: Candidate names do not match the expected format {ligand}_variant_{target}_{variant}. Cannot create comparison matrices.")
    # Stop matrix generation if the format is wrong
    exit() 

# Define Features and Color Maps
# Features where HIGHER is BETTER (Green is good)
HIGH_IS_BETTER = ["ipTM", "Mean pLDDT", "Loop pLDDT"] 
# Features where LOWER is BETTER (Red is good, so use reverse colormap)
LOW_IS_BETTER = ["Interface PAE (Loop)"] 

features_to_pivot = HIGH_IS_BETTER + LOW_IS_BETTER 

comparison_matrices = {}

for feature in features_to_pivot:
    try:
        # A. Calculate Global Min/Max for Coloring
        # Use the entire column's min/max to ensure global color scaling
        global_min = df[feature].min()
        global_max = df[feature].max()
        
        # B. Determine Colormap (Green for good, Red for bad)
        if feature in HIGH_IS_BETTER:
            cmap = 'RdYlGn'
        elif feature in LOW_IS_BETTER:
            # Reverse colormap (RdYlGn_r) makes low values green and high values red
            cmap = 'RdYlGn_r'
        else:
            cmap = 'Blues' # Default fallback
            
        # C. Create the Pivot Table
        pivot_table = df.pivot_table(
            index='Variant', 
            columns='Target', 
            values=feature,
            aggfunc='mean'
        )
        
        # D. Apply Global Background Gradient
        styled_table = pivot_table.style.background_gradient(
            cmap=cmap,
            # Pass the global min/max for scaling across all cells in the matrix
            vmin=global_min, 
            vmax=global_max
        )
        
        comparison_matrices[feature] = styled_table
        
        print(f"\n--- Comparison Matrix: {feature} ---")
        display(styled_table)

        # Save the styled table
        styled_table.to_excel(dir / f"Comparison_Matrix_{feature.replace(' ', '_')}.xlsx", engine='openpyxl')
        latex_table = styled_table.hide(axis="index").to_latex(caption=f"AlphaFold {feature} Summary", label=f"tab:af_feature_{feature.replace(' ', '_')}_sum", convert_css=True)
        with open(dir / f"Comparison_Matrix_{feature.replace(' ', '_')}.tex", 'w') as f:
            f.write(latex_table)
        
    except KeyError:
        print(f"Skipping {feature}: Column not found in DataFrame.")

print("\nMatrix generation complete. Global coloring applied based on the min/max of each feature across all data points.")

### Feature discovery

In [None]:
from matplotlib.colors import Normalize, to_hex

# --- Helper Function for Coloring (UPDATED) ---
def color_cells(val, cmap, norm):
    """
    Applies a background color to a cell based on its numeric value
    and dynamically sets text color for readability.
    """
    # Get the color as an RGBA tuple from the colormap
    rgba = cmap(norm(val))
    
    # Convert RGBA to solid RGB for hex and brightness calculation
    # We drop the alpha here as Excel doesn't handle it, and we want opaque hex
    r, g, b = (np.array(rgba[:3]) * 255).astype(int)
    
    # Convert the color to a hex string (e.g., '#AABBCC'), which Excel understands
    hex_color = to_hex([r/255, g/255, b/255]) # to_hex expects values between 0 and 1

    # Calculate perceived brightness (YIQ formula)
    brightness = (r * 299 + g * 587 + b * 114) / 1000
    
    # Choose text color based on brightness
    text_color = '#FFFFFF' if brightness < 128 else '#000000' # White for dark, Black for light

    # Return the CSS style for background and text color
    return f'background-color: {hex_color}; color: {text_color}'

output_dir = Path("./analysis_output")
output_dir.mkdir(exist_ok=True)

temp_split = df['Candidate'].str.split('_', expand=True)
df['Target'] = temp_split[3]
df['Variant'] = temp_split[4]

HIGH_IS_BETTER = ["ipTM", "Mean pLDDT", "Loop pLDDT"]
LOW_IS_BETTER = ["Interface PAE (Loop)"]
features_to_analyze = HIGH_IS_BETTER + LOW_IS_BETTER

In [None]:
# --- PRE-PROCESSING: CALCULATE LOCAL Z-SCORES (FOR ALL ANALYSES NOW) ---
# This Z-score is calculated *across targets* for each variant.
for feature in features_to_analyze:
    # Use transform to calculate z-scores within each variant group and align it back to the original df
    df[f'{feature}_local_zscore'] = df.groupby('Variant')[feature].transform(
        lambda x: (x - x.mean()) / x.std() if x.std() > 0 else 0
    )

# --- ANALYSIS 1: MOST SIGNIFICANT TARGET (USING LOCAL Z-SCORE) ---
# This analysis is now simpler as the local_zscore is pre-calculated.
print("\n" + "="*80)
print("--- Analysis 1: Most Significant Target per Variant (Z-Score Across Targets) ---")
print("="*80)

z_score_summary_text, z_score_summary_numeric = {}, {}
for feature in features_to_analyze:
    text, numeric = {}, {}
    for variant, group in df.groupby('Variant'):
        best_idx = group[f'{feature}_local_zscore'].idxmax() if feature in HIGH_IS_BETTER else group[f'{feature}_local_zscore'].idxmin()
        best_row = df.loc[best_idx]
        z_val = best_row[f'{feature}_local_zscore']
        text[variant] = f"{best_row['Target']} (Z={z_val:.2f})"
        numeric[variant] = z_val
    z_score_summary_text[feature] = text
    z_score_summary_numeric[feature] = numeric

z_score_text_df = pd.DataFrame(z_score_summary_text)
z_score_numeric_df = pd.DataFrame(z_score_summary_numeric)
# (Styling and saving logic is unchanged)
styler_df = pd.DataFrame('', index=z_score_text_df.index, columns=z_score_text_df.columns)
g_min, g_max = z_score_numeric_df.min().min(), z_score_numeric_df.max().max()
norm = Normalize(vmin=g_min, vmax=g_max)
for feature in styler_df.columns:
    cmap = plt.get_cmap('RdYlGn' if feature in HIGH_IS_BETTER else 'RdYlGn_r')
    styler_df[feature] = z_score_numeric_df[feature].apply(lambda val: color_cells(val, cmap, norm))
styled_summary = z_score_text_df.style.apply(lambda x: styler_df, axis=None)
display(styled_summary)
excel_path = output_dir / "Significant_Targets_A_Summary_Styled.xlsx"
styled_summary.to_excel(excel_path, engine='openpyxl')
print(f"\nStyled summary table saved to: {excel_path}")


# --- ANALYSIS 2: VARIANT RANKING BY LOCAL Z-SCORE (UPDATED) ---
print("\n" + "="*80)
print("--- Analysis 2: Styled Variant Ranking for Each Target (by Local Z-Score/Selectivity) ---")
print("="*80)
for feature in features_to_analyze:
    text_data, numeric_data = {}, {}
    for target_name, group in df.groupby('Target'):
        sort_ascending = feature in LOW_IS_BETTER
        # Use the pre-calculated LOCAL z-score for sorting
        sorted_group = group.sort_values(by=f'{feature}_local_zscore', ascending=sort_ascending)
        text_data[target_name] = [f"{row['Variant']} (Z={row[f'{feature}_local_zscore']:.2f})" for _, row in sorted_group.iterrows()]
        numeric_data[target_name] = sorted_group[f'{feature}_local_zscore'].tolist()
    rank_df_text = pd.DataFrame(text_data); rank_df_numeric = pd.DataFrame(numeric_data)
    rank_df_text.index = rank_df_numeric.index = rank_df_text.index + 1
    rank_df_text.index.name = "Rank"
    cmap = plt.get_cmap('RdYlGn' if feature in HIGH_IS_BETTER else 'RdYlGn_r')
    norm = Normalize(vmin=rank_df_numeric.min().min(), vmax=rank_df_numeric.max().max())
    styled_rank_df = rank_df_text.style.apply(lambda x: rank_df_numeric.map(lambda val: color_cells(val, cmap, norm)), axis=None)
    print(f"\n--- Local Z-Score Ranks for: {feature} ---"); display(styled_rank_df)
    excel_path = output_dir / f"Ranking_Local_ZScore_Target_Styled_{feature.replace(' ', '_')}.xlsx"
    styled_rank_df.to_excel(excel_path, engine='openpyxl')
    print(f"Styled Local Z-Score ranking table saved to: {excel_path}")

In [None]:
# --- PRE-PROCESSING: CALCULATE Z-SCORES ---
for feature in features_to_analyze:
    mean, std = df[feature].mean(), df[feature].std()
    df[f'{feature}_zscore'] = 0 if std == 0 else (df[feature] - mean) / std

# --- ANALYSIS 1: MOST SIGNIFICANT TARGET SUMMARY ---
print("\n" + "="*80)
print("--- Analysis 1: Most Statistically Significant Target for Each Variant ---")
print("="*80)

# Create the text and numeric DataFrames as before
z_score_summary_text, z_score_summary_numeric = {}, {}
for feature in features_to_analyze:
    text, numeric = {}, {}
    for variant, group in df.groupby('Variant'):
        # Find the row with the most desirable Z-score
        best_row = group.loc[group[f'{feature}_zscore'].idxmax()] if feature in HIGH_IS_BETTER else group.loc[group[f'{feature}_zscore'].idxmin()]
        z_val = best_row[f'{feature}_zscore']
        text[variant] = f"{best_row['Target']} (Z={z_val:.2f})"
        numeric[variant] = z_val
    z_score_summary_text[feature] = text
    z_score_summary_numeric[feature] = numeric

z_score_text_df = pd.DataFrame(z_score_summary_text)
z_score_text_df.index.name = "Variant"
z_score_numeric_df = pd.DataFrame(z_score_summary_numeric)
z_score_numeric_df.index.name = "Variant"

# --- Styling Logic for Analysis 1 ---
# Create an empty DataFrame to hold the CSS styles
styler_df = pd.DataFrame('', index=z_score_text_df.index, columns=z_score_text_df.columns)

# Find the global min and max Z-score for consistent color scaling
g_min, g_max = z_score_numeric_df.min().min(), z_score_numeric_df.max().max()
norm = Normalize(vmin=g_min, vmax=g_max)

# Iterate through each column (feature) to apply the correct colormap
for feature in styler_df.columns:
    # Choose the colormap based on whether high or low is better
    cmap = plt.get_cmap('RdYlGn' if feature in HIGH_IS_BETTER else 'RdYlGn_r')
    # Apply the coloring function to each cell in the column
    styler_df[feature] = z_score_numeric_df[feature].apply(lambda val: color_cells(val, cmap, norm))

# Apply the generated styles to the text DataFrame
styled_summary = z_score_text_df.style.apply(lambda x: styler_df, axis=None)

print("This table shows the most significant target, colored by its Z-score.")
display(styled_summary)

# Save the styled table to an Excel file
excel_path = output_dir / "Significant_Targets_Summary_B_Styled.xlsx"
styled_summary.to_excel(excel_path, engine='openpyxl')
print(f"\nStyled summary table saved to: {excel_path}")


# --- 4. ANALYSIS 2: VARIANT RANKING BY Z-SCORE (STYLED) ---
print("\n" + "="*80)
print("--- Analysis 2: Styled Variant Ranking for Each Target (by Z-Score) ---")
print("="*80)

for feature in features_to_analyze:
    text_data, numeric_data = {}, {}
    for target_name, group in df.groupby('Target'):
        sort_ascending = feature in LOW_IS_BETTER
        sorted_group = group.sort_values(by=f'{feature}_zscore', ascending=sort_ascending)
        
        text_data[target_name] = [f"{row['Variant']} (Z={row[f'{feature}_zscore']:.2f})" for _, row in sorted_group.iterrows()]
        numeric_data[target_name] = sorted_group[f'{feature}_zscore'].tolist()

    rank_df_text = pd.DataFrame(text_data)
    rank_df_numeric = pd.DataFrame(numeric_data)

    # Make the rank index 1-based for the text table
    rank_df_text.index = rank_df_text.index + 1
    rank_df_text.index.name = "Rank"
    
    # *** THE FIX: Ensure the numeric DataFrame has the exact same index ***
    rank_df_numeric.index = rank_df_text.index

    # Set up colormap and normalization for styling
    cmap = plt.get_cmap('RdYlGn' if feature in HIGH_IS_BETTER else 'RdYlGn_r')
    norm = Normalize(vmin=rank_df_numeric.min().min(), vmax=rank_df_numeric.max().max())
    
    # Apply styling
    styled_rank_df = rank_df_text.style.apply(lambda x: rank_df_numeric.map(lambda val: color_cells(val, cmap, norm)), axis=None)
    
    print(f"\n--- Z-Score Ranks for: {feature} ---")
    display(styled_rank_df)
    
    excel_path = output_dir / f"Ranking_ZScore_Styled_{feature.replace(' ', '_')}.xlsx"
    styled_rank_df.to_excel(excel_path, engine='openpyxl')
    print(f"Styled Z-Score ranking table saved to: {excel_path}")


# --- ANALYSIS 3: VARIANT RANKING BY RAW SCORE (STYLED) ---
print("\n" + "="*80)
print("--- Analysis 3: Styled Variant Ranking for Each Target (by Raw Score) ---")
print("="*80)

for feature in features_to_analyze:
    text_data, numeric_data = {}, {}
    for target_name, group in df.groupby('Target'):
        sort_ascending = feature in LOW_IS_BETTER
        sorted_group = group.sort_values(by=feature, ascending=sort_ascending)
        
        text_data[target_name] = [f"{row['Variant']} (Score={row[feature]:.3f})" for _, row in sorted_group.iterrows()]
        numeric_data[target_name] = sorted_group[feature].tolist()

    rank_df_text = pd.DataFrame(text_data)
    rank_df_numeric = pd.DataFrame(numeric_data)

    # Make the rank index 1-based for the text table
    rank_df_text.index = rank_df_text.index + 1
    rank_df_text.index.name = "Rank"

    # *** THE FIX: Ensure the numeric DataFrame has the exact same index ***
    rank_df_numeric.index = rank_df_text.index

    # Set up colormap and normalization
    cmap = plt.get_cmap('RdYlGn' if feature in HIGH_IS_BETTER else 'RdYlGn_r')
    norm = Normalize(vmin=rank_df_numeric.min().min(), vmax=rank_df_numeric.max().max())

    # Apply styling
    styled_rank_df = rank_df_text.style.apply(lambda x: rank_df_numeric.map(lambda val: color_cells(val, cmap, norm)), axis=None)
    
    print(f"\n--- Raw Score Ranks for: {feature} ---")
    display(styled_rank_df)

    excel_path = output_dir / f"Ranking_RawScore_Styled_{feature.replace(' ', '_')}.xlsx"
    styled_rank_df.to_excel(excel_path, engine='openpyxl')
    print(f"Styled Raw Score ranking table saved to: {excel_path}")

## Save Summary

In [None]:
# --- FINAL SUMMARY ---
if results_list:
    display(HTML("<hr><h1>Final Summary Ranking</h1>"))
    summary_df = pd.DataFrame(results_list)
    summary_df_sorted = summary_df.sort_values(by="Loop pLDDT", ascending=False).reset_index(drop=True)
    #summary_df_sorted = summary_df_sorted[['Candidate', 'ipTM', 'Mean pLDDT']]
    
    # Style the dataframe for better readability
    styled_summary_df = summary_df_sorted.style.background_gradient(subset=['ipTM'], cmap='viridis').format({'ipTM': '{:.4f}', 'Mean pLDDT': '{:.2f}'})
    styled_summary_df = styled_summary_df.background_gradient(subset=['Mean pLDDT'], cmap='viridis')
    styled_summary_df = styled_summary_df.background_gradient(subset=['Loop pLDDT'], cmap='viridis')
    styled_summary_df = styled_summary_df.background_gradient(subset=['Interface PAE (Loop)'], cmap='viridis_r')
    display(styled_summary_df)
    
    styled_summary_df.to_excel(dir / 'AlphaFold_Summary.xlsx', engine='openpyxl', index=False)
    latex_table = styled_summary_df.hide(axis="index").to_latex(caption=f"AlphaFold Summary", label=f"tab:af_multi_sum", convert_css=True)
    latex_table = latex_table.replace("fold_", "").replace("_", " ").replace(".result", "")
    with open(dir / f"AlphaFold_Summary.tex", 'w', encoding='utf-8') as f:
        f.write(latex_table)