<a href="https://colab.research.google.com/github/eoinleen/Protein-design-random/blob/main/Extracting_stats_from_AF3_directories_v1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
# ========================================================================
# AlphaFold3 K11-Ub2 Binder Analysis Script
# ========================================================================
#
# PURPOSE:
# This script analyzes AlphaFold3 prediction results for K11-Ub2 binder designs.
# It extracts confidence metrics, calculates pLDDT values, and generates
# visualizations to help evaluate and compare different binder designs.
#
# HOW TO USE THIS SCRIPT:
# 1. Upload this script to Google Colab
# 2. Change the 'base_dir' variable in the main() function to point to your data folder
#    (e.g., '/content/drive/MyDrive/My_AF3_Results')
# 3. Run the script
# 4. Check the output directory for results (CSV files, plots, and log file)
#
# INPUT DATA REQUIREMENTS:
# - JSON confidence files (format: *summary_confidences_*.json)
# - PDB/CIF structure files in the same directories as their JSON files
# - File naming should include fold_id and model numbers
#
# WHAT THIS SCRIPT DOES:
# 1. Finds all JSON confidence files in the specified directory and subdirectories
# 2. Parses each JSON file to extract metrics like ipTM, pTM, and chain interactions
# 3. Locates the corresponding PDB/CIF file for each JSON file
# 4. Calculates per-residue and per-chain average pLDDT values from PDB files
# 5. Computes averages across models (0-4) for each fold_id/model_type combination
# 6. Generates visualizations of:
#    - Binder Quality (pLDDT) vs Binding Interface Quality (ipTM)
#    - Binder interaction with Ub1 vs Ub2
# 7. Creates a detailed log file documenting the analysis process
# 8. Saves raw and averaged data as CSV files
#
# OUTPUTS:
# - af3_raw_results.csv: All metrics for each individual model
# - af3_averaged_results.csv: Metrics averaged across models for each design
# - binder_quality_vs_interface.png: Plot of binder pLDDT vs interface ipTM
# - binder_ub1_vs_ub2_interaction.png: Plot of binder interaction with Ub1 vs Ub2
# - af3_analysis_log_*.txt: Detailed log of the analysis process
#
# KEY METRICS EXPLAINED:
# - iptm: Interface predicted TM-score, measures quality of interface prediction
# - binder_ub1_iptm: Interface ipTM between binder (A) and first ubiquitin (B)
# - binder_ub2_iptm: Interface ipTM between binder (A) and second ubiquitin (C)
# - binder_ub_iptm: Average of binder_ub1_iptm and binder_ub2_iptm
# - binder_avg_plddt: Average confidence score for the binder, calculated
#   by averaging B-factor values per residue, then averaging those per chain
# - composite_score: Average of avg_iptm and avg_binder_ub_iptm
#
# TROUBLESHOOTING:
# - If no JSON files are found: Check the 'base_dir' path and make sure files match expected patterns
# - If no PDB files are matched: Check filenames and directory structure
# - If missing data in CSV: Check log file for warnings about specific metrics
# - If no visualizations: Make sure required metrics are available
#
# UNDERSTANDING THE CHAIN NAMES:
# - Chain A: The binder
# - Chain B: First ubiquitin (Ub1)
# - Chain C: Second ubiquitin (Ub2)
#
# REQUIREMENTS:
# - Python packages: pandas, numpy, matplotlib, seaborn, biopython
#
# ========================================================================
# Authors:
# - Claude from Anthropic (Primary implementation and documentation)
# - Your Name (Project design, use case definition, and testing)
# - St Jimmy (Moral support and asking questions no one else would think to ask)
#
# Date: 2025-03-08
# Version: 1.0
# ========================================================================
#
# Note for Jimmy: Yes, this is the right script. Just run it and it will work.
# No, don't change anything unless you know what you're doing. Yes, it's
# supposed to take that long to run. No, the warnings are normal.
#
# ========================================================================

def main():
    """
    Main execution function that processes data and generates outputs.
    """
    # Set the base directory to your mounted Google Drive folder
    # Update this path to match your Google Drive structure
    base_dir = '/content/drive/MyDrive/PDB-files/AF3-recalc_folds_2025_03_08_19_33'
    output_dir = '/content/drive/MyDrive/PDB-files/AF3-recalc_folds_2025_03_08_19_33/AF3_ANALYSIS_RESULTS'

    # Create output directory and set up logging
    os.makedirs(output_dir, exist_ok=True)
    log_file = setup_logging(output_dir)

    logging.info("Starting AlphaFold3 K11-Ub2 binder analysis...")
    logging.info(f"Base directory: {base_dir}")
    logging.info(f"Output directory: {output_dir}")

    # Process all folders
    df = process_folder(base_dir)

    # Check if we have data
    if df.empty:
        logging.error("No data was found or processed. Please check your file paths and formats.")
        return None, None, log_file

    # Save the raw DataFrame to CSV
    raw_csv_path = os.path.join(output_dir, 'af3_raw_results.csv')
    df.to_csv(raw_csv_path, index=False)
    logging.info(f"Raw analysis results saved to {raw_csv_path}")

    # Check for any rows with missing critical data
    critical_cols = ['fold_id', 'model_type', 'model_num', 'iptm']
    available_critical_cols = [col for col in critical_cols if col in df.columns]

    if available_critical_cols:
        missing_data = df[df[available_critical_cols].isna().any(axis=1)]
        if not missing_data.empty:
            logging.warning(f"{len(missing_data)} rows have missing critical data")
            logging.warning("Sample of rows with missing data:")
            for idx, row in missing_data.head().iterrows():
                logging.warning(f"  Row {idx}: {row[available_critical_cols].to_dict()}")

    # Calculate averages across models 0-4 for each fold
    try:
        logging.info("Calculating model averages...")
        avg_df = calculate_model_averages(df)

        if avg_df.empty:
            logging.error("Failed to calculate model averages. Check the logs for details.")
            return df, None, log_file

        # Save the averaged DataFrame to CSV
        avg_csv_path = os.path.join(output_dir, 'af3_averaged_results.csv')
        avg_df.to_csv(avg_csv_path, index=False)
        logging.info(f"Averaged analysis results saved to {avg_csv_path}")

        # Generate visualizations for the averaged data
        generate_visualizations(df, avg_df, output_dir)

        # Display summary statistics for the averaged data
        logging.info("Summary Statistics for Averaged Data:")

        # List of metrics to show in summary
        summary_metrics = ['avg_iptm', 'avg_binder_ub_iptm', 'avg_binder_plddt']
        available_metrics = [m for m in summary_metrics if m in avg_df.columns]

        if available_metrics:
            # Calculate descriptive statistics
            summary_stats = avg_df[available_metrics].describe()
            logging.info("\nDescriptive statistics:")
            for stat, values in summary_stats.iterrows():
                logging.info(f"  {stat}: {values.to_dict()}")
        else:
            logging.warning("No metrics available for summary statistics")

        # Generate specific average metrics requested
        logging.info("\n===== Average ipTM for binders to both ubiquitins =====")
        if 'avg_binder_ub_iptm' in avg_df.columns:
            # Filter out NaN values
            valid_data = avg_df.dropna(subset=['avg_binder_ub_iptm'])
            if not valid_data.empty:
                avg_value = valid_data['avg_binder_ub_iptm'].mean()
                logging.info(f"Overall average binder-Ub ipTM: {avg_value:.4f}")
                logging.info(f"Number of valid data points: {len(valid_data)}/{len(avg_df)}")

                # Get top 10 binders by avg_binder_ub_iptm
                top_binders = valid_data.sort_values('avg_binder_ub_iptm', ascending=False).head(10)
                logging.info("\nTop 10 binders by average interaction with both ubiquitins:")
                for idx, row in top_binders.iterrows():
                    logging.info(f"  Fold {row['fold_id']} ({row['model_type']}): {row['avg_binder_ub_iptm']:.4f}")
            else:
                logging.warning("No valid binder-Ub ipTM data available")
        else:
            logging.warning("No binder-Ub ipTM data available")

        # Generate a list of the top performing binders (overall ipTM and binder-Ub interaction)
        logging.info("\n===== Overall Best Binders =====")
        # Create a composite score that weights both overall ipTM and binder-Ub interaction
        if 'avg_iptm' in avg_df.columns and 'avg_binder_ub_iptm' in avg_df.columns:
            # Filter to rows with both metrics available
            valid_data = avg_df.dropna(subset=['avg_iptm', 'avg_binder_ub_iptm'])

            if not valid_data.empty:
                # Calculate composite score
                valid_data['composite_score'] = (valid_data['avg_iptm'] + valid_data['avg_binder_ub_iptm']) / 2
                logging.info("Calculated composite score as (avg_iptm + avg_binder_ub_iptm) / 2")

                top_overall = valid_data.sort_values('composite_score', ascending=False).head(10)
                logging.info("\nTop 10 binders by composite score (overall ipTM + binder-Ub interaction):")
                for idx, row in top_overall.iterrows():
                    logging.info(f"  Fold {row['fold_id']} ({row['model_type']}): " +
                                 f"composite={row['composite_score']:.4f}, " +
                                 f"ipTM={row['avg_iptm']:.4f}, " +
                                 f"binder-Ub={row['avg_binder_ub_iptm']:.4f}")
            else:
                logging.warning("No valid data available for composite scoring")
        else:
            logging.warning("Required metrics not available for composite scoring")

        # Add metric calculation explanations to the log
        logging.info("\n===== Metric Calculation Details =====")
        logging.info("1. iptm - Interface predicted TM-score from AlphaFold3 JSON")
        logging.info("2. binder_ub1_iptm - Interface ipTM between binder (chain A) and first ubiquitin (chain B)")
        logging.info("3. binder_ub2_iptm - Interface ipTM between binder (chain A) and second ubiquitin (chain C)")
        logging.info("4. binder_ub_iptm - Average of binder_ub1_iptm and binder_ub2_iptm")
        logging.info("5. binder_avg_plddt - Average pLDDT (per-residue confidence) for the binder (chain A)")
        logging.info("   Calculation: Average of B-factor values from PDB file, first averaged per residue then per chain")
        logging.info("6. composite_score - Average of avg_iptm and avg_binder_ub_iptm")

    except Exception as e:
        logging.error(f"Error in calculating or visualizing averages: {e}")
        import traceback
        logging.error(traceback.format_exc())
        avg_df = None

    logging.info("\nAnalysis complete!")
    logging.info(f"Log file saved to: {log_file}")

    return df, avg_df, log_file

# This can be run in Google Colab
if __name__ == "__main__":
    try:
        # Mount Google Drive
        drive.mount('/content/drive')

        # Run the main analysis
        df, avg_df, log_file = main()

        # Display the first few rows of the averaged DataFrame if available
        if avg_df is not None and not avg_df.empty:
            print("\nPreview of averaged results:")
            display(avg_df.head())
        else:
            print("\nNo averaged results available to display")

        print(f"\nDetailed log saved to: {log_file}")

    except Exception as e:
        print(f"Error in main execution: {e}")
        import traceback
        traceback.print_exc()# Install required packages
!pip install -q biopython pandas numpy matplotlib seaborn

import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from Bio import PDB
from google.colab import drive
import glob
import re
import datetime
import logging

# Set up logging
def setup_logging(output_dir):
    """Set up detailed logging for the analysis process"""
    os.makedirs(output_dir, exist_ok=True)
    log_file = os.path.join(output_dir, f'af3_analysis_log_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}.txt')

    # Configure logging
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()  # Also output to console
        ]
    )

    logging.info("=" * 80)
    logging.info("AlphaFold3 K11-Ub2 Binder Analysis Log")
    logging.info("=" * 80)
    logging.info(f"Analysis started at: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

    return log_file

def analyze_confidence_json(json_file):
    """
    Analyze AlphaFold3 JSON confidence files and extract key metrics.
    """
    logging.info(f"Processing JSON file: {os.path.basename(json_file)}")

    try:
        with open(json_file, 'r') as f:
            data = json.load(f)

        # Extract the base name from the file path
        base_name = os.path.basename(json_file)

        # Initialize default values
        fold_id = "unknown"
        model_type = "unknown"
        model_num = "0"

        # Try to extract fold_id (numeric part after "fold_")
        fold_match = re.search(r'fold_(\d+)', base_name)
        if fold_match:
            fold_id = fold_match.group(1)
            logging.info(f"  Extracted fold_id: {fold_id}")
        else:
            # Try alternative patterns
            parts = base_name.split('_')
            if parts and parts[0].isdigit():
                fold_id = parts[0]
                logging.info(f"  Extracted fold_id from filename start: {fold_id}")

        # Try to extract model_type (mpnnX or other identifier)
        if "fold_" in base_name:
            model_match = re.search(r'fold_\d+_([^_]+)', base_name)
            if model_match:
                model_type = model_match.group(1)
                logging.info(f"  Extracted model_type: {model_type}")
        elif len(base_name.split('_')) > 1:
            model_type = base_name.split('_')[1]
            logging.info(f"  Extracted model_type from second part: {model_type}")

        # Extract model number (last digit before .json)
        model_num_match = re.search(r'confidences_(\d+)\.json$', base_name)
        if model_num_match:
            model_num = model_num_match.group(1)
            logging.info(f"  Extracted model_num: {model_num}")
        else:
            alt_match = re.search(r'_(\d+)\.json$', base_name)
            if alt_match:
                model_num = alt_match.group(1)
                logging.info(f"  Extracted model_num from alternative pattern: {model_num}")

        # Extract key metrics
        logging.info("  Extracting metrics from JSON data...")

        # Core metrics
        results = {
            'fold_id': fold_id,
            'model_type': model_type,
            'model_num': model_num,
            'iptm': data.get('iptm', None),
            'ptm': data.get('ptm', None),
            'ranking_score': data.get('ranking_score', None),
            'has_clash': data.get('has_clash', None),
            'fraction_disordered': data.get('fraction_disordered', None),
        }

        logging.info(f"  Core metrics - ipTM: {results['iptm']}, pTM: {results['ptm']}, ranking_score: {results['ranking_score']}")

        # Extract chain-specific metrics
        chain_ptm = data.get('chain_ptm', [])
        if len(chain_ptm) >= 3:
            results['binder_ptm'] = chain_ptm[0]  # Chain A (binder)
            results['ub1_ptm'] = chain_ptm[1]     # Chain B (first ubiquitin)
            results['ub2_ptm'] = chain_ptm[2]     # Chain C (second ubiquitin)
            logging.info(f"  Chain-specific pTM - Binder: {chain_ptm[0]}, Ub1: {chain_ptm[1]}, Ub2: {chain_ptm[2]}")
        else:
            logging.warning(f"  chain_ptm data incomplete or missing: {chain_ptm}")

        # Extract interaction metrics between binder and ubiquitins
        chain_pair_iptm = data.get('chain_pair_iptm', [])
        if len(chain_pair_iptm) >= 3 and len(chain_pair_iptm[0]) >= 3:
            # A-B interaction (binder to first Ub)
            results['binder_ub1_iptm'] = chain_pair_iptm[0][1]
            # A-C interaction (binder to second Ub)
            results['binder_ub2_iptm'] = chain_pair_iptm[0][2]
            # B-C interaction (between ubiquitins)
            results['ub1_ub2_iptm'] = chain_pair_iptm[1][2]

            logging.info(f"  Chain interaction ipTM - Binder-Ub1: {chain_pair_iptm[0][1]}, Binder-Ub2: {chain_pair_iptm[0][2]}, Ub1-Ub2: {chain_pair_iptm[1][2]}")
        else:
            logging.warning(f"  chain_pair_iptm data incomplete or missing: {chain_pair_iptm}")

        # Extract PAE information
        chain_pair_pae_min = data.get('chain_pair_pae_min', [])
        if len(chain_pair_pae_min) >= 3 and len(chain_pair_pae_min[0]) >= 3:
            # A-B interaction (binder to first Ub)
            results['binder_ub1_pae_min'] = chain_pair_pae_min[0][1]
            # A-C interaction (binder to second Ub)
            results['binder_ub2_pae_min'] = chain_pair_pae_min[0][2]
            # B-C interaction (between ubiquitins)
            results['ub1_ub2_pae_min'] = chain_pair_pae_min[1][2]

            logging.info(f"  Chain interaction PAE min - Binder-Ub1: {chain_pair_pae_min[0][1]}, Binder-Ub2: {chain_pair_pae_min[0][2]}, Ub1-Ub2: {chain_pair_pae_min[1][2]}")
        else:
            logging.warning(f"  chain_pair_pae_min data incomplete or missing: {chain_pair_pae_min}")

        return results
    except Exception as e:
        logging.error(f"Error processing JSON file {json_file}: {e}")
        # Return minimal information to avoid breaking the pipeline
        return {
            'fold_id': os.path.basename(json_file).split('_')[1] if '_' in os.path.basename(json_file) else "error",
            'model_type': "error",
            'model_num': "error",
            'error_message': str(e)
        }

def calculate_plddt_from_pdb(pdb_file):
    """
    Calculate the average pLDDT for each chain from PDB/CIF file.
    """
    logging.info(f"Calculating pLDDT from {os.path.basename(pdb_file)}")

    try:
        # Read the file content
        with open(pdb_file, 'r') as f:
            content = f.read()

        # Extract atom lines
        atomLines = [line for line in content.split('\n') if line.startswith('ATOM')]

        if not atomLines:
            logging.warning(f"No ATOM lines found in {pdb_file}")
            return {}

        logging.info(f"  Found {len(atomLines)} ATOM lines")

        # Parse atom data based on the AlphaFold3 mmCIF-like format
        # Format example: ATOM 1    N N   . GLY A 1 1   ? -14.504 -13.150 -32.624 1.00 59.06 1   A 1
        atomData = []
        for line in atomLines:
            parts = [p for p in line.split() if p]

            if len(parts) < 15:
                continue

            try:
                atomData.append({
                    'atomType': parts[2],
                    'residueType': parts[5],
                    'chainId': parts[6],
                    'residueNum': int(parts[8]),
                    'bFactor': float(parts[14])
                })
            except (ValueError, IndexError) as e:
                logging.warning(f"  Error parsing line: {line}, error: {e}")
                continue

        logging.info(f"  Successfully parsed {len(atomData)} atoms")

        # Group by chain
        chainData = {}
        for atom in atomData:
            chainId = atom['chainId']
            if chainId not in chainData:
                chainData[chainId] = []
            chainData[chainId].append(atom)

        logging.info(f"  Found chains: {list(chainData.keys())}")

        # Calculate per-chain average pLDDT
        results = {}

        for chainId, atoms in chainData.items():
            # Group by residue
            residueGroups = {}
            for atom in atoms:
                residueNum = atom['residueNum']
                if residueNum not in residueGroups:
                    residueGroups[residueNum] = []
                residueGroups[residueNum].append(atom['bFactor'])

            num_residues = len(residueGroups)
            logging.info(f"  Chain {chainId}: {len(atoms)} atoms in {num_residues} residues")

            # Calculate per-residue averages
            residueAvgs = [sum(bFactors)/len(bFactors) for bFactors in residueGroups.values()]

            # Calculate chain average (average of residue averages)
            if residueAvgs:
                chainAvg = sum(residueAvgs) / len(residueAvgs)

                # Map chain IDs to our standard names
                if chainId == 'A':
                    results['binder_avg_plddt'] = chainAvg
                    logging.info(f"  Binder (Chain A) average pLDDT: {chainAvg:.2f}")
                elif chainId == 'B':
                    results['ub1_avg_plddt'] = chainAvg
                    logging.info(f"  Ub1 (Chain B) average pLDDT: {chainAvg:.2f}")
                elif chainId == 'C':
                    results['ub2_avg_plddt'] = chainAvg
                    logging.info(f"  Ub2 (Chain C) average pLDDT: {chainAvg:.2f}")

        return results
    except Exception as e:
        logging.error(f"Error processing {pdb_file}: {e}")
        return {}

def match_pdb_to_json(json_file, pdb_dir):
    """
    Find the corresponding PDB/CIF file for a given JSON confidence file.
    """
    base_name = os.path.basename(json_file)
    logging.info(f"Finding matching PDB/CIF file for {base_name}")

    # Try to extract model number from the JSON filename
    model_num = None
    model_match = re.search(r'confidences_(\d+)', base_name)
    if model_match:
        model_num = model_match.group(1)
        logging.info(f"  Extracted model number: {model_num}")
    else:
        # Try alternative pattern
        alt_match = re.search(r'_(\d+)\.json$', base_name)
        if alt_match:
            model_num = alt_match.group(1)
            logging.info(f"  Extracted model number (alt pattern): {model_num}")

    # Try to extract the fold ID
    fold_id = None
    fold_match = re.search(r'fold_(\d+)', base_name)
    if fold_match:
        fold_id = fold_match.group(1)
        logging.info(f"  Extracted fold ID: {fold_id}")
    elif base_name.split('_')[0].isdigit():
        fold_id = base_name.split('_')[0]
        logging.info(f"  Extracted fold ID from filename start: {fold_id}")

    # If we have fold_id, try to find a matching PDB file
    if fold_id:
        # Generate patterns from most specific to most general
        patterns = []

        # If we have a model number, try patterns with it first
        if model_num:
            patterns.append(f"*{fold_id}*_model_{model_num}.cif")
            patterns.append(f"*{fold_id}*model_{model_num}*.cif")

        # Add more general patterns
        patterns.append(f"*{fold_id}*_model_*.cif")
        patterns.append(f"*{fold_id}*model_*.cif")
        patterns.append(f"*{fold_id}*.cif")

        # Try each pattern
        for pattern in patterns:
            logging.info(f"  Trying pattern: {pattern}")
            matching_pdbs = glob.glob(os.path.join(pdb_dir, pattern))

            if matching_pdbs:
                logging.info(f"  Found {len(matching_pdbs)} matching files")

                # If multiple matches, try to find the best one
                if len(matching_pdbs) > 1 and model_num:
                    for pdb_file in matching_pdbs:
                        if f"model_{model_num}" in pdb_file:
                            logging.info(f"  Best match: {os.path.basename(pdb_file)}")
                            return pdb_file

                logging.info(f"  Selected match: {os.path.basename(matching_pdbs[0])}")
                return matching_pdbs[0]
            else:
                logging.info(f"  No matches for pattern: {pattern}")

    # If we get here, no matching PDB was found
    logging.warning(f"No matching PDB/CIF file found for {base_name}")
    return None

def process_folder(base_dir):
    """
    Process all JSON confidence files in a directory and its subdirectories.
    """
    logging.info(f"Processing folder: {base_dir}")

    # Find all JSON confidence files
    json_files = glob.glob(os.path.join(base_dir, "**", "*summary_confidences_*.json"), recursive=True)

    if not json_files:
        logging.warning(f"No JSON confidence files found in {base_dir}")
        # Try a more general pattern
        json_files = glob.glob(os.path.join(base_dir, "**", "*.json"), recursive=True)
        if json_files:
            logging.info(f"Found {len(json_files)} JSON files with general pattern")

    all_results = []
    processed_count = 0
    total_files = len(json_files)

    logging.info(f"Found {total_files} JSON files to process")

    for json_file in json_files:
        processed_count += 1
        if processed_count % 10 == 0 or processed_count == 1 or processed_count == total_files:
            logging.info(f"Processing file {processed_count}/{total_files}: {os.path.basename(json_file)}")

        # Process the JSON file
        json_results = analyze_confidence_json(json_file)

        # Try to find matching PDB/CIF file
        pdb_dir = os.path.dirname(json_file)
        pdb_file = match_pdb_to_json(json_file, pdb_dir)

        if pdb_file:
            # Calculate pLDDT values
            plddt_results = calculate_plddt_from_pdb(pdb_file)
            # Merge with JSON results
            json_results.update(plddt_results)
        else:
            logging.warning(f"  No matching PDB file found for {os.path.basename(json_file)}")

        all_results.append(json_results)

    # Convert to DataFrame
    df = pd.DataFrame(all_results)
    logging.info(f"Created DataFrame with {len(df)} rows and {len(df.columns)} columns")

    # Log column information
    logging.info(f"Columns in DataFrame: {list(df.columns)}")

    # Check for missing values in key columns
    for col in ['iptm', 'binder_ub1_iptm', 'binder_ub2_iptm', 'binder_avg_plddt']:
        if col in df.columns:
            null_count = df[col].isna().sum()
            logging.info(f"Column '{col}': {len(df) - null_count} non-null values, {null_count} null values")

    return df

def calculate_model_averages(df):
    """
    Calculate averages across the models for each fold_id/model_type combination.
    """
    logging.info("Calculating model averages")

    # Make sure we have the required columns
    required_cols = ['fold_id', 'model_type', 'model_num']
    for col in required_cols:
        if col not in df.columns:
            logging.error(f"Required column '{col}' not found in data")
            return pd.DataFrame()  # Return empty dataframe

    # Count unique fold_id and model_type combinations
    unique_combinations = df.groupby(['fold_id', 'model_type']).size().reset_index(name='model_count')
    logging.info(f"Found {len(unique_combinations)} unique fold_id/model_type combinations")

    # Make sure model_num is a string
    df['model_num'] = df['model_num'].astype(str)

    # Group by fold_id and model_type
    grouped = df.groupby(['fold_id', 'model_type'])

    avg_results = []

    # Track models that don't have all 5 models (0-4)
    incomplete_models = []

    logging.info("Processing each fold_id/model_type group...")

    for i, ((fold_id, model_type), group) in enumerate(grouped):
        logging.info(f"Processing group {i+1}/{len(unique_combinations)}: {fold_id}/{model_type}")

        # Get all model numbers for this group
        model_nums = sorted(group['model_num'].unique())
        logging.info(f"  Models in group: {', '.join(model_nums)}")

        # Check if we have all 5 models (0-4)
        expected_models = ['0', '1', '2', '3', '4']
        missing_models = [m for m in expected_models if m not in model_nums]

        if missing_models:
            logging.warning(f"  Missing models: {', '.join(missing_models)}")
            incomplete_models.append((fold_id, model_type, missing_models))

        # Calculate interface metric between binder and both ubiquitins
        if 'binder_ub1_iptm' in group.columns and 'binder_ub2_iptm' in group.columns:
            group['binder_ub_iptm'] = (group['binder_ub1_iptm'] + group['binder_ub2_iptm']) / 2
            logging.info("  Calculated binder_ub_iptm as average of binder_ub1_iptm and binder_ub2_iptm")

        # List of metrics to average
        metrics_to_avg = [
            # Overall metrics
            'iptm', 'ptm', 'ranking_score', 'has_clash', 'fraction_disordered',

            # Chain-specific metrics
            'binder_ptm', 'ub1_ptm', 'ub2_ptm',

            # Interface metrics
            'binder_ub1_iptm', 'binder_ub2_iptm', 'binder_ub_iptm', 'ub1_ub2_iptm',

            # pLDDT metrics
            'binder_avg_plddt', 'ub1_avg_plddt', 'ub2_avg_plddt',

            # PAE metrics
            'binder_ub1_pae_min', 'binder_ub2_pae_min', 'ub1_ub2_pae_min'
        ]

        # Initialize the result dictionary
        avg_data = {
            'fold_id': fold_id,
            'model_type': model_type,
            'num_models': len(group),
            'model_nums': ','.join(model_nums)
        }

        # Calculate average for each metric if it exists
        logging.info("  Calculating averages for metrics:")

        for metric in metrics_to_avg:
            if metric in group.columns:
                # Skip metrics with all NaN values
                if group[metric].isna().all():
                    logging.warning(f"    Metric '{metric}' has all NaN values")
                    avg_data[f'avg_{metric}'] = None
                    continue

                # Count non-null values
                non_null_count = group[metric].count()
                if non_null_count < len(group):
                    logging.warning(f"    Metric '{metric}' has {non_null_count}/{len(group)} non-null values")

                # Special handling for boolean-like metrics
                if metric == 'has_clash' or metric == 'fraction_disordered':
                    avg_value = group[metric].mean() > 0.5
                    avg_data[f'avg_{metric}'] = avg_value
                    logging.info(f"    {metric}: {avg_value} (boolean conversion)")
                else:
                    # Regular numeric average
                    avg_value = group[metric].mean()
                    avg_data[f'avg_{metric}'] = avg_value
                    logging.info(f"    {metric}: {avg_value:.4f}")

        # Find the best model in this group based on ranking score or ipTM
        if 'ranking_score' in group.columns and not group['ranking_score'].isna().all():
            best_idx = group['ranking_score'].idxmax()
            best_model = group.loc[best_idx, 'model_num']
            avg_data['best_model_num'] = best_model
            logging.info(f"  Best model by ranking_score: {best_model}")

            # Add best model metrics
            for metric in ['iptm', 'ranking_score', 'binder_ub_iptm']:
                if metric in group.columns and best_idx in group.index and not pd.isna(group.loc[best_idx, metric]):
                    avg_data[f'best_model_{metric}'] = group.loc[best_idx, metric]
                    logging.info(f"    best_model_{metric}: {group.loc[best_idx, metric]:.4f}")
        elif 'iptm' in group.columns and not group['iptm'].isna().all():
            best_idx = group['iptm'].idxmax()
            best_model = group.loc[best_idx, 'model_num']
            avg_data['best_model_num'] = best_model
            logging.info(f"  Best model by iptm: {best_model}")

            # Add best model metrics
            for metric in ['iptm', 'binder_ub_iptm']:
                if metric in group.columns and best_idx in group.index and not pd.isna(group.loc[best_idx, metric]):
                    avg_data[f'best_model_{metric}'] = group.loc[best_idx, metric]
                    logging.info(f"    best_model_{metric}: {group.loc[best_idx, metric]:.4f}")

        avg_results.append(avg_data)

    # Convert to DataFrame
    avg_df = pd.DataFrame(avg_results)
    logging.info(f"Created averages DataFrame with {len(avg_df)} rows and {len(avg_df.columns)} columns")

    # Report on missing models
    if incomplete_models:
        total_incomplete = len(incomplete_models)
        logging.warning(f"\nFound {total_incomplete} fold_id/model_type combinations with missing models")
        logging.warning("Top 5 incomplete models:")
        for i, (fold_id, model_type, missing) in enumerate(incomplete_models[:5]):
            logging.warning(f"  {fold_id} {model_type}: Missing models {', '.join(missing)}")

        if total_incomplete > 5:
            logging.warning(f"  ...and {total_incomplete - 5} more")

    # Log column information
    logging.info(f"Columns in averages DataFrame: {list(avg_df.columns)}")

    return avg_df

def generate_visualizations(df, avg_df, output_dir):
    """
    Generate visualizations for binder quality and interactions
    """
    logging.info("Generating visualizations")
    os.makedirs(output_dir, exist_ok=True)

    # Focus on just the two main visualizations:
    # 1. Binder Quality vs Binding Interface Quality
    if all(col in avg_df.columns for col in ['avg_binder_plddt', 'avg_binder_ub_iptm']):
        # Filter for valid data
        valid_data = avg_df.dropna(subset=['avg_binder_plddt', 'avg_binder_ub_iptm'])

        logging.info(f"Plotting Binder Quality vs Binding Interface Quality: {len(valid_data)}/{len(avg_df)} data points")

        if valid_data.shape[0] >= 2:
            plt.figure(figsize=(12, 10))

            # Log any rows with missing values
            missing_data = avg_df[avg_df['avg_binder_plddt'].isna() | avg_df['avg_binder_ub_iptm'].isna()]
            if not missing_data.empty:
                logging.warning(f"Found {len(missing_data)} rows with missing values for quality plot:")
                for idx, row in missing_data.iterrows():
                    logging.warning(f"  {row['fold_id']} {row['model_type']}: " +
                                    f"pLDDT={row.get('avg_binder_plddt', 'N/A')}, " +
                                    f"ipTM={row.get('avg_binder_ub_iptm', 'N/A')}")

            # Determine color values
            color_values = valid_data['avg_ranking_score'] if 'avg_ranking_score' in valid_data.columns else None

            logging.info("Creating scatter plot")
            scatter = plt.scatter(
                valid_data['avg_binder_plddt'],
                valid_data['avg_binder_ub_iptm'],
                c=color_values,
                cmap='viridis',
                alpha=0.9,
                s=100,
                edgecolors='black',
                linewidths=0.5
            )

            if color_values is not None:
                plt.colorbar(scatter, label='Average Ranking Score')

            plt.title('Binder Quality vs Binding Interface Quality', fontsize=16)
            plt.xlabel('Average Binder pLDDT', fontsize=14)
            plt.ylabel('Average Binder-Ub Interface ipTM', fontsize=14)
            plt.grid(alpha=0.3)

            # Add fold_id annotations with better positioning and visibility
            for i, row in valid_data.iterrows():
                plt.annotate(
                    row['fold_id'],
                    (row['avg_binder_plddt'], row['avg_binder_ub_iptm']),
                    fontsize=9,
                    xytext=(5, 5),  # Small offset from point
                    textcoords='offset points',
                    fontweight='bold',
                    bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.7, ec="none")
                )

            # Save both regular and high-res versions
            plt.tight_layout()
            plot_path = os.path.join(output_dir, 'binder_quality_vs_interface.png')
            plt.savefig(plot_path, dpi=300, bbox_inches='tight')
            logging.info(f"Saved plot to {plot_path}")
            plt.close()

    # 2. Binder interaction with Ub1 vs Ub2
    if all(col in avg_df.columns for col in ['avg_binder_ub1_iptm', 'avg_binder_ub2_iptm']):
        # Filter for valid data
        valid_data = avg_df.dropna(subset=['avg_binder_ub1_iptm', 'avg_binder_ub2_iptm'])

        logging.info(f"Plotting Binder Interaction with Ub1 vs Ub2: {len(valid_data)}/{len(avg_df)} data points")

        if valid_data.shape[0] >= 2:
            plt.figure(figsize=(10, 8))

            # Determine color values
            color_values = valid_data['avg_iptm'] if 'avg_iptm' in valid_data.columns else None

            logging.info("Creating scatter plot for Ub1 vs Ub2 interaction")
            scatter = plt.scatter(
                valid_data['avg_binder_ub1_iptm'],
                valid_data['avg_binder_ub2_iptm'],
                c=color_values,
                cmap='viridis',
                alpha=0.8,
                s=80,
                edgecolors='black',
                linewidths=0.5
            )

            if color_values is not None:
                plt.colorbar(scatter, label='Average Overall ipTM')

            plt.title('Average Binder Interaction with Ub1 vs Ub2')
            plt.xlabel('Avg Binder-Ub1 ipTM')
            plt.ylabel('Avg Binder-Ub2 ipTM')
            plt.grid(alpha=0.3)

            # Add fold_id annotations with better visibility
            for i, row in valid_data.iterrows():
                plt.annotate(
                    row['fold_id'],
                    (row['avg_binder_ub1_iptm'], row['avg_binder_ub2_iptm']),
                    fontsize=9,
                    xytext=(5, 5),
                    textcoords='offset points',
                    bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.7)
                )

            plt.tight_layout()
            plot_path = os.path.join(output_dir, 'binder_ub1_vs_ub2_interaction.png')
            plt.savefig(plot_path, dpi=300, bbox_inches='tight')
            logging.info(f"Saved plot to {plot_path}")
            plt.close()

    logging.info("Visualization generation complete")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).

Preview of averaged results:


Unnamed: 0,fold_id,model_type,num_models,model_nums,avg_iptm,avg_ptm,avg_ranking_score,avg_has_clash,avg_fraction_disordered,avg_binder_ptm,...,avg_binder_avg_plddt,avg_ub1_avg_plddt,avg_ub2_avg_plddt,avg_binder_ub1_pae_min,avg_binder_ub2_pae_min,avg_ub1_ub2_pae_min,best_model_num,best_model_iptm,best_model_ranking_score,best_model_binder_ub_iptm
0,1010,24,5,1234,0.854,0.89,0.862,False,False,0.878,...,92.741425,90.393442,90.287202,1.24,1.47,1.674,0,0.86,0.87,0.845
1,110287,mpnn3,5,1234,0.704,0.774,0.716,False,False,0.758,...,78.077954,89.142291,87.848592,2.31,2.838,2.332,0,0.71,0.72,0.665
2,150447,mpnn3,5,1234,0.788,0.844,0.8,False,False,0.856,...,86.565249,90.071347,87.909204,1.712,2.592,2.014,0,0.8,0.81,0.765
3,164418,mpnn2,5,1234,0.296,0.582,0.354,False,False,0.854,...,82.539048,78.21539,79.281446,8.942,7.788,14.482,0,0.34,0.39,0.32
4,164418,mpnn3,5,1234,0.856,0.888,0.862,False,False,0.884,...,91.233013,89.212333,89.441769,1.266,1.594,1.752,0,0.86,0.87,0.85



Detailed log saved to: /content/drive/MyDrive/PDB-files/AF3-recalc_folds_2025_03_08_19_33/AF3_ANALYSIS_RESULTS/af3_analysis_log_20250308_214622.txt
