 # Boltz-2 NIM for Protein/Ligand Co-Folding and Affinity Prediction, including MSA-Search NIM for Alignments



 ## Demo to run in Google Colab environment



 [MSA-Search](https://docs.nvidia.com/nim/bionemo/msa-search/latest/overview.html) Multiple Sequence Alignment (MSA) compares a query amino acid sequence to protein databases, aligning similar sequences to identify conserved regions despite differences in length or motifs. The resulting alignments enhance structural prediction models like AlphaFold2 and OpenFold by leveraging the structural similarity of homologous sequences.



 [Boltz-2](https://docs.nvidia.com/nim/bionemo/boltz2/latest/index.html) Boltz-2 NIM delivers advanced biomolecular structure and binding affinity predictions for proteins, RNA, DNA, and other molecules. Built on the Boltz-2 architecture, it enables accurate modeling of complex structures and quantifies molecular interactions across diverse configurations.



 29Aug2025

 ## 1.1 Set Up the Environment

In [None]:
!pip install matplotlib numpy pandas seaborn sklearn tqdm httpx "fastapi[standard]"

In [None]:
import os, requests, re, json
import shutil
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score

from pathlib import Path
from typing import Dict, Any, Optional
from time import perf_counter
from tqdm import tqdm
from datetime import datetime

from google.colab import userdata, files


 ### Define Input File with SMILES data and Output Directory

In [None]:
# edit to match your dataset path. Ensure you use the `raw` URL path if the file is hosted on GitHub.
#csv_file = "https://raw.githubusercontent.com/bf-nv/bionemo_tutorials/refs/heads/main/RORc_SMILES_and_pIC50.csv"

CSV_FILE = "RORc_SMILES_and_pIC50.csv"

OUTPUT_DIR = "/content/output"


 ## 1.2 Set Up `output` Directory and `API_KEY`

In [None]:
API_KEY = userdata.get('API_KEY')

# Ensure output directory exists, create if not present
if os.path.exists(OUTPUT_DIR):
    shutil.rmtree(OUTPUT_DIR)
    os.makedirs(OUTPUT_DIR)


 ## 1.3 Define `MSA-Search` Functions

In [None]:
MSA_DATABASES = ['Uniref30_2302', 'colabfold_envdb_202108', 'PDB70_220313']

def msa_search(sequence, API_KEY, databases=MSA_DATABASES):
    msa_search_url = "https://health.api.nvidia.com/v1/biology/colabfold/msa-search/predict"
    payload = {
        "sequence": sequence,
        "databases": databases,
        "e_value": 0.0001,
        "iterations": 1,
        "max_msa_sequences": 10000,
        "run_structural_template_search": False,
        "output_alignment_formats": ["a3m"],
    }
    headers = {
        "Authorization": f"Bearer {API_KEY}",
        "content-type": "application/json",
        "NVCF-POLL-SECONDS": "300",
    }
    # Call MSA-Search NIM
    response = requests.post(msa_search_url, json=payload, headers=headers)
    return response.json()


def parse_sequences(input_string, n, sequence):
    """
    Parse the output of alignments from the MSA-Search NIM to be used downstream

    Args:
        input_string (str): The output file of alignments in a string format
        n (int): The amount of alignments to return from the output when parsing
        sequence (str): The query sequence for alignment

    Returns:
        list: A list of alignment identifiers and sequences, starting with the query,
              where the amount of sequences is given by n
    """
    # Output is parsed to have a line for the sequence id and sequence itself so `n` returns correlates to n*2 lines
    n = n * 2
    # First, handle the `Query` block separately
    lines = input_string.strip().split('\n')
    # Now process the rest of the lines
    remaining_string = "\n".join(lines[:])
    # Regex to find blocks starting with `>` and then followed by a sequence.
    pattern = re.compile(r'\n>(.*?)\n(.*?)(?=\n>|\Z)', re.DOTALL)
    matches = pattern.finditer(remaining_string)
    output_list_to_order = []
    for match in matches:
        # The name is the first capturing group, split by tab and take the first part
        name_full = match.group(1).split('\t')[0]
        SW_score = match.group(1).split('\t')[1]
        # The sequence is the second capturing group
        sequence_raw = match.group(2).strip()
        aligned_sequence = ''.join(char for char in sequence_raw if char.isupper() or not char.isalpha())
        # Store the aligned sequence in the list of outputs by name, sequence, Smith-Waterman score
        output_list_to_order.append((f'>{name_full}', aligned_sequence, int(SW_score)))
    output_lines = output_list_to_order[:n]
    return output_lines


def validate_a3m_format(alignments_string):
    """
    Validate that the alignment string follows A3M format.

    Args:
        alignments_string (str): String containing alignments

    Returns:
        bool: True if valid A3M format, False otherwise
    """
    lines = alignments_string.strip().split('\n')
    if len(lines) < 2:
        return False

    # Check that we have alternating header and sequence lines
    for i, line in enumerate(lines):
        if i % 2 == 0:  # Even indices should be headers
            if not line.startswith('>'):
                return False
        else:  # Odd indices should be sequences
            if line.startswith('>'):
                return False
            # Sequences should only contain valid amino acid characters and gaps
            if not all(c in 'ACDEFGHIKLMNPQRSTVWY-' for c in line.upper()):
                return False

    return True


def write_alignments_to_a3m(alignments_data, uniprot_id, output_dir):
    """
    Write alignment data to a3M format file.

    Args:
        alignments_data: Either a list of alternating headers/sequences or a string containing alignments
        uniprot_id (str): Uniprot ID of the protein
        output_dir (str): Directory for the output a3M file

    Returns:
        str: Path to the created a3M file
    """
    # Ensure output directory exists
    Path(output_dir).mkdir(parents=True, exist_ok=True)

    output_path = Path(output_dir) / f"{uniprot_id}_msa_alignments.a3m"

    # Handle both list and string input formats
    if isinstance(alignments_data, list):
        alignments_string = '\n'.join(alignments_data)
    elif isinstance(alignments_data, str):
        alignments_string = alignments_data
    else:
        raise ValueError("alignments_data must be either a list or string")

    # Validate A3M format
    if not alignments_string.strip():
        raise ValueError("Empty alignment data provided")

    # Count sequences for reporting
    sequence_count = alignments_string.count('>')
    if sequence_count == 0:
        raise ValueError("No sequences found in alignment data")

    # Validate A3M format structure
    if not validate_a3m_format(alignments_string):
        print("Warning: Alignment data may not follow strict A3M format")
        print("Proceeding with file creation...")

    print(f"Writing {sequence_count} sequences to A3M format: {output_path}")

    try:
        with open(output_path, 'w', encoding='utf-8') as f:
            # Write the alignments
            f.write(alignments_string)
            # Ensure file ends with newline
            if not alignments_string.endswith('\n'):
                f.write('\n')

        # Verify the file was created successfully
        if output_path.exists():
            file_size = output_path.stat().st_size
            print(f"Successfully created A3M file:")
            print(f"File: {output_path}")
            print(f"Size: {file_size:,} bytes")
            print(f"Sequences: {sequence_count}")

            # Download the file to the user's machine
            try:
                files.download(str(output_path))
                print(f"File downloaded successfully: {output_path}")
            except Exception as download_error:
                print(f"Warning: Could not download file automatically: {download_error}")
                print(f"File is available at: {output_path}")

            return str(output_path)
        else:
            raise IOError(f"Failed to create file {output_path}")
    except Exception as e:
        print(f"Error writing A3M file: {e}")
        raise



def process_msa_alignments(msa_response_dict, sequence, uniprot_id, output_dir, databases=MSA_DATABASES, max_sequences_per_db=10000):
    """
    Process MSA alignments from multiple databases and merge them into A3M format.

    Args:
        msa_response_dict (dict): MSA response data containing alignments
        sequence (str): Query sequence for alignment
        uniprot_id (str): Uniprot ID of the protein
        output_dir (str): Output directory for the A3M file
        databases (list): List of database names to process
        max_sequences_per_db (int): Maximum number of sequences to parse per database

    Returns:
        str: Path to the created A3M file
    """
    all_parsed_dataset_output = []
    for database in databases:
        print(f"Parsing results from database: {database}")
        # Pull string of alignments stored in json output for specific dataset
        a3m_dict_msa_search = msa_response_dict['alignments'][database]['a3m']['alignment']
        a3m_dict_msa_search_parsed = parse_sequences(a3m_dict_msa_search, max_sequences_per_db, sequence)
        num_sequences_aligned = (len(a3m_dict_msa_search_parsed))
        print(f"Number of sequences aligned: {num_sequences_aligned}")
        all_parsed_dataset_output.extend(a3m_dict_msa_search_parsed)
    # Sort all the alignments based off of the alignment score
    all_parsed_dataset_output.sort(key=lambda x: x[2], reverse=True)
    # Now that the alignments across all datasets are sorted, reformat each entry to name and sequence
    sorted_parsed_output_formatted = []
    for align_tuple in all_parsed_dataset_output:
        sorted_parsed_output_formatted.append(align_tuple[0])
        sorted_parsed_output_formatted.append(align_tuple[1])
    merged_alignments_protein = [f">query_sequence\n{sequence}"]
    merged_alignments_protein.extend(sorted_parsed_output_formatted)
    print(f"Total merged alignments: {len(merged_alignments_protein)}")
    # Write merged_alignments_protein to a3M format
    a3m_file_path = write_alignments_to_a3m(
        merged_alignments_protein,
        uniprot_id,
        output_dir
    )
    return a3m_file_path


 ## 1.4 Run `MSA-Search` and save `A3M` alignment file output to local directory

 ### Provide Sequence Information



 **NOTE:** Ensure the sequence is a string and does not contain any whitespace, special characters, nor carriage returns.

In [None]:
# Example sequence using human RORc from PDB:4wqp_A
# http://rcsb.org/structure/4WQP

# >4WQP_1|Chain A|Nuclear receptor ROR-gamma|Homo sapiens (9606)
uniprot_id = "4wqp_1"
sequence = "MHHHHHHGENLYFQGSAPYASLTEIEHLVQSVCKSYRETCQLRLEDLLRQRSNIFSREEVTGYQRKSMWEMWERCAHHLTEAIQYVVEFAKRLSGFMELCQNDQIVLLKAGAMEVVLVRMCRAYNADNRTVFFEGKYGGMELFRALGCSELISSIFDFSHSLSALHFSEDEIALYTALVLINAHRPGLQEKRKVEQLQYNLELAFHHHLCKTHRQSILAKLPPKGKLRSLCSQHVERLQIFQHLHPIVVQAAFPPLYKELFSGNS"
sequences = [(uniprot_id, sequence)]

# NOTE: Ensure the sequence is a string and does not contain any whitespace, special characters, nor carriage returns.


In [None]:
for seq_id, seq in tqdm(sequences):
    try:
        print(f"\nProcessing protein: {seq_id}")
        print(f"Sequence length: {len(seq)}")

        # Call MSA-Search NIM
        msa_response_dict = msa_search(seq, API_KEY)

        # Check if the response contains the expected data
        if 'alignments' not in msa_response_dict:
            print(f"Warning: No alignments found for {seq_id}")
            continue

        # Process and create A3M file
        a3m_file_path = process_msa_alignments(msa_response_dict, seq, seq_id, OUTPUT_DIR)
        print(f"Successfully processed {seq_id} -> {a3m_file_path}")

    except Exception as e:
        print(f"Error processing {seq_id}: {e}")
        continue


 ## 1.5 List all created `A3M` alignment files

In [None]:
# List all created A3M files
import glob
a3m_files = glob.glob(f"{OUTPUT_DIR}/*.a3m")
a3m_files = sorted(a3m_files)
print(f"Created {len(a3m_files)} A3M files:")
for file_path in a3m_files:
    file_size = Path(file_path).stat().st_size
    print(f"  - {Path(file_path).name} ({file_size:,} bytes)")

print(f"\nAll A3M files are available in: {OUTPUT_DIR}")
print("Files have been automatically downloaded to your machine.")


 ## 1.6 If needed, trigger download of all `A3M` alignment files

In [None]:
for file_path in a3m_files:
    files.download(file_path)


 ## 2.1 Set-up `Boltz-2` Environment

 ### Configuration Constraints

In [None]:
# MSA configuration
MSA_DATA_DIR = OUTPUT_DIR
MSA_FILE_NAME = f"{seq_id}_msa_alignments.a3m"
MSA_FILE_PATH = Path(OUTPUT_DIR / MSA_FILE_NAME)
if MSA_FILE_PATH.exists():
    MSA_STATUS = True
else:
    MSA_STATUS = False

# Boltz2 API configuration
BOLTZ2_BASE_URL = "http://localhost:8000"
BOLTZ2_ENDPOINT = "/biology/mit/boltz2/predict"
REQUEST_TIMEOUT = 300  # 5 minutes

# Boltz2 prediction parameters
BOLTZ2_CONFIG = {
    "recycling_steps": 3,
    "sampling_steps": 20,
    "diffusion_samples": 1,
    "step_scale": 1.64,
    "without_potentials": True
}

# Required output columns for protein data
REQUIRED_OUTPUT_COLUMNS = {
    'smiles', 'uniprot_id', 'fasta_uniprot_seq', 'pic50',
    'boltz2_plddt', 'boltz2_pic50', 'boltz2_pic50_conf', 
    'boltz2_msa', 'boltz2_runtime'
}

# CSV timestamp format
CSV_TIMESTAMP_FORMAT = "%Y_%m_%d"


 ## 2.2 Use `csv_file` as Source of SMILES and pIC50 Data for Query

In [None]:
# Check if file exists before loading
if not Path(CSV_FILE).exists():
    raise FileNotFoundError(f"CSV file not found: {CSV_FILE}")

# Load the dataset
try:
    df = pd.read_csv(CSV_FILE, low_memory=False)
    print(f"Successfully loaded dataset: {df.shape}")
    print(f"Columns: {list(df.columns)}")
    print("\nFirst few rows:")
    print(df.head())
except Exception as e:
    raise IOError(f"Error loading CSV file {CSV_FILE}: {e}")

# Validate that the dataset has the expected structure
if df.empty:
    raise ValueError("Dataset is empty")


 ## 2.3 Define `Boltz-2` Functions

In [None]:
def boltz2_nim_query(
    input_data: Dict[str, Any], 
    base_url: str = None
) -> Dict[str, Any]:
    """
    Query the Boltz2 NIM with input data.
    
    Args:
        input_data: Dictionary containing the prediction request data
        base_url: Base URL of the NIM service (default: from global config)
    
    Returns:
        Dictionary containing the prediction response
        
    Raises:
        requests.exceptions.RequestException: If the HTTP request fails
    """
    if base_url is None:
        base_url = BOLTZ2_BASE_URL
    
    url = f"{base_url}{BOLTZ2_ENDPOINT}"
    headers = {"Content-Type": "application/json"}
    
    try:
        response = requests.post(
            url, 
            json=input_data, 
            headers=headers, 
            timeout=REQUEST_TIMEOUT
        )
        response.raise_for_status()
        return response.json()
    except requests.exceptions.Timeout:
        raise requests.exceptions.RequestException(f"Request timed out after {REQUEST_TIMEOUT} seconds")
    except requests.exceptions.RequestException as e:
        error_msg = f"Error querying NIM: {e}"
        if hasattr(e, 'response') and e.response is not None:
            error_msg += f"\nResponse status: {e.response.status_code}"
            if hasattr(e.response, 'text'):
                error_msg += f"\nResponse text: {e.response.text}"
        raise requests.exceptions.RequestException(error_msg)


def create_payload(seq: str, smile_: str, msa_content: str = None) -> Dict[str, Any]:
    """Create the payload for Boltz2 NIM query.
    
    Args:
        seq (str): Protein sequence
        smile_: SMILES string for the ligand
        msa_content (str, optional): MSA alignment content as string
        
    Returns:
        Dict[str, Any]: Payload for Boltz2 NIM query
    """
    polymer_data = {
        "id": "A",
        "molecule_type": "protein",
        "sequence": seq,
    }
    
    # Add MSA data if content is provided
    if msa_content:
        polymer_data["msa"] = {
            "uniref90": {
                "a3m": {
                    "alignment": msa_content,
                    "format": "a3m"
                }
            }
        }    
    # Construct complete payload using configuration constants
    payload = {
        "polymers": [polymer_data],
        "ligands": [{
            "smiles": smile_,
            "id": "L1",
            "predict_affinity": True
        }],
        **BOLTZ2_CONFIG
    }
    return payload


def append_default_values(plddt_list, pic50_list, pic50_conf_list, msa_list, time_list, pic50_groundtruth, msa_status, time_val=None, csv_file=None):
    """Append default values to all result lists while preserving ground truth data.
    
    Note: This function is used to maintain the correct length of Boltz2 prediction result lists.
    The pic50_groundtruth parameter is included for consistency but not stored in these lists
    since ground truth values are preserved separately in the original DataFrame and protein CSV files.
    
    Args:
        plddt_list: List to append pLDDT values to
        pic50_list: List to append pIC50 values to  
        pic50_conf_list: List to append pIC50 confidence values to
        msa_list: List to append MSA status to
        time_list: List to append timing values to
        pic50_groundtruth: Ground truth pIC50 value (for reference)
        msa_status: Whether MSA was used
        time_val: Runtime value to append (optional)
        csv_file: Path to CSV file for reference (optional, uses global CSV_FILE if not provided)
    """
    # Use global CSV_FILE if csv_file parameter not provided
    if csv_file is None:
        csv_file = CSV_FILE
    
    # Append default/null values to maintain list consistency
    plddt_list.append(None)
    pic50_list.append(None)
    pic50_conf_list.append(None)
    msa_list.append(msa_status)
    time_list.append(time_val)
    
    # Log the CSV file being referenced for debugging/tracking
    if csv_file:
        print(f"Default values appended for failed prediction (CSV source: {csv_file})")


def process_msa_file(msa_file_path: Path = None) -> str:
    """
    Process MSA file and return its content.
    
    Args:
        msa_file_path: Path to MSA file (default: global MSA_FILE_PATH)
        
    Returns:
        MSA alignment data as string
        
    Raises:
        FileNotFoundError: If MSA file doesn't exist
        IOError: If there's an error reading the file
    """
    if msa_file_path is None:
        msa_file_path = MSA_FILE_PATH
    
    if not msa_file_path.exists():
        raise FileNotFoundError(f"MSA file not found: {msa_file_path}")
    try:
        with open(msa_file_path, "r", encoding="utf-8") as f:
            content = f.read().strip()
            if not content:
                raise IOError(f"MSA file is empty: {msa_file_path}")
            return content
    except UnicodeDecodeError as e:
        raise IOError(f"Error decoding MSA file {msa_file_path}: {e}")
    except Exception as e:
        raise IOError(f"Error reading MSA file {msa_file_path}: {e}")


def validate_response(result: Dict[str, Any], seq_id: str) -> tuple[Optional[float], list, list]:
    """
    Validate API response and extract required data.
    
    Args:
        result: API response dictionary
        seq_id: Protein sequence identifier for error reporting
        
    Returns:
        Tuple of (plddt_score, pic50_values, pic50_confidence_values)
        
    Raises:
        ValueError: If required data is missing from response
    """
    # Validate affinities data
    if 'affinities' not in result:
        raise ValueError(f"Missing 'affinities' key in response for {seq_id}")
    if 'L1' not in result['affinities']:
        raise ValueError(f"Missing 'L1' ligand data in affinities for {seq_id}")
    # Extract pLDDT scores
    if 'complex_plddt_scores' not in result or not result['complex_plddt_scores']:
        print(f"Warning: Missing or empty pLDDT scores for {seq_id}")
        plddt_indiv = None
    else:
        try:
            plddt_indiv = float(np.mean(result['complex_plddt_scores']))
        except (TypeError, ValueError) as e:
            print(f"Warning: Error calculating pLDDT mean for {seq_id}: {e}")
            plddt_indiv = None
    # Extract affinity data
    try:
        pic50_indiv = result['affinities']['L1']['affinity_pic50']
        pic50_conf_indiv = result['affinities']['L1']['affinity_probability_binary']
        # Validate that we got lists
        if not isinstance(pic50_indiv, list) or not isinstance(pic50_conf_indiv, list):
            raise ValueError(f"Expected lists for affinity data, got {type(pic50_indiv)} and {type(pic50_conf_indiv)}")
    except KeyError as e:
        raise ValueError(f"Missing required affinity data key {e} for {seq_id}")
    return plddt_indiv, pic50_indiv, pic50_conf_indiv


def save_protein_results_to_csv(protein_data: list, protein_id: str, output_dir: str = None, msa_status: bool = None):
    """
    Save results for a single protein to a CSV file.
    
    Args:
        protein_data: List of dictionaries containing ligand data
        protein_id: Identifier for the protein
        output_dir: Directory to save the CSV file (default: global OUTPUT_DIR)
        msa_status: Whether MSA was used for predictions (default: global MSA_STATUS)
        
    Raises:
        ValueError: If protein_data is empty or invalid
        IOError: If there's an error writing the file
    """
    if not protein_data:
        return
    if not isinstance(protein_data, list) or not all(isinstance(item, dict) for item in protein_data):
        raise ValueError("protein_data must be a list of dictionaries")
    
    # Use global constants if not provided
    if output_dir is None:
        output_dir = OUTPUT_DIR
    if msa_status is None:
        msa_status = MSA_STATUS
    
    try:
        # Create DataFrame from protein data
        protein_df = pd.DataFrame(protein_data)
        # Validate DataFrame structure using configuration
        missing_columns = REQUIRED_OUTPUT_COLUMNS - set(protein_df.columns)
        if missing_columns:
            raise ValueError(f"Missing required columns in protein data: {missing_columns}")
        # Create output directory if it doesn't exist
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)
        # Generate filename with timestamp using configuration
        timestamp = datetime.now().strftime(CSV_TIMESTAMP_FORMAT)
        filename = f"Boltz2_Predictions_{protein_id}_MSA_{msa_status}_BFauber_{timestamp}.csv"
        filepath = output_path / filename
        # Save to CSV
        protein_df.to_csv(filepath, index=False)
        print(f"Saved results for protein {protein_id} to {filepath} ({len(protein_data)} ligands)")
    except Exception as e:
        error_msg = f"Error saving protein {protein_id} results: {e}"
        print(error_msg)
        raise IOError(error_msg)
    

def run_queries(df: pd.DataFrame, smiles_field: str, seq_id_field: str, seq_field: str, pic50_field: str, output_dir: str = None, msa_status: bool = None):
    """
    Run Boltz2 NIM queries for all protein-ligand pairs.
    
    Args:
        df: DataFrame containing protein-ligand data
        smiles_field: Column name for SMILES strings
        seq_id_field: Column name for sequence IDs
        seq_field: Column name for protein sequences
        pic50_field: Column name for ground truth pIC50 values
        output_dir: Directory to save protein-specific CSV files (default: global OUTPUT_DIR)
        msa_status: Whether to use MSA data (default: global MSA_STATUS)
        
    Returns:
        Tuple of result lists: (plddt_list, pic50_list, pic50_conf_list, msa_list, time_list)
        
    Raises:
        ValueError: If required columns are missing
    """
    # Use global constants if not provided
    if output_dir is None:
        output_dir = OUTPUT_DIR
    if msa_status is None:
        msa_status = MSA_STATUS
    
    # Validate that required columns exist using configuration
    required_columns = [smiles_field, seq_id_field, seq_field, pic50_field]
    missing_columns = [col for col in required_columns if col not in df.columns]
    if missing_columns:
        raise ValueError(f"Missing required columns: {missing_columns}")
    # Validate DataFrame is not empty
    if df.empty:
        raise ValueError("DataFrame is empty")
    # Initialize result lists
    plddt_list, pic50_list, pic50_conf_list, time_list, msa_list = [], [], [], [], []
    # Track previous protein to avoid redundant processing
    seq_id_prev = None
    protein_msa = None
    # Track current protein data for CSV saving
    current_protein_data = []
    current_protein_id = None
    # Process each protein-ligand pair
    for smile_, seq_id, seq, pic50_groundtruth in tqdm(
        zip(df[smiles_field], df[seq_id_field], df[seq_field], df[pic50_field]),
        desc="Processing protein-ligand pairs",
        total=len(df)
    ):
        # Check if we need to process a new protein
        if seq_id_prev != seq_id:
            # Save previous protein results to CSV if we have data
            if current_protein_data and current_protein_id:
                save_protein_results_to_csv(current_protein_data, current_protein_id, output_dir, msa_status)
            # Reset for new protein
            seq_id_prev = seq_id
            protein_msa = None
            current_protein_id = seq_id
            current_protein_data = []
            # Load MSA data if needed
            if msa_status:
                try:
                    protein_msa = process_msa_file(MSA_FILE_PATH)
                except (FileNotFoundError, IOError) as e:
                    print(f"Warning: {e}, skipping protein {seq_id}")
                    append_default_values(plddt_list, pic50_list, pic50_conf_list, msa_list, time_list, pic50_groundtruth, msa_status, csv_file=CSV_FILE)
                    continue
        # Create payload for this query
        payload = create_payload(seq, smile_, protein_msa)
        try:
            # Query the NIM
            t0 = perf_counter()
            result = boltz2_nim_query(payload)
            t1 = perf_counter()
            time_indv = round(t1 - t0, 3)
            # Extract and validate results
            plddt_indiv, pic50_indiv, pic50_conf_indiv = validate_response(result, seq_id)
            # Save results to lists
            plddt_list.append(plddt_indiv)
            pic50_list.append(pic50_indiv)
            pic50_conf_list.append(pic50_conf_indiv)
            msa_list.append(msa_status)
            time_list.append(time_indv)
            # Store current ligand data for CSV
            current_protein_data.append({
                'smiles': smile_,
                'uniprot_id': seq_id,
                'fasta_uniprot_seq': seq,
                'pic50': pic50_groundtruth,
                'boltz2_plddt': round(plddt_indiv, 2) if plddt_indiv is not None else None,
                'boltz2_pic50': round(pic50_indiv[0], 2) if pic50_indiv and pic50_indiv[0] is not None else None,
                'boltz2_pic50_conf': round(pic50_conf_indiv[0], 2) if pic50_conf_indiv and pic50_conf_indiv[0] is not None else None,
                'boltz2_msa': msa_status,
                'boltz2_runtime': time_indv
            })
        except Exception as e:
            print(f"Failed to get prediction for {seq_id}: {e}")
            append_default_values(plddt_list, pic50_list, pic50_conf_list, msa_list, time_list, pic50_groundtruth, msa_status, time_indv if 'time_indv' in locals() else None, csv_file=CSV_FILE)
            # Add failed ligand data to current protein data
            current_protein_data.append({
                'smiles': smile_,
                'uniprot_id': seq_id,
                'fasta_uniprot_seq': seq,
                'pic50': pic50_groundtruth,
                'boltz2_plddt': None,
                'boltz2_pic50': None,
                'boltz2_pic50_conf': None,
                'boltz2_msa': msa_status,
                'boltz2_runtime': time_indv if 'time_indv' in locals() else None
            })
    # Save the last protein's results
    if current_protein_data and current_protein_id:
        save_protein_results_to_csv(current_protein_data, current_protein_id, output_dir, msa_status)
    return plddt_list, pic50_list, pic50_conf_list, msa_list, time_list



 ## 2.4 Run `Boltz-2` Query

In [None]:
plddt_list, pic50_list, pic50_conf_list, msa_list, time_list = run_queries(
    df, "smiles", "uniprot_id", "fasta_uniprot_seq", "pic50", 
    OUTPUT_DIR, msa_status=MSA_STATUS)


 ## 2.5 Results to a DataFrame

In [None]:
# Check if all lists have the same length
expected_length = len(df)
actual_lengths = [len(plddt_list), len(pic50_list), len(pic50_conf_list), len(msa_list), len(time_list)]
if not all(length == expected_length for length in actual_lengths):
    print(f"Warning: Result lists have different lengths. Expected: {expected_length}, Got: {actual_lengths}")

# Add Boltz2 prediction results to the original `csv_file` DataFrame
df['boltz2_plddt'] = np.array([round(x, 2) if x is not None else None for x in plddt_list])
df['boltz2_pic50'] = np.array([round(x[0], 2) if x and x[0] is not None else None for x in pic50_list])
df['boltz2_pic50_conf'] = np.array([round(x[0], 2) if x and x[0] is not None else None for x in pic50_conf_list])
df['boltz2_msa'] = np.array(msa_list)
df['boltz2_runtime'] = np.array(time_list)



 ## 2.6 Plot Results



 ### Actual pIC50 vs Boltz2 Predicted pIC50 with regression lines and R-squared values

In [None]:
# Matplotlib plot with regression lines and R-squared values for MSA cohorts
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score

BLUE = (120/255, 94/255, 240/255) # `indigo` from IBM
ORANGE = (254/255, 97/255, 0) # from IBM
RED = (220/255, 38/255, 127/255) # `magenta` from IBM
cmap_full = [ORANGE, BLUE, RED]

# Create figure
plt.figure(figsize=(6, 5))

# Define colors for each cohort
colors = {'False': cmap_full[1], 'True': cmap_full[0]}

# Get unique MSA values
msa_values = df['boltz2_msa'].unique()

# Plot data and regression lines for each cohort
for msa_type in msa_values:
    # Filter data for this cohort
    mask = df['boltz2_msa'] == msa_type
    x_data = df[mask]['pic50'].values.reshape(-1, 1)
    y_data = df[mask]['boltz2_pic50'].values
    
    # Remove any NaN values
    valid_mask = ~(np.isnan(x_data.flatten()) | np.isnan(y_data))
    x_clean = x_data[valid_mask].reshape(-1, 1)
    y_clean = y_data[valid_mask]
    
    if len(x_clean) > 1:  # Need at least 2 points for regression
        # Scatter plot
        plt.scatter(x_clean.flatten(), y_clean, 
                   color=colors.get(msa_type, 
                                    #'#333333'
                                    ), 
                   alpha=0.8, 
                   s=100, 
                   label=f'{msa_type} (n={len(x_clean)})')
        
        # Fit linear regression
        reg_model = LinearRegression()
        reg_model.fit(x_clean, y_clean)
        
        # Calculate R-squared
        y_pred = reg_model.predict(x_clean)
        r2 = r2_score(y_clean, y_pred)
        
        # Create regression line
        x_range = np.linspace(x_clean.min(), x_clean.max(), 100).reshape(-1, 1)
        y_range = reg_model.predict(x_range)
        
        # Plot regression line
        plt.plot(x_range.flatten(), y_range, 
                color=colors.get(msa_type, 
                                 #'#333333'
                                 ), 
                linewidth=2, linestyle='--',
                label=f'{msa_type} R² = {r2:.2f}')

# Add diagonal reference line (perfect prediction)
min_val = min(df['pic50'].min(), df['boltz2_pic50'].min())
max_val = max(df['pic50'].max(), df['boltz2_pic50'].max())
plt.plot([min_val, max_val], [min_val, max_val], 
         'k--', alpha=0.8, linewidth=1, label='Perfect Prediction')

# Formatting
plt.xlabel(r"Actual pIC$_{50}$", fontsize=14)
plt.ylabel(r"Boltz-2 Predicted pIC$_{50}$", fontsize=14)
plt.title(r"Boltz-2 pIC$_{50}$ vs Actual pIC$_{50}$", fontsize=16)
plt.legend(fontsize=12, framealpha=0.9)
plt.grid(True, alpha=0.3)
plt.tick_params(labelsize=12)
plt.tight_layout()
plt.show()



 ## 2.7 Save Results

In [None]:
# Save combined results to CSV
output_filename = f"Boltz2_Predictions_ManyProteins_MSA_{MSA_STATUS}.csv"
output_filepath = Path(OUTPUT_DIR) / output_filename
df.to_csv(output_filepath, index=False)

print(f"Saved combined results to: {output_filepath}")
print(f"DataFrame shape: {df.shape}")
print("\nFirst few rows:")
print(df.head())


