# Notebook for filtering by minibinder location

Workbook to understand mmCIF parsing

In [None]:
import numpy as np
from Bio.PDB.MMCIFParser import MMCIFParser
from Bio.PDB import *
import pandas as pd
import os
import shutil
import matplotlib.pyplot as plt

# Functions
def parse_mmcif_coordinates(mmcif_file_path):
    """
    Parse an mmCIF file and extract atomic coordinates into numpy arrays
    
    Args:
        mmcif_file_path: Path to the .cif file
    
    Returns:
        dict: Contains coordinates, atom names, residue info, etc.
    """
    parser = MMCIFParser(QUIET=True)
    structure = parser.get_structure('structure', mmcif_file_path)
    
    coordinates = []
    atom_names = []
    residue_names = []
    residue_numbers = []
    chain_ids = []
    
    # Extract data from all atoms
    for model in structure:
        for chain in model:
            for residue in chain:
                for atom in residue:
                    coordinates.append(atom.get_coord())
                    atom_names.append(atom.get_name())
                    residue_names.append(residue.get_resname())
                    residue_numbers.append(residue.get_id()[1])
                    chain_ids.append(chain.get_id())
    
    # Convert to numpy arrays
    coord_array = np.array(coordinates)
    
    return {
        'coordinates': coord_array,
        'atom_names': np.array(atom_names),
        'residue_names': np.array(residue_names),
        'residue_numbers': np.array(residue_numbers),
        'chain_ids': np.array(chain_ids)
    }

def calculate_COM(coordinates):
    """
    Calculate the center of mass for a set of coordinates.
    
    Args:
        coordinates: numpy array of shape (N, 3) where N is the number of atoms.
    Returns:
        numpy array of shape (3,) representing the center of mass.
    """
    return np.mean(coordinates, axis=0)

def get_chain_coordinates(mmcif_data, chain_id):
    """
    Extract coordinates for a specific chain
    
    Args:
        mmcif_data: Dictionary from parse_mmcif_coordinates
        chain_id: Chain identifier (e.g., 'A', 'B')
    
    Returns:
        dict: Filtered data for the specified chain
    """
    chain_mask = mmcif_data['chain_ids'] == chain_id
    
    return {
        'coordinates': mmcif_data['coordinates'][chain_mask],
        'atom_names': mmcif_data['atom_names'][chain_mask],
        'residue_names': mmcif_data['residue_names'][chain_mask],
        'residue_numbers': mmcif_data['residue_numbers'][chain_mask],
        'chain_ids': mmcif_data['chain_ids'][chain_mask]
    }

def get_residue_coordinates(mmcif_data, chain_id, residue_number):
    """
    Extract coordinates for a specific residue in a specific chain
    
    Args:
        mmcif_data: Dictionary from parse_mmcif_coordinates
        chain_id: Chain identifier (e.g., 'A', 'B')
        residue_number: Residue number
    
    Returns:
        dict: Filtered data for the specified residue
    """
    residue_mask = (mmcif_data['chain_ids'] == chain_id) & (mmcif_data['residue_numbers'] == residue_number)
    
    return {
        'coordinates': mmcif_data['coordinates'][residue_mask],
        'atom_names': mmcif_data['atom_names'][residue_mask],
        'residue_names': mmcif_data['residue_names'][residue_mask],
        'residue_numbers': mmcif_data['residue_numbers'][residue_mask],
        'chain_ids': mmcif_data['chain_ids'][residue_mask]
    }

def get_residue_range_coordinates(mmcif_data, chain_id, start_res, end_res):
    """Extract coordinates for a range of residues"""
    range_mask = (mmcif_data['chain_ids'] == chain_id) & \
                 (mmcif_data['residue_numbers'] >= start_res) & \
                 (mmcif_data['residue_numbers'] <= end_res)
    
    return {
        'coordinates': mmcif_data['coordinates'][range_mask],
        'atom_names': mmcif_data['atom_names'][range_mask],
        'residue_names': mmcif_data['residue_names'][range_mask],
        'residue_numbers': mmcif_data['residue_numbers'][range_mask],
        'chain_ids': mmcif_data['chain_ids'][range_mask]
    }

def calculate_distance(coord1, coord2):
    """
    Calculate Euclidean distance between two points.
    
    Args:
        coord1: numpy array of shape (3,)
        coord2: numpy array of shape (3,)
    
    Returns:
        float: Euclidean distance
    """
    return np.linalg.norm(coord1 - coord2)

def process_all_cif_files(cif_directory, binder_chain, receptor_chain, receptor_residue_range):
    """
    Process all CIF files in a directory and calculate distances between 
    binder COM and receptor binding site COM
    
    Args:
        cif_directory: Path to directory containing CIF files
        binder_chain: Chain ID of the binder (e.g., 'A')
        receptor_chain: Chain ID of the receptor (e.g., 'B') 
        receptor_residue_range: Tuple (start_res, end_res) for binding site
    
    Returns:
        pandas.DataFrame: Results with filename and distance
    """
    results = []
    
    # Get all .cif files in the directory
    cif_files = [f for f in os.listdir(cif_directory) if f.endswith('.cif')]
    
    print(f"Found {len(cif_files)} CIF files to process...")
    
    for i, filename in enumerate(cif_files):
        try:
            # Parse the CIF file
            file_path = os.path.join(cif_directory, filename)
            mmcif_data = parse_mmcif_coordinates(file_path)
            
            # Get binder chain coordinates
            binder_data = get_chain_coordinates(mmcif_data, binder_chain)
            
            # Get receptor binding site coordinates
            receptor_bindsite_data = get_residue_range_coordinates(
                mmcif_data, receptor_chain, 
                receptor_residue_range[0], receptor_residue_range[1]
            )
            
            # Calculate center of mass for both
            if len(binder_data['coordinates']) > 0 and len(receptor_bindsite_data['coordinates']) > 0:
                binder_com = calculate_COM(binder_data['coordinates'])
                receptor_bindsite_com = calculate_COM(receptor_bindsite_data['coordinates'])
                
                # Calculate distance between COMs
                distance = calculate_distance(binder_com, receptor_bindsite_com)
                
                results.append({
                    'file_name': filename,
                    'binder_chain': binder_chain,
                    'receptor_chain': receptor_chain,
                    'binding_site_residues': f"{receptor_residue_range[0]}-{receptor_residue_range[1]}",
                    'distance_angstroms': distance,
                    'binder_atom_count': len(binder_data['coordinates']),
                    'bindsite_atom_count': len(receptor_bindsite_data['coordinates'])
                })
            else:
                print(f"Warning: Missing chains in {filename}")
                results.append({
                    'file_name': filename,
                    'binder_chain': binder_chain,
                    'receptor_chain': receptor_chain,
                    'binding_site_residues': f"{receptor_residue_range[0]}-{receptor_residue_range[1]}",
                    'distance_angstroms': np.nan,
                    'binder_atom_count': len(binder_data['coordinates']) if 'coordinates' in binder_data else 0,
                    'bindsite_atom_count': len(receptor_bindsite_data['coordinates']) if 'coordinates' in receptor_bindsite_data else 0
                })
                
        except Exception as e:
            print(f"Error processing {filename}: {str(e)}")
            results.append({
                'file_name': filename,
                'binder_chain': binder_chain,
                'receptor_chain': receptor_chain,
                'binding_site_residues': f"{receptor_residue_range[0]}-{receptor_residue_range[1]}",
                'distance_angstroms': np.nan,
                'binder_atom_count': 0,
                'bindsite_atom_count': 0
            })
        
        # Progress update
        if (i + 1) % 10 == 0:
            print(f"Processed {i + 1}/{len(cif_files)} files...")
    
    # Convert to DataFrame
    df = pd.DataFrame(results)
    return df

def get_chain_sequence(mmcif_data, chain_id):
    """
    Extract amino acid sequence from a specific chain
    
    Args:
        mmcif_data: Dictionary from parse_mmcif_coordinates
        chain_id: Chain identifier (e.g., 'A', 'B')
    
    Returns:
        dict: Contains sequence, residue numbers, and residue names
    """
    # Filter data for the specified chain
    chain_mask = mmcif_data['chain_ids'] == chain_id
    
    chain_residue_numbers = mmcif_data['residue_numbers'][chain_mask]
    chain_residue_names = mmcif_data['residue_names'][chain_mask]
    
    # Get unique residues in order
    unique_residues = []
    seen_residues = set()
    
    for res_num, res_name in zip(chain_residue_numbers, chain_residue_names):
        residue_key = (res_num, res_name)
        if residue_key not in seen_residues:
            unique_residues.append((res_num, res_name))
            seen_residues.add(residue_key)
    
    # Sort by residue number
    unique_residues.sort(key=lambda x: x[0])
    
    # Three-letter to one-letter amino acid code mapping
    aa_code_map = {
        'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C',
        'GLU': 'E', 'GLN': 'Q', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
        'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', 'PRO': 'P',
        'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V',
        # Common modified amino acids
        'MSE': 'M',  # Selenomethionine
        'CSE': 'C',  # Selenocysteine
        'PYL': 'O',  # Pyrrolysine
        'SEC': 'U',  # Selenocysteine
    }
    
    # Convert to one-letter codes
    sequence = ""
    residue_numbers = []
    residue_names = []
    
    for res_num, res_name in unique_residues:
        if res_name in aa_code_map:
            sequence += aa_code_map[res_name]
            residue_numbers.append(res_num)
            residue_names.append(res_name)
        else:
            # Handle non-standard residues
            sequence += 'X'  # Unknown amino acid
            residue_numbers.append(res_num)
            residue_names.append(res_name)
            print(f"Warning: Unknown residue {res_name} at position {res_num} in chain {chain_id}")
    
    return {
        'sequence': sequence,
        'residue_numbers': residue_numbers,
        'residue_names': residue_names,
        'chain_id': chain_id,
        'length': len(sequence)
    }

def get_sequence_from_file(mmcif_file_path, chain_id):
    """
    Extract amino acid sequence from a CIF file for a specific chain
    
    Args:
        mmcif_file_path: Path to the .cif file
        chain_id: Chain identifier (e.g., 'A', 'B')
    
    Returns:
        dict: Contains sequence information
    """
    # Parse the CIF file
    mmcif_data = parse_mmcif_coordinates(mmcif_file_path)
    
    # Extract sequence for the specified chain
    sequence_data = get_chain_sequence(mmcif_data, chain_id)
    
    return sequence_data

def extract_sequences_from_all_files(cif_directory, chain_id):
    """
    Extract sequences from all CIF files in a directory for a specific chain
    
    Args:
        cif_directory: Path to directory containing CIF files
        chain_id: Chain identifier (e.g., 'A', 'B')
    
    Returns:
        pandas.DataFrame: Results with filename and sequence data
    """
    results = []
    
    # Get all .cif files in the directory
    cif_files = [f for f in os.listdir(cif_directory) if f.endswith('.cif')]
    
    print(f"Extracting sequences from {len(cif_files)} CIF files for chain {chain_id}...")
    
    for i, filename in enumerate(cif_files):
        try:
            file_path = os.path.join(cif_directory, filename)
            sequence_data = get_sequence_from_file(file_path, chain_id)
            
            results.append({
                'file_name': filename,
                'chain_id': chain_id,
                'sequence': sequence_data['sequence'],
                'sequence_length': sequence_data['length'],
                'residue_count': len(sequence_data['residue_numbers']),
                'first_residue': sequence_data['residue_numbers'][0] if sequence_data['residue_numbers'] else None,
                'last_residue': sequence_data['residue_numbers'][-1] if sequence_data['residue_numbers'] else None
            })
            
        except Exception as e:
            print(f"Error processing {filename}: {str(e)}")
            results.append({
                'file_name': filename,
                'chain_id': chain_id,
                'sequence': '',
                'sequence_length': 0,
                'residue_count': 0,
                'first_residue': None,
                'last_residue': None
            })
        
        # Progress update
        if (i + 1) % 10 == 0:
            print(f"Processed {i + 1}/{len(cif_files)} files...")
    
    return pd.DataFrame(results)

def extract_rank_number(filename):
    """Extract rank number from filename like 'rank0073_GLP1_ICL3_1stRun_3780.cif' -> 73"""
    try:
        if filename.startswith('rank'):
            rank_part = filename.split('_')[0]  # Get the 'rank0073' part
            rank_number = rank_part.replace('rank', '')  # Remove 'rank' to get '0073'
            return int(rank_number)  # Convert '0073' to 73
        else:
            return float('inf')  # Put non-rank files at the end
    except:
        return float('inf')

User defined Variable

In [None]:
# Define the chain ID of the binder (usually chain A for most cases)
chain_ID_binder = 'A'
chain_ID_receptor = 'B'

# Define the residue range of the binding site of the receptor
residue_range = (332, 346) 

# CIF file path, place where all the final ranked designs are stored
cif_file_path = 'final_ranked_designs/final_100_designs/'

# Output path for filtered designs
output_path = 'bindsite_filtered_designs/'

# Distance threshold for binding site filtering
distance_threshold = 25.0  # in Angstroms

# BoltzGen Filtered CSV File
boltzgen_csv = 'final_ranked_designs/final_designs_metrics_100.csv'

In [None]:
# Process all CIF files with the defined parameters
print("Processing CIF files...")
print(f"Binder chain: {chain_ID_binder}")
print(f"Receptor chain: {chain_ID_receptor}")
print(f"Binding site residue range: {residue_range}")
print(f"Distance threshold: {distance_threshold} Å")

# Calculate distances for all files
distance_df = process_all_cif_files(
    cif_file_path, 
    chain_ID_binder, 
    chain_ID_receptor, 
    residue_range
)

print(f"\nCompleted processing. Results shape: {distance_df.shape}")
print("\nFirst 5 results:")
print(distance_df.head())

print(f"\nDistance statistics:")
print(distance_df['distance_angstroms'].describe())

# Filter designs based on distance threshold
filtered_df = distance_df[distance_df['distance_angstroms'] <= distance_threshold].copy()

print(f"\nFiltering results:")
print(f"Total designs processed: {len(distance_df)}")
print(f"Designs within {distance_threshold} Å: {len(filtered_df)}")
print(f"Percentage passing filter: {len(filtered_df)/len(distance_df)*100:.1f}%")

# Sort by distance (closest first)
filtered_df = filtered_df.sort_values('distance_angstroms').reset_index(drop=True)

print(f"\nTop 10 closest designs:")
print(filtered_df[['file_name', 'distance_angstroms']].head(10))

# Save results to CSV files
os.makedirs(output_path, exist_ok=True)

# Save all results
all_results_file = os.path.join(output_path, 'all_distance_calculations.csv')
distance_df.to_csv(all_results_file, index=False)
print(f"\nAll results saved to: {all_results_file}")

# Display summary statistics for filtered designs
print(f"\nFiltered designs distance statistics:")
print(filtered_df['distance_angstroms'].describe())

# Copy filtered CIF files to output directory
for filename in filtered_df['file_name']:
    src_path = os.path.join(cif_file_path, filename)
    dest_path = os.path.join(output_path, filename)
    if os.path.exists(src_path):
        shutil.copy(src_path, dest_path)
print(f"\nFiltered CIF files copied to: {output_path}")

# Plot histogram of distances for all designs
plt.figure(figsize=(8,6))
plt.hist(distance_df['distance_angstroms'].dropna(), bins=50, color='skyblue', edgecolor='black')
plt.axvline(distance_threshold, color='red', linestyle='dashed', linewidth=1)
plt.title('Distribution of Binder to Binding Site Distances')
plt.xlabel('Distance (Å)')
plt.ylabel('Number of Designs')

plt.savefig(os.path.join(output_path, 'distance_distribution_histogram.tiff'), dpi=200)
print(f"Distance distribution histogram saved to: {os.path.join(output_path, 'distance_distribution_histogram.tiff')}")

plt.show()

# Extract sequences from filtered binders and append to filtered_df
print(f"\nExtracting binder sequences for {len(filtered_df)} filtered designs...")

sequences_data = []
for i, row in filtered_df.iterrows():
    filename = row['file_name']
    try:
        # Get the full path to the CIF file
        file_path = os.path.join(cif_file_path, filename)
        
        # Extract sequence for the binder chain
        sequence_data = get_sequence_from_file(file_path, chain_ID_binder)
        
        sequences_data.append({
            'file_name': filename,
            'binder_sequence': sequence_data['sequence'],
            'binder_length': sequence_data['length'],
            'binder_first_residue': sequence_data['residue_numbers'][0] if sequence_data['residue_numbers'] else None,
            'binder_last_residue': sequence_data['residue_numbers'][-1] if sequence_data['residue_numbers'] else None
        })
        
    except Exception as e:
        print(f"Error extracting sequence from {filename}: {str(e)}")
        sequences_data.append({
            'file_name': filename,
            'binder_sequence': '',
            'binder_length': 0,
            'binder_first_residue': None,
            'binder_last_residue': None
        })
    
    # Progress update
    if (i + 1) % 10 == 0:
        print(f"Processed sequences for {i + 1}/{len(filtered_df)} files...")

# Convert sequences data to DataFrame
sequences_df = pd.DataFrame(sequences_data)

# Merge with filtered_df
filtered_df_with_sequences = filtered_df.merge(sequences_df, on='file_name', how='left')

print(f"\nSequence extraction completed!")
print(f"Enhanced dataframe shape: {filtered_df_with_sequences.shape}")
print(f"New columns added: {[col for col in filtered_df_with_sequences.columns if col not in filtered_df.columns]}")

# Display some sequence statistics
if len(sequences_df) > 0:
    valid_sequences = sequences_df[sequences_df['binder_length'] > 0]
    if len(valid_sequences) > 0:
        print(f"\nBinder sequence statistics:")
        print(f"Average length: {valid_sequences['binder_length'].mean():.1f}")
        print(f"Min length: {valid_sequences['binder_length'].min()}")
        print(f"Max length: {valid_sequences['binder_length'].max()}")
        
        print(f"\nFirst 5 binder sequences:")
        for idx, row in valid_sequences.head().iterrows():
            print(f"{row['file_name']}: {row['binder_sequence'][:50]}{'...' if len(row['binder_sequence']) > 50 else ''}")

# Reorder by rank number in filename (rank0001_XXX.cif to rank_XXXX_XXX.cif)
print(f"\nReordering results by rank number...")

# Add rank column for sorting
filtered_df_with_sequences['rank_number'] = filtered_df_with_sequences['file_name'].apply(extract_rank_number)

# Sort by rank number
filtered_df_with_sequences_sorted = filtered_df_with_sequences.sort_values('rank_number').reset_index(drop=True)

# Save the reordered results
reordered_results_file = os.path.join(output_path, f'filtered_designs_with_sequences_ranked_{distance_threshold}A.csv')
filtered_df_with_sequences_sorted.to_csv(reordered_results_file, index=False)
print(f"Reordered results saved to: {reordered_results_file}")

print(f"\nFirst 10 files in reordered results:")
print(filtered_df_with_sequences_sorted[['file_name', 'distance_angstroms']].head(10))

# Update the main filtered_df to include sequences and proper ordering
filtered_df = filtered_df_with_sequences_sorted.copy()
print(f"\nfiltered_df updated with sequence information and rank ordering")