<a href="https://colab.research.google.com/github/eoinleen/Protein-design-random/blob/main/Copy_of_Total_analysis_RFdiff_v3_ind-err_578.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
RFdiffusion Structure Analysis and Sequence Extraction Tool
========================================================
Created: January 31, 2025
Authors: Original Analysis - Dr. Eoin Leen, University of Leeds
         Visualization & Integration - Claude AI & Dr. Eoin Leen
Version: 2.0

Purpose:
--------
Combined pipeline for:
1. Structural analysis of PDB files
2. AF2 score visualization
3. Sequence extraction and formatting
4. Generation of publication-ready visualizations

Input Required:
-------------
1. Directory containing PDB files
2. af2_scores.csv file in same directory containing:
   - design: Design number
   - n: Sequence number
   - seq: Sequences in format "sequence1/sequence2"
   - i_pae: iPAE scores
   - Other AF2 metrics

Output Generated:
---------------
1. PowerPoint presentation with:
   - Slide 1: Structure-function correlation plots
   - Slide 2: iPAE score visualization
   - Slide 3: Top 10 sequences by iPAE score
   - Slide 4: Detailed interface analysis for structures with i_PAE < 7.5
2. Combined FASTA file with all sequences
3. CSV file with combined structural analysis

Analysis Parameters:
------------------
1. Hydrogen Bonds:
   - Distance cutoff: O-N distance < 3.5 Å
   - Calculated between backbone atoms only
   - Only inter-chain H-bonds counted

2. Salt Bridges:
   - Distance cutoff: < 4.0 Å between any atoms of residue pairs
   - Residue pairs considered:
     * Acidic: ASP, GLU
     * Basic: LYS, ARG, HIS
   - Only inter-chain salt bridges counted

3. Hydrophobic Contacts:
   - Distance cutoff: < 5.0 Å between any atoms of residue pairs
   - Hydrophobic residues considered:
     * ALA, VAL, LEU, ILE, MET, PHE, TRP, PRO
   - Only inter-chain contacts counted

4. Buried Surface Area:
   - Calculated using FreeSASA algorithm
   - Uses default atomic radii from FreeSASA (based on NACCESS/RSA)
   - Process:
     * First calculates SASA for entire complex
     * Then calculates SASA for each chain individually
     * BSA = (Sum of individual chain SASAs - Complex SASA) / 2
   - Units: Å²
   - Inter-chain burial only (interface area)
   - Probe radius: 1.4 Å (water molecule)
   - Resolution: 100 points/atom (FreeSASA default)

5. Interface Analysis (for structures with i_PAE < 7.5):
   - Core Region: Residues with >90% SASA burial upon complex formation
   - Rim Region: Residues with 10-90% SASA burial
   - Residue Classification:
     * Hydrophobic: ALA, VAL, LEU, ILE, MET, PHE, TRP, PRO
     * Polar: SER, THR, ASN, GLN, TYR, CYS
     * Charged: ASP, GLU, LYS, ARG, HIS
   - SASA calculated using FreeSASA with default parameters
     * Probe radius: 1.4 Å
     * Per-residue SASA summed from atomic areas

6. Clash Score:
   - Calculated as clashes per 1000 atoms
   - Clash defined as: non-bonded atoms closer than sum of vdW radii - 0.4 Å
   - Only inter-chain clashes considered
   - Hydrogen atoms excluded
   - Van der Waals radii used:
     * C: 1.7 Å
     * N: 1.55 Å
     * O: 1.52 Å
     * S: 1.8 Å
     * P: 1.8 Å
     * Halogens: F: 1.47 Å, Cl: 1.75 Å, Br: 1.85 Å, I: 1.98 Å
"""

# Install required packages
!pip install -q biopython pandas freesasa numpy matplotlib seaborn python-pptx plotly kaleido

# Import required libraries
import os
import sys
import time
import pandas as pd
import matplotlib.pyplot as plt
from google.colab import files, drive
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from Bio import PDB
from Bio.PDB.PDBIO import PDBIO
from Bio.PDB.Polypeptide import is_aa
from Bio.PDB.Structure import Structure
import freesasa
import numpy as np
import seaborn as sns
from pptx import Presentation
from pptx.util import Inches, Cm, Pt
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Custom exception for structure validation
class StructureValidationError(Exception):
    pass
# ===============================
# Structure Analysis Functions
# ===============================

def validate_pdb_file(file_path: str) -> bool:
    """Validates if file exists and has proper PDB format."""
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"PDB file not found: {file_path}")
    try:
        with open(file_path, 'r') as f:
            first_line = f.readline()
            if not any(marker in first_line for marker in ['HEADER', 'ATOM', 'MODEL']):
                raise StructureValidationError(f"Invalid PDB: {file_path}")
    except UnicodeDecodeError:
        raise StructureValidationError(f"Not a valid text file: {file_path}")
    return True

def safe_structure_load(parser: PDB.PDBParser, file_path: str) -> Optional[Structure]:
    """Safely loads PDB structure with error handling."""
    try:
        validate_pdb_file(file_path)
        structure = parser.get_structure('protein', file_path)
        if not list(structure.get_models()):
            raise StructureValidationError("No models")
        if not list(list(structure.get_models())[0].get_chains()):
            raise StructureValidationError("No chains")
        return structure
    except Exception as e:
        print(f"Error loading {file_path}: {str(e)}")
        return None

def calculate_buried_surface_area(pdb_file: str) -> Tuple[Optional[float], Optional[Dict[str, float]]]:
    """Calculates buried surface area between chains."""
    parser = PDB.PDBParser(QUIET=True)
    structure = safe_structure_load(parser, pdb_file)
    if not structure:
        return None, None
    try:
        chains = list(structure.get_chains())
        if len(chains) < 2:
            print(f"Warning: {pdb_file} has fewer than 2 chains")
            return None, None

        combined_structure = freesasa.Structure(pdb_file)
        result = freesasa.calc(combined_structure)
        total_area = result.totalArea()

        chain_areas = {}
        io = PDBIO()
        temp_files = []

        for chain in chains:
            new_structure = PDB.Structure.Structure('temp')
            new_model = PDB.Model.Model(0)
            new_structure.add(new_model)
            new_model.add(chain.copy())

            temp_file = f"temp_chain_{chain.id}.pdb"
            temp_files.append(temp_file)

            io.set_structure(new_structure)
            io.save(temp_file)

            chain_structure = freesasa.Structure(temp_file)
            chain_result = freesasa.calc(chain_structure)
            chain_areas[chain.id] = chain_result.totalArea()

        for temp_file in temp_files:
            if os.path.exists(temp_file):
                os.remove(temp_file)

        total_individual_area = sum(chain_areas.values())
        buried_surface_area = abs(total_individual_area - total_area) / 2
        return buried_surface_area, chain_areas

    except Exception as e:
        print(f"Error calculating BSA for {pdb_file}: {str(e)}")
        return None, None

def calculate_hydrogen_bonds(structure: Structure) -> int:
    """Calculates number of hydrogen bonds between chains."""
    try:
        h_bonds = 0
        for chain1 in structure.get_chains():
            for chain2 in structure.get_chains():
                if chain1.id >= chain2.id:
                    continue
                for res1 in chain1.get_residues():
                    if not is_aa(res1):
                        continue
                    for res2 in chain2.get_residues():
                        if not is_aa(res2):
                            continue
                        if 'O' in res1 and 'N' in res2:
                            distance = res1['O'] - res2['N']
                            if distance < 3.5:
                                h_bonds += 1
        return h_bonds
    except Exception as e:
        print(f"Error calculating H-bonds: {str(e)}")
        return 0
def calculate_hydrophobic_contacts(structure: Structure) -> int:
    """
    Calculates number of hydrophobic contacts between chains.
    Considers residues ALA, VAL, LEU, ILE, MET, PHE, TRP, PRO.
    Contact is counted if distance < 5.0 Å.
    """
    try:
        hydrophobic_residues = {'ALA', 'VAL', 'LEU', 'ILE', 'MET', 'PHE', 'TRP', 'PRO'}
        contacts = 0
        for chain1 in structure.get_chains():
            for chain2 in structure.get_chains():
                if chain1.id >= chain2.id:
                    continue
                for res1 in chain1.get_residues():
                    if not is_aa(res1) or res1.get_resname() not in hydrophobic_residues:
                        continue
                    for res2 in chain2.get_residues():
                        if not is_aa(res2) or res2.get_resname() not in hydrophobic_residues:
                            continue
                        min_distance = float('inf')
                        for atom1 in res1.get_atoms():
                            for atom2 in res2.get_atoms():
                                distance = atom1 - atom2
                                min_distance = min(min_distance, distance)
                        if min_distance < 5.0:
                            contacts += 1
        return contacts
    except Exception as e:
        print(f"Error calculating hydrophobic contacts: {str(e)}")
        return 0

def calculate_salt_bridges(structure: Structure) -> int:
    """
    Calculates number of salt bridges between chains.
    Salt bridge is counted between ASP/GLU and LYS/ARG/HIS if distance < 4.0 Å.
    """
    try:
        acidic = {'ASP', 'GLU'}
        basic = {'LYS', 'ARG', 'HIS'}
        salt_bridges = 0
        for chain1 in structure.get_chains():
            for chain2 in structure.get_chains():
                if chain1.id >= chain2.id:
                    continue
                for res1 in chain1.get_residues():
                    if not is_aa(res1):
                        continue
                    res1_name = res1.get_resname()
                    for res2 in chain2.get_residues():
                        if not is_aa(res2):
                            continue
                        res2_name = res2.get_resname()
                        if ((res1_name in acidic and res2_name in basic) or
                            (res1_name in basic and res2_name in acidic)):
                            min_distance = float('inf')
                            for atom1 in res1.get_atoms():
                                for atom2 in res2.get_atoms():
                                    distance = atom1 - atom2
                                    min_distance = min(min_distance, distance)
                            if min_distance < 4.0:
                                salt_bridges += 1
        return salt_bridges
    except Exception as e:
        print(f"Error calculating salt bridges: {str(e)}")
        return 0

def save_results_as_df(results: List[Dict[str, Any]], output_file: str) -> pd.DataFrame:
    """
    Converts analysis results to DataFrame and saves to CSV.
    Extracts design and variant numbers from filenames.
    """
    analysis_data = []
    for result in results:
        filename = result['file_name'].replace('.pdb', '')
        try:
            design_num = int(filename.split('design')[1].split('_')[0])
            variant_num = int(filename.split('_n')[1])
            analysis_data.append({
                'design': design_num,
                'n': variant_num,
                'buried_surface_area': result['buried_surface_area'] if result['buried_surface_area'] else 0,
                'hydrogen_bonds': result['hydrogen_bonds'],
                'hydrophobic_contacts': result['hydrophobic_contacts'],
                'salt_bridges': result['salt_bridges']
            })
        except Exception as e:
            print(f"Error parsing filename {filename}: {str(e)}")
            continue

    df = pd.DataFrame(analysis_data)
    df = df.sort_values(['design', 'n']).reset_index(drop=True)
    df.to_csv(output_file, index=False)
    print(f"Saved structure analysis to {output_file}")
    return df

def merge_with_af2_scores(structure_df: pd.DataFrame, af2_scores_file: str) -> pd.DataFrame:
    """Merges structural analysis results with AF2 scores."""
    af2_df = pd.read_csv(af2_scores_file)
    merged_df = pd.merge(af2_df, structure_df, on=['design', 'n'], how='left')
    merged_df = merged_df.sort_values(['design', 'n']).reset_index(drop=True)
    return merged_df
# ===============================
# Interface Analysis Functions
# ===============================

def calculate_residue_sasa(structure: Structure, chain_id: str, complex: bool = True) -> Dict[str, float]:
    """
    Calculates SASA for each residue in a chain.

    Args:
        structure: PDB Structure object
        chain_id: Chain identifier
        complex: If True, calculates SASA in context of complex; if False, treats chain in isolation

    Returns:
        Dictionary of residue IDs and their SASA values
    """
    # Create temporary PDB for SASA calculation
    io = PDBIO()
    if not complex:
        # Create new structure with just the chain of interest
        new_structure = PDB.Structure.Structure('temp')
        new_model = PDB.Model.Model(0)
        new_structure.add(new_model)
        target_chain = structure[0][chain_id]
        new_model.add(target_chain)
        structure = new_structure

    io.set_structure(structure)
    temp_file = f"temp_sasa_{chain_id}.pdb"
    io.save(temp_file)

    # Calculate SASA
    freesasa_struct = freesasa.Structure(temp_file)
    result = freesasa.calc(freesasa_struct)

    # Get per-residue SASA
    residue_sasa = {}
    chain = structure[0][chain_id]
    for residue in chain:
        res_id = f"{residue.get_resname()}_{residue.id[1]}"
        sasa = sum(result.residueAreas()[chain_id].residueAreas[residue.id[1]].total)
        residue_sasa[res_id] = sasa

    # Cleanup
    os.remove(temp_file)
    return residue_sasa

def analyze_interface_details(structure: Structure) -> Dict:
    """
    Performs comprehensive interface analysis.

    Calculates:
    1. Core/Rim classification (>90% burial for core, 10-90% for rim)
    2. Residue composition analysis
    3. Interface shape parameters
    """
    results = {}

    # Residue classifications
    hydrophobic = {'ALA', 'VAL', 'LEU', 'ILE', 'MET', 'PHE', 'TRP', 'PRO'}
    polar = {'SER', 'THR', 'ASN', 'GLN', 'TYR', 'CYS'}
    charged = {'ASP', 'GLU', 'LYS', 'ARG', 'HIS'}

    core_residues = {'hydrophobic': 0, 'polar': 0, 'charged': 0}
    rim_residues = {'hydrophobic': 0, 'polar': 0, 'charged': 0}

    # Analyze each chain
    for chain in structure.get_chains():
        # Calculate SASA for isolated chain
        monomer_sasa = calculate_residue_sasa(structure, chain.id, complex=False)
        # Calculate SASA in complex
        complex_sasa = calculate_residue_sasa(structure, chain.id, complex=True)

        for residue_id, monomer_value in monomer_sasa.items():
            if monomer_value < 0.1:  # Skip buried residues
                continue

            complex_value = complex_sasa.get(residue_id, 0)
            burial_percent = (monomer_value - complex_value) / monomer_value * 100

            # Get residue type
            res_name = residue_id.split('_')[0]
            if res_name in hydrophobic:
                res_type = 'hydrophobic'
            elif res_name in polar:
                res_type = 'polar'
            elif res_name in charged:
                res_type = 'charged'
            else:
                continue

            # Classify as core or rim
            if burial_percent > 90:
                core_residues[res_type] += 1
            elif burial_percent > 10:
                rim_residues[res_type] += 1

    # Calculate statistics
    total_core = sum(core_residues.values())
    total_rim = sum(rim_residues.values())

    results = {
        'core_count': total_core,
        'rim_count': total_rim,
        'core_rim_ratio': total_core / max(1, total_rim),
        'core_hydrophobic': round(100 * core_residues['hydrophobic'] / max(1, total_core)),
        'core_polar': round(100 * core_residues['polar'] / max(1, total_core)),
        'core_charged': round(100 * core_residues['charged'] / max(1, total_core)),
        'rim_hydrophobic': round(100 * rim_residues['hydrophobic'] / max(1, total_rim)),
        'rim_polar': round(100 * rim_residues['polar'] / max(1, total_rim)),
        'rim_charged': round(100 * rim_residues['charged'] / max(1, total_rim))
    }

    return results

def create_clash_score(structure: Structure) -> float:
    """
    Calculates clash score for structure.

    Clash defined as:
    - Non-bonded atoms closer than sum of van der Waals radii minus 0.4Å
    - Only considers inter-chain clashes
    - Hydrogens not considered
    """
    # Van der Waals radii (Å)
    vdw_radii = {
        'C': 1.7, 'N': 1.55, 'O': 1.52, 'S': 1.8,
        'P': 1.8, 'F': 1.47, 'Cl': 1.75, 'Br': 1.85, 'I': 1.98
    }

    clash_count = 0
    total_atoms = 0

    # Iterate through chain pairs
    chains = list(structure.get_chains())
    for i, chain1 in enumerate(chains):
        for chain2 in chains[i+1:]:
            # Get heavy atoms
            atoms1 = [atom for atom in chain1.get_atoms()
                     if atom.element != 'H' and atom.element in vdw_radii]
            atoms2 = [atom for atom in chain2.get_atoms()
                     if atom.element != 'H' and atom.element in vdw_radii]

            # Check for clashes
            for atom1 in atoms1:
                for atom2 in atoms2:
                    distance = atom1 - atom2
                    min_distance = vdw_radii[atom1.element] + vdw_radii[atom2.element] - 0.4

                    if distance < min_distance:
                        clash_count += 1

            total_atoms += len(atoms1) + len(atoms2)

    # Calculate clashes per 1000 atoms
    clash_score = (1000 * clash_count) / max(1, total_atoms)
    return clash_score
# ===============================
# Visualization Functions
# ===============================

def create_pptx_plots(df: pd.DataFrame, output_dir: str, timestamp: str):
    """
    Creates PowerPoint presentation with four slides:
    1. Structure-function correlation plots
    2. iPAE visualization
    3. Top 10 lowest iPAE sequences
    4. Detailed interface analysis for low iPAE structures
    """
    # Initialize presentation
    prs = Presentation()
    prs.slide_width = Cm(21)
    prs.slide_height = Cm(29.7)

    # First slide - correlation plots
    print("Creating correlation plots...")
    slide1 = prs.slides.add_slide(prs.slide_layouts[5])

    fig, axes = plt.subplots(3, 2, figsize=(8.27, 11.69))
    axes = axes.flatten()

    y_vars = ['i_ptm', 'rmsd', 'buried_surface_area',
              'hydrogen_bonds', 'hydrophobic_contacts', 'salt_bridges']
    titles = ['iPTM', 'RMSD (Å)', 'Buried Surface Area (Å²)',
             '# of Hydrogen Bonds', '# of Hydrophobic Contacts', '# of Salt Bridges']

    for ax, y_var, title in zip(axes, y_vars, titles):
        sns.scatterplot(data=df, x='i_pae', y=y_var, ax=ax, color='black', marker='x', s=16)
        ax.set_xlabel('i_PAE')
        ax.set_ylabel(title)
        ax.set_title(title)
        ax.set_facecolor('white')

    fig.patch.set_facecolor('white')
    plt.tight_layout()

    temp_img1 = os.path.join(output_dir, 'temp_plots1.png')
    plt.savefig(temp_img1, bbox_inches='tight', dpi=300, facecolor='white')
    plt.close()

    left = Cm(2)
    top = Cm(2)
    slide1.shapes.add_picture(temp_img1, left, top)

    # Second slide - iPAE visualization
    print("Creating iPAE visualization...")
    slide2 = prs.slides.add_slide(prs.slide_layouts[5])

    fig = make_subplots(
        rows=4,
        cols=1,
        vertical_spacing=0.08,
        subplot_titles=[f"Designs {i*8}-{(i+1)*8-1}" for i in range(4)]
    )

    rows_per_subplot = 512  # 8 designs × 64 sequences = 512 rows per subplot
    colors = ['black', 'red']

    for i in range(4):
        start_idx = i * rows_per_subplot
        end_idx = start_idx + rows_per_subplot
        chunk = df.iloc[start_idx:end_idx].copy()

        for design_num in chunk['design'].unique():
            mask = chunk['design'] == design_num
            color = colors[design_num % 2]

            fig.add_trace(
                go.Bar(
                    x=chunk[mask].index,
                    y=chunk[mask]['i_pae'],
                    showlegend=False,
                    marker_color=color,
                    width=1,
                ),
                row=i+1,
                col=1
            )
fig.update_yaxes(
            range=[0, 30],
            title_text='iPAE' if i == 1 else None,
            row=i+1,
            col=1
        )

        design_numbers = sorted(chunk['design'].unique())
        fig.update_xaxes(
            tickmode='array',
            ticktext=design_numbers,
            tickvals=[start_idx + (j*64) + 32 for j in range(len(design_numbers))],
            row=i+1,
            col=1,
            title_text='Design Number' if i == 3 else None
        )

    fig.update_layout(
        title='iPAE Scores by Design Number and Sequence (Scale: 0-30)',
        height=1000,
        width=1200,
        showlegend=False,
        margin=dict(t=50, b=50, r=150, l=50),
        paper_bgcolor='white',
        plot_bgcolor='white'
    )

    temp_img2 = os.path.join(output_dir, 'temp_plots2.png')
    fig.write_image(temp_img2)

    left = Cm(1)
    top = Cm(1)
    slide2.shapes.add_picture(temp_img2, left, top)

    # Third slide - Top 10 sequences
    print("Creating top 10 sequences slide...")
    slide3 = prs.slides.add_slide(prs.slide_layouts[5])

    # Get top 10 lowest i_PAE sequences
    top_10_sequences = df.nsmallest(10, 'i_pae')[['design', 'n', 'i_pae', 'seq']]

    # Add title
    title = slide3.shapes.title
    title.text = "Top 10 Sequences (Lowest i_PAE Scores)"

    # Create text box for sequences
    left = Cm(2)
    top = Cm(4)
    width = Cm(17)
    height = Cm(20)
    textbox = slide3.shapes.add_textbox(left, top, width, height)
    text_frame = textbox.text_frame
    text_frame.clear()

    # Add sequences
    for _, row in top_10_sequences.iterrows():
        sequence = row['seq'].split('/')[1].strip()
        p = text_frame.add_paragraph()
        p.text = f">d{row['design']}n{row['n']} (i_PAE: {row['i_pae']:.4f})\n{sequence}"
        p.font.name = 'Courier New'
        p.font.size = Pt(8)
        p.line_spacing = 1.0

    # Fourth slide - Detailed interface analysis
    print("Creating interface analysis for low i_PAE structures...")
    slide4 = prs.slides.add_slide(prs.slide_layouts[5])

    # Add title
    title = slide4.shapes.title
    title.text = "Detailed Interface Analysis (Structures with i_PAE < 7.5)"

    # Create text box
    left = Cm(2)
    top = Cm(4)
    width = Cm(17)
    height = Cm(20)
    textbox = slide4.shapes.add_textbox(left, top, width, height)
    text_frame = textbox.text_frame
    text_frame.clear()

    # Get low i_PAE structures
    low_ipae_structures = df[df['i_pae'] < 7.5].sort_values('i_pae')

    if len(low_ipae_structures) == 0:
        p = text_frame.add_paragraph()
        p.text = "No structures found with i_PAE < 7.5"
        p.font.name = 'Courier New'
        p.font.size = Pt(8)
    else:
        parser = PDB.PDBParser(QUIET=True)

        for _, row in low_ipae_structures.iterrows():
            pdb_file = os.path.join(os.path.dirname(output_dir),
                                  f"design{row['design']}_n{row['n']}.pdb")
            structure = safe_structure_load(parser, pdb_file)

            if structure:
                interface_analysis = analyze_interface_details(structure)
                clash_score = create_clash_score(structure)

                p = text_frame.add_paragraph()
                p.text = (
                    f"Structure d{row['design']}n{row['n']} (i_PAE: {row['i_pae']:.2f})\n"
                    f"Buried Surface Area: {row['buried_surface_area']:.1f} Å²\n"
                    f"Clash Score: {clash_score:.2f}\n"
                    f"Interface Analysis:\n"
                    f"  Core Residues: {interface_analysis['core_count']}\n"
                    f"  Rim Residues: {interface_analysis['rim_count']}\n"
                    f"  Core/Rim ratio: {interface_analysis['core_rim_ratio']:.2f}\n"
                    f"  Core Composition:\n"
                    f"    Hydrophobic: {interface_analysis['core_hydrophobic']}%\n"
                    f"    Polar: {interface_analysis['core_polar']}%\n"
                    f"    Charged: {interface_analysis['core_charged']}%\n"
                    f"  Rim Composition:\n"
                    f"    Hydrophobic: {interface_analysis['rim_hydrophobic']}%\n"
                    f"    Polar: {interface_analysis['rim_polar']}%\n"
                    f"    Charged: {interface_analysis['rim_charged']}%\n"
                    f"----------------------------------------\n"
                )
                p.font.name = 'Courier New'
                p.font.size = Pt(8)
                p.line_spacing = 1.0

    # Save PowerPoint
    output_basename = os.path.basename(output_dir)
    pptx_path = os.path.join(output_dir, f"{output_basename}_{timestamp}_analysis.pptx")
    prs.save(pptx_path)

    # Clean up temporary files
    os.remove(temp_img1)
    os.remove(temp_img2)
    print(f"Saved PowerPoint to {pptx_path}")
# ===============================
# Main Processing Functions
# ===============================

def process_multiple_pdb_files(pdb_directory: str, af2_scores_file: str = None) -> pd.DataFrame:
    """
    Main processing function that:
    1. Analyzes all PDB files in directory
    2. Merges with AF2 scores
    3. Generates visualizations and outputs
    4. Saves sequences to FASTA
    """
    if not os.path.exists(pdb_directory):
        raise FileNotFoundError(f"Directory not found: {pdb_directory}")

    # Get timestamp for file naming
    timestamp = time.strftime("%y%m%d")

    # Initialize results
    results = []
    parser = PDB.PDBParser(QUIET=True)
    pdb_files = [f for f in os.listdir(pdb_directory) if f.endswith('.pdb')]

    if not pdb_files:
        print(f"No PDB files found in {pdb_directory}")
        return pd.DataFrame()

    print(f"Processing {len(pdb_files)} PDB files...")
    total_files = len(pdb_files)

    # Process each PDB file
    for idx, file_name in enumerate(pdb_files, 1):
        pdb_file = os.path.join(pdb_directory, file_name)
        print(f"Processing file {idx}/{total_files}: {file_name}")

        structure = safe_structure_load(parser, pdb_file)
        if not structure:
            continue

        # Calculate structural parameters
        buried_surface_area, chain_areas = calculate_buried_surface_area(pdb_file)
        h_bonds = calculate_hydrogen_bonds(structure)
        hydrophobic = calculate_hydrophobic_contacts(structure)
        salt_bridges = calculate_salt_bridges(structure)

        results.append({
            'file_name': file_name,
            'buried_surface_area': buried_surface_area,
            'hydrogen_bonds': h_bonds,
            'hydrophobic_contacts': hydrophobic,
            'salt_bridges': salt_bridges,
            'chain_areas': chain_areas
        })

    # Save structural analysis
    output_basename = os.path.basename(pdb_directory)
    structure_csv = os.path.join(pdb_directory, f"{output_basename}_{timestamp}_structure.csv")
    structure_df = save_results_as_df(results, structure_csv)

    # If AF2 scores exist, merge and create visualizations
    if af2_scores_file and os.path.exists(af2_scores_file):
        print(f"Merging with AF2 scores from {af2_scores_file}")
        final_df = merge_with_af2_scores(structure_df, af2_scores_file)

        # Save combined analysis
        combined_csv = os.path.join(pdb_directory, f"{output_basename}_{timestamp}_combined.csv")
        final_df.to_csv(combined_csv, index=False)
        print(f"Saved combined results to {combined_csv}")

        # Create PowerPoint plots
        create_pptx_plots(final_df, pdb_directory, timestamp)

        # Save sequences to FASTA
        fasta_path = os.path.join(pdb_directory, f"{output_basename}_{timestamp}_sequences.fasta")
        with open(fasta_path, 'w') as f:
            for _, row in final_df.iterrows():
                sequence = row['seq'].split('/')[1].strip()
                header = f">d{row['design']}n{row['n']}"
                f.write(f"{header}\n{sequence}\n")
        print(f"Saved sequences to {fasta_path}")

        return final_df
    return structure_df

# ===============================
# Main Execution
# ===============================

if __name__ == "__main__":
    # Mount Google Drive
    drive.mount('/content/drive')

    # Set directory containing PDB files and AF2 scores
    pdb_directory = '/content/drive/MyDrive/PDB-files/202501xx/3NOB-70-110-all_pdb'  # Update this path
    af2_scores_path = os.path.join(pdb_directory, 'af2_scores.csv')

    if not os.path.exists(af2_scores_path):
        af2_scores_path = None
        print("No AF2 scores file found - will generate structure analysis only")

    print("\nStarting analysis...")
    print(f"Processing PDB files from: {pdb_directory}")

    try:
        results_df = process_multiple_pdb_files(pdb_directory, af2_scores_path)
        print("\nAnalysis completed successfully!")
    except Exception as e:
        print(f"\nError during analysis: {str(e)}")
        raise



IndentationError: unindent does not match any outer indentation level (<tokenize>, line 578)