<a href="https://colab.research.google.com/github/eoinleen/Protein-design-random/blob/main/Rank_BindCraft.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
#!/usr/bin/env python3
"""
BindCraft Near-Miss Design Analysis & Comprehensive Ranking
===========================================================

Purpose
-------
Analyze BindCraft design CSVs with comprehensive scoring based on all metrics.
Each design receives a weighted score based on how far each metric is from its threshold.

Scoring System
--------------
- Each metric contributes to a composite score
- Metrics are normalized by their distance from threshold
- Penalties for violations, bonuses for exceeding thresholds
- Special penalty for >5 unrelaxed clashes

Ranking
-------
Designs are ranked by:
1) Exclusion of high clash designs (>5 clashes)
2) Composite score based on ALL metrics
3) Violation count as tiebreaker
4) Individual metric performance

Input
-----
- Upload a BindCraft CSV when prompted

Output
------
- bindcraft_near_miss_analysis.xlsx with comprehensive scoring
- Color-coded Excel sheets showing metric performance
"""

import pandas as pd
import numpy as np
import io
from google.colab import files
import sys

# Try to import/install xlsxwriter
try:
    import xlsxwriter
    print("✅ xlsxwriter already installed")
except ImportError:
    print("Installing xlsxwriter for Excel formatting...")
    !pip install xlsxwriter -q
    import xlsxwriter
    print("✅ xlsxwriter installed successfully")

# ------------------------------------------------------------------
# 1. COMPREHENSIVE FILTER CRITERIA WITH WEIGHTS
# ------------------------------------------------------------------
filter_config = {
    # Critical metrics (weight = 2.0)
    "Average_pLDDT": {"threshold": 0.8, "higher": True, "weight": 2.0},
    "1_pLDDT": {"threshold": 0.8, "higher": True, "weight": 2.0},
    "Average_Binder_pLDDT": {"threshold": 0.8, "higher": True, "weight": 2.0},
    "Average_i_pTM": {"threshold": 0.5, "higher": True, "weight": 2.0},
    "1_i_pTM": {"threshold": 0.5, "higher": True, "weight": 2.0},
    "i_pTM": {"threshold": 0.5, "higher": True, "weight": 2.0},

    # Important structural metrics (weight = 1.5)
    "Average_pTM": {"threshold": 0.55, "higher": True, "weight": 1.5},
    "1_pTM": {"threshold": 0.55, "higher": True, "weight": 1.5},
    "Average_i_pAE": {"threshold": 0.35, "higher": False, "weight": 1.5},
    "1_i_pAE": {"threshold": 0.35, "higher": False, "weight": 1.5},
    "i_pAE": {"threshold": 0.35, "higher": False, "weight": 1.5},
    "Average_Binder_RMSD": {"threshold": 3.5, "higher": False, "weight": 1.5},

    # Interface quality metrics (weight = 1.5)
    "Average_ShapeComplementarity": {"threshold": 0.6, "higher": True, "weight": 1.5},
    "1_ShapeComplementarity": {"threshold": 0.55, "higher": True, "weight": 1.5},
    "Average_n_InterfaceHbonds": {"threshold": 3, "higher": True, "weight": 1.5},
    "Average_n_InterfaceResidues": {"threshold": 7, "higher": True, "weight": 1.5},

    # Energy metrics (weight = 1.2)
    "Average_dG": {"threshold": 0, "higher": False, "weight": 1.2},
    "Average_Binder_Energy_Score": {"threshold": 0, "higher": False, "weight": 1.2},
    "Average_dSASA": {"threshold": 1, "higher": True, "weight": 1.2},

    # Secondary metrics (weight = 1.0)
    "2_pLDDT": {"threshold": 0.8, "higher": True, "weight": 1.0},
    "Average_Surface_Hydrophobicity": {"threshold": 0.35, "higher": False, "weight": 1.0},
    "Average_n_InterfaceUnsatHbonds": {"threshold": 4, "higher": False, "weight": 1.0},
    "Average_Binder_Loop%": {"threshold": 90, "higher": False, "weight": 1.0},
    "Average_Hotspot_RMSD": {"threshold": 6, "higher": False, "weight": 1.0},

    # Penalty metrics (weight = 0.8)
    "Average_InterfaceAAs_K": {"threshold": 3, "higher": False, "weight": 0.8},
    "Average_InterfaceAAs_M": {"threshold": 3, "higher": False, "weight": 0.8},
}

# ------------------------------------------------------------------
# 2. LOAD CSV
# ------------------------------------------------------------------
print("\n" + "="*60)
print("Please upload a BindCraft CSV (e.g. final_design_stats.csv)")
print("="*60)
uploaded = files.upload()
csv_name = list(uploaded.keys())[0]
df = pd.read_csv(io.BytesIO(uploaded[csv_name]))

print(f"\n✅ Loaded {csv_name}")
print(f"   - {len(df)} designs")
print(f"   - {len(df.columns)} columns")

# Find the unrelaxed clashes column
column_P = None
if len(df.columns) > 15:
    column_P = df.columns[15]  # Column P
    print(f"\nColumn P identified as: '{column_P}'")

clash_columns = [col for col in df.columns if "clash" in col.lower() or "unrelaxed" in col.lower()]
if clash_columns:
    print(f"Clash-related columns found: {clash_columns}")
    if column_P in clash_columns:
        rank_unrelaxed_clashes = column_P
    else:
        rank_unrelaxed_clashes = clash_columns[0]
else:
    rank_unrelaxed_clashes = column_P

if rank_unrelaxed_clashes:
    filter_config[rank_unrelaxed_clashes] = {"threshold": 5, "higher": False, "weight": 3.0}  # High weight for clashes
    print(f"✅ Using '{rank_unrelaxed_clashes}' for clash filtering (threshold: ≤5, weight: 3.0)")

# Show metrics present
print("\nMetrics found in data:")
metrics_present = []
for metric in filter_config.keys():
    if metric in df.columns:
        metrics_present.append(metric)
print(f"   - {len(metrics_present)}/{len(filter_config)} filter metrics present")

# ------------------------------------------------------------------
# 3. COMPREHENSIVE SCORING SYSTEM
# ------------------------------------------------------------------
def calculate_metric_score(value, threshold, higher_better, weight):
    """
    Calculate a score for a single metric based on its distance from threshold.
    Returns a weighted score where:
    - Positive scores = better than threshold
    - Negative scores = worse than threshold
    - Magnitude indicates how far from threshold
    """
    if pd.isna(value):
        return -weight * 2  # Heavy penalty for missing values

    if higher_better:
        # For metrics where higher is better
        if value >= threshold:
            # Bonus for exceeding threshold (capped at 2x threshold for normalization)
            normalized = min((value - threshold) / max(threshold, 0.1), 1.0)
            return weight * normalized
        else:
            # Penalty for being below threshold
            normalized = (threshold - value) / max(threshold, 0.1)
            return -weight * normalized
    else:
        # For metrics where lower is better
        if value <= threshold:
            # Bonus for being below threshold
            if threshold <= 0:
                # For negative thresholds (like dG)
                normalized = min(abs(value - threshold) / max(abs(threshold), 1), 1.0)
            else:
                # For positive thresholds
                normalized = min((threshold - value) / max(threshold, 0.1), 1.0)
            return weight * normalized
        else:
            # Penalty for exceeding threshold
            normalized = (value - threshold) / max(abs(threshold), 0.1)
            return -weight * normalized

def analyze_design_comprehensive(row):
    """Analyze each design with comprehensive scoring"""
    violations = 0
    failed = []
    composite_score = 0
    metric_scores = {}

    for metric, crit in filter_config.items():
        if metric not in row.index:
            continue

        val = row[metric]
        thresh = crit["threshold"]
        higher = crit["higher"]
        weight = crit.get("weight", 1.0)

        # Calculate metric score
        score = calculate_metric_score(val, thresh, higher, weight)
        metric_scores[metric] = score
        composite_score += score

        # Track violations
        if pd.isna(val):
            violations += 1
            failed.append(f"{metric}(NaN)")
        elif (higher and val < thresh) or (not higher and val > thresh):
            violations += 1
            failed.append(f"{metric}({val:.3f})")

    return pd.Series([
        violations,
        ", ".join(failed),
        composite_score,
        len(metric_scores)  # Number of metrics evaluated
    ])

print("\nCalculating comprehensive scores...")
df[["violation_count", "failed_metrics", "composite_score", "metrics_evaluated"]] = \
    df.apply(analyze_design_comprehensive, axis=1)

# Normalize composite score by number of metrics evaluated
df["normalized_score"] = df.apply(
    lambda x: x["composite_score"] / x["metrics_evaluated"] if x["metrics_evaluated"] > 0 else -999,
    axis=1
)

# ------------------------------------------------------------------
# 4. CLASSIFICATION WITH SCORE RANGES
# ------------------------------------------------------------------
df["passes_all_filters"] = df["violation_count"] == 0

def classify_with_score(row):
    """Classify based on violations and score"""
    v = row["violation_count"]
    score = row["normalized_score"]

    if v == 0:
        if score > 0.5:
            return "EXCELLENT"
        elif score > 0.2:
            return "PASS"
        else:
            return "PASS_MARGINAL"
    elif v == 1:
        return "NEAR_MISS_1"
    elif v == 2:
        return "NEAR_MISS_2"
    else:
        return "NEAR_MISS_3+"

df["filter_status"] = df.apply(classify_with_score, axis=1)

# ------------------------------------------------------------------
# 5. ADVANCED RANKING WITH COMPREHENSIVE SCORING
# ------------------------------------------------------------------
# Create high clash flag
if rank_unrelaxed_clashes and rank_unrelaxed_clashes in df.columns:
    df["high_clash_flag"] = df[rank_unrelaxed_clashes].apply(
        lambda x: 1 if pd.notna(x) and x > 5 else 0
    )

    high_clash_count = len(df[df["high_clash_flag"] == 1])
    if high_clash_count > 0:
        print(f"\n⚠️ {high_clash_count} designs with >5 clashes will be heavily penalized")
else:
    df["high_clash_flag"] = 0

# Calculate percentile ranks for key metrics
key_metrics = ["composite_score", "normalized_score", "violation_count"]
for metric in key_metrics:
    if metric in df.columns:
        df[f"{metric}_percentile"] = df[metric].rank(
            ascending=(metric == "violation_count"),
            pct=True
        ) * 100

print("\nScore distribution:")
print(f"  Best normalized score: {df['normalized_score'].max():.3f}")
print(f"  Median normalized score: {df['normalized_score'].median():.3f}")
print(f"  Worst normalized score: {df['normalized_score'].min():.3f}")

# Multi-level sorting
sort_cols = [
    "high_clash_flag",      # Exclude high clash designs
    "normalized_score",      # Primary: normalized composite score
    "violation_count",       # Secondary: number of violations
    "composite_score"        # Tertiary: raw composite score
]
sort_asc = [True, False, True, False]

# Add specific metric columns if available
for metric in ["i_pTM", "1_i_pTM", "Average_i_pTM", "Binder_pLDDT", "Average_Binder_pLDDT"]:
    if metric in df.columns:
        sort_cols.append(metric)
        sort_asc.append(False)

# Apply sorting
df = df.sort_values(by=sort_cols, ascending=sort_asc).reset_index(drop=True)

# Add final rank
df.insert(0, "final_rank", df.index + 1)

# Add score tier
def assign_tier(rank, total):
    percentile = (rank / total) * 100
    if percentile <= 10:
        return "Top 10%"
    elif percentile <= 25:
        return "Top 25%"
    elif percentile <= 50:
        return "Top 50%"
    else:
        return "Bottom 50%"

df.insert(1, "score_tier", df.apply(
    lambda x: assign_tier(x["final_rank"], len(df)), axis=1
))

# ------------------------------------------------------------------
# 6. EXCEL OUTPUT WITH COMPREHENSIVE FORMATTING
# ------------------------------------------------------------------
try:
    output_file = "bindcraft_comprehensive_analysis.xlsx"
    print("\nCreating comprehensive Excel report...")

    with pd.ExcelWriter(output_file, engine="xlsxwriter") as writer:
        workbook = writer.book

        # Define formats
        formats = {
            "red": workbook.add_format({"bg_color": "#FFC7CE", "font_color": "#9C0006"}),
            "green": workbook.add_format({"bg_color": "#C6EFCE", "font_color": "#006100"}),
            "yellow": workbook.add_format({"bg_color": "#FFEB9C", "font_color": "#9C5700"}),
            "blue": workbook.add_format({"bg_color": "#DAE8FC", "font_color": "#00509E"}),
            "dark_red": workbook.add_format({"bg_color": "#FF6B6B", "font_color": "#FFFFFF", "bold": True}),
            "dark_green": workbook.add_format({"bg_color": "#4CAF50", "font_color": "#FFFFFF", "bold": True}),
            "light_blue": workbook.add_format({"bg_color": "#E3F2FD", "font_color": "#1976D2"}),
        }

        # Write sheets
        sheets_config = [
            ("Summary", df),
            ("Metrics_Colored", df),
            ("Top_Performers", df[df["score_tier"].isin(["Top 10%", "Top 25%"])]),
            ("Score_Analysis", df[["final_rank", "Design", "normalized_score", "composite_score",
                                   "violation_count", "filter_status", "score_tier"]])
        ]

        for sheet_name, sheet_df in sheets_config:
            sheet_df.to_excel(writer, sheet_name=sheet_name, index=False)
            worksheet = writer.sheets[sheet_name]

            # Apply conditional formatting
            first_row = 1
            last_row = len(sheet_df)

            # Format all metric columns
            for metric, crit in filter_config.items():
                if metric not in sheet_df.columns:
                    continue

                col_idx = sheet_df.columns.get_loc(metric)
                thresh = crit["threshold"]
                higher = crit["higher"]

                if higher:
                    # Green for good, red for bad
                    worksheet.conditional_format(
                        first_row, col_idx, last_row, col_idx,
                        {"type": "cell", "criteria": ">=", "value": thresh, "format": formats["green"]}
                    )
                    worksheet.conditional_format(
                        first_row, col_idx, last_row, col_idx,
                        {"type": "cell", "criteria": "<", "value": thresh, "format": formats["red"]}
                    )
                else:
                    # Green for good, red for bad
                    worksheet.conditional_format(
                        first_row, col_idx, last_row, col_idx,
                        {"type": "cell", "criteria": "<=", "value": thresh, "format": formats["green"]}
                    )
                    worksheet.conditional_format(
                        first_row, col_idx, last_row, col_idx,
                        {"type": "cell", "criteria": ">", "value": thresh, "format": formats["red"]}
                    )

            # Format score columns
            if "normalized_score" in sheet_df.columns:
                score_col = sheet_df.columns.get_loc("normalized_score")
                worksheet.conditional_format(
                    first_row, score_col, last_row, score_col,
                    {"type": "3_color_scale",
                     "min_color": "#FF0000",
                     "mid_color": "#FFFF00",
                     "max_color": "#00FF00"}
                )

            # Format tier column
            if "score_tier" in sheet_df.columns:
                tier_col = sheet_df.columns.get_loc("score_tier")
                worksheet.conditional_format(
                    first_row, tier_col, last_row, tier_col,
                    {"type": "text", "criteria": "containing", "value": "Top 10%", "format": formats["dark_green"]}
                )
                worksheet.conditional_format(
                    first_row, tier_col, last_row, tier_col,
                    {"type": "text", "criteria": "containing", "value": "Top 25%", "format": formats["green"]}
                )

            # Format status column
            if "filter_status" in sheet_df.columns:
                status_col = sheet_df.columns.get_loc("filter_status")
                worksheet.conditional_format(
                    first_row, status_col, last_row, status_col,
                    {"type": "text", "criteria": "containing", "value": "EXCELLENT", "format": formats["dark_green"]}
                )
                worksheet.conditional_format(
                    first_row, status_col, last_row, status_col,
                    {"type": "text", "criteria": "containing", "value": "PASS", "format": formats["green"]}
                )
                worksheet.conditional_format(
                    first_row, status_col, last_row, status_col,
                    {"type": "text", "criteria": "containing", "value": "NEAR_MISS", "format": formats["yellow"]}
                )

            # Auto-adjust column widths
            for idx, column in enumerate(sheet_df.columns):
                try:
                    max_len = min(sheet_df[column].astype(str).map(len).max() + 2, 50)
                    max_len = max(max_len, len(column) + 2)
                    worksheet.set_column(idx, idx, max_len)
                except:
                    worksheet.set_column(idx, idx, 15)

        print(f"✅ Created comprehensive report with {len(sheets_config)} sheets")

    files.download(output_file)
    excel_success = True

except Exception as e:
    print(f"⚠️ Excel creation failed: {str(e)}")
    print("Falling back to CSV output...")
    excel_success = False

    output_file = "bindcraft_comprehensive_analysis.csv"
    df.to_csv(output_file, index=False)
    files.download(output_file)

# ------------------------------------------------------------------
# 7. COMPREHENSIVE SUMMARY STATISTICS
# ------------------------------------------------------------------
print(f"\n{'='*60}")
print("COMPREHENSIVE ANALYSIS COMPLETE!")
print(f"{'='*60}")

print(f"\nOverall Statistics:")
print(f"  Total designs: {len(df)}")
print(f"  Metrics evaluated: {len(metrics_present)}")
print(f"  Average normalized score: {df['normalized_score'].mean():.3f}")

print(f"\nPerformance Distribution:")
for tier in ["Top 10%", "Top 25%", "Top 50%", "Bottom 50%"]:
    count = len(df[df["score_tier"] == tier])
    print(f"  {tier}: {count} designs")

print(f"\nFilter Status Breakdown:")
for status in df["filter_status"].unique():
    count = len(df[df["filter_status"] == status])
    avg_score = df[df["filter_status"] == status]["normalized_score"].mean()
    print(f"  {status}: {count} designs (avg score: {avg_score:.3f})")

if rank_unrelaxed_clashes and "high_clash_flag" in df.columns:
    high_clash = df[df["high_clash_flag"] == 1]
    if len(high_clash) > 0:
        print(f"\n⚠️ Designs penalized for high clashes (>5): {len(high_clash)}")

print(f"\nTop 5 Performers:")
top_cols = ["final_rank", "Design", "normalized_score", "violation_count", "score_tier"]
if rank_unrelaxed_clashes and rank_unrelaxed_clashes in df.columns:
    top_cols.append(rank_unrelaxed_clashes)
print(df[top_cols].head(5).to_string(index=False))

if excel_success:
    print(f"\n✅ Downloaded: {output_file}")
    print("\nExcel report includes:")
    print("  • Summary - Complete data with all metrics")
    print("  • Metrics_Colored - Visual highlighting of all thresholds")
    print("  • Top_Performers - Top 25% of designs")
    print("  • Score_Analysis - Detailed scoring breakdown")
else:
    print(f"\n✅ Downloaded: {output_file} (CSV format)")

✅ xlsxwriter already installed

Please upload a BindCraft CSV (e.g. final_design_stats.csv)


Saving trajectory_stats.csv to trajectory_stats (4).csv

✅ Loaded trajectory_stats (4).csv
   - 35 designs
   - 44 columns

Column P identified as: 'Unrelaxed_Clashes'
Clash-related columns found: ['Unrelaxed_Clashes', 'Relaxed_Clashes']
✅ Using 'Unrelaxed_Clashes' for clash filtering (threshold: ≤5, weight: 3.0)

Metrics found in data:
   - 3/27 filter metrics present

Calculating comprehensive scores...

⚠️ 29 designs with >5 clashes will be heavily penalized

Score distribution:
  Best normalized score: 1.083
  Median normalized score: -0.062
  Worst normalized score: -7.798

Creating comprehensive Excel report...
✅ Created comprehensive report with 4 sheets


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


COMPREHENSIVE ANALYSIS COMPLETE!

Overall Statistics:
  Total designs: 35
  Metrics evaluated: 3
  Average normalized score: -0.917

Performance Distribution:
  Top 10%: 3 designs
  Top 25%: 5 designs
  Top 50%: 9 designs
  Bottom 50%: 18 designs

Filter Status Breakdown:
  EXCELLENT: 6 designs (avg score: 0.874)
  NEAR_MISS_1: 29 designs (avg score: -1.288)

⚠️ Designs penalized for high clashes (>5): 29

Top 5 Performers:
 final_rank                   Design  normalized_score  violation_count score_tier  Unrelaxed_Clashes
          1  AMPK_Pocket_l89_s720529          1.082857                0    Top 10%                  3
          2 AMPK_Pocket_l108_s482734          0.924762                0    Top 10%                  4
          3  AMPK_Pocket_l84_s830749          0.912381                0    Top 10%                  4
          4 AMPK_Pocket_l102_s373164          0.876190                0    Top 25%                  5
          5 AMPK_Pocket_l109_s828745          0.805714       