# Running MSA-Search NIM in Google Colab Environment

Runs MSA-Search NIM and saves result as `A3M` format file.

MSA-Search NIM: https://build.nvidia.com/colabfold/msa-search

13Aug2025

## 1.1 Set Up the Environment

In [1]:
!pip install pandas numpy seaborn matplotlib httpx "fastapi[standard]"

Collecting fastapi-cli>=0.0.8 (from fastapi-cli[standard]>=0.0.8; extra == "standard"->fastapi[standard])
  Downloading fastapi_cli-0.0.8-py3-none-any.whl.metadata (6.3 kB)
Collecting email-validator>=2.0.0 (from fastapi[standard])
  Downloading email_validator-2.2.0-py3-none-any.whl.metadata (25 kB)
Collecting dnspython>=2.0.0 (from email-validator>=2.0.0->fastapi[standard])
  Downloading dnspython-2.7.0-py3-none-any.whl.metadata (5.8 kB)
Collecting rich-toolkit>=0.14.8 (from fastapi-cli>=0.0.8->fastapi-cli[standard]>=0.0.8; extra == "standard"->fastapi[standard])
  Downloading rich_toolkit-0.15.0-py3-none-any.whl.metadata (1.0 kB)
Collecting fastapi-cloud-cli>=0.1.1 (from fastapi-cli[standard]>=0.0.8; extra == "standard"->fastapi[standard])
  Downloading fastapi_cloud_cli-0.1.5-py3-none-any.whl.metadata (3.2 kB)
Collecting httptools>=0.6.3 (from uvicorn[standard]>=0.12.0; extra == "standard"->fastapi[standard])
  Downloading httptools-0.6.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1

In [2]:
import json
import os
import requests
import re
import shutil
from google.colab import userdata

import asyncio
from typing import Any, Dict, Optional
from pathlib import Path
from enum import StrEnum
import logging
import sys

## 1.2 Set Up `output` Directory and `API_KEY`

**NOTE:** Be sure to follow the steps in the README to embed your NVIDIA `API_KEY` into your Google Colab environment.

In [5]:
def prepare_output_directory(output):
    """
    Prepare the output directory
    output: str, the output directory
    return: None
    """
    # Overwrite the output directory
    if os.path.exists(output):
        shutil.rmtree(output)
    os.makedirs(output)

In [6]:
API_KEY = userdata.get('API_KEY')

# Prepare `output_dir` for saving files
output_dir = "/content/output"
prepare_output_directory(output_dir)

#### Define Protein Sequence and Databases to use for MSA-Search

In [7]:
sequence = "MHHHHHHGENLYFQGSAPYASLTEIEHLVQSVCKSYRETCQLRLEDLLRQRSNIFSREEVTGYQRKSMWEMWERCAHHLTEAIQYVVEFAKRLSGFMELCQNDQIVLLKAGAMEVVLVRMCRAYNADNRTVFFEGKYGGMELFRALGCSELISSIFDFSHSLSALHFSEDEIALYTALVLINAHRPGLQEKRKVEQLQYNLELAFHHHLCKTHRQSILAKLPPKGKLRSLCSQHVERLQIFQHLHPIVVQAAFPPLYKELFSGNS"

databases = ['Uniref30_2302', 'colabfold_envdb_202108', 'PDB70_220313']

## 1.3 Set Up and Run `MSA-Search`

In [None]:
msa_search_url = "https://health.api.nvidia.com/v1/biology/colabfold/msa-search/predict"
payload = {
    "sequence": sequence,
    "databases": databases,
    "e_value": 0.0001,
    "iterations": 1,
    "max_msa_sequences": 10000,
    "run_structural_template_search": False,
    "output_alignment_formats": ["a3m"],
}

headers = {
    "Authorization": f"Bearer {API_KEY}",
    "content-type": "application/json",
    "NVCF-POLL-SECONDS": "300",
}
# Call MSA-Search NIM
response = requests.post(msa_search_url, json=payload, headers=headers)
msa_response_dict = response.json()
print(f"MSA response : \n {msa_response_dict}")

with open('raw_msa_output.json', 'w') as json_file:
    json.dump(msa_response_dict, json_file, indent=4)

## 1.4 Merge and Sort Alignments from `MSA-Search` into a Single `A3M` File

In [9]:
def parse_sequences(input_string, n, sequence):
    """
    Parse the output of alignments from the MSA-Search NIM to be used downstream

    Args:
        input_string (str): The output file of alignments in a string format
        n (int): The amount of alignments to return from the output when parsing
        sequence (str): The query sequence for alignment

    Returns:
        list: A list of alignment identifiers and sequences, starting with the query,
              where the amount of sequences is given by n
    """
    # Output is parsed to have a line for the sequence id and sequence itself so `n` returns correlates to n*2 lines
    n = n * 2

    # First, handle the `Query` block separately
    lines = input_string.strip().split('\n')

    # Now process the rest of the lines
    remaining_string = "\n".join(lines[:])

    # Regex to find blocks starting with `>` and then followed by a sequence.
    pattern = re.compile(r'\n>(.*?)\n(.*?)(?=\n>|\Z)', re.DOTALL)

    matches = pattern.finditer(remaining_string)

    output_list = []
    output_list_to_order = []

    for num_match, match in enumerate(matches):
        # The name is the first capturing group, split by tab and take the first part
        name_full = match.group(1).split('\t')[0]
        SW_score = match.group(1).split('\t')[1]

        # The sequence is the second capturing group
        sequence_raw = match.group(2).strip()
        sequence = ''.join(char for char in sequence_raw if char.isupper() or not char.isalpha())

        # Store the aligned sequence in the list of outputs by name, sequence, Smith-Waterman score
        output_list_to_order.append((f'>{name_full}', sequence, int(SW_score)))

    output_lines = output_list_to_order[:n]

    return output_lines


def write_alignments_to_a3m(alignments_data, output_file_path, description="MSA alignments"):
    """
    Write alignment data to a3M format file.

    Args:
        alignments_data: Either a list of alternating headers/sequences or a string containing alignments
        output_file_path (str): Path for the output a3M file
        description (str): Description for the file

    Returns:
        str: Path to the created a3M file
    """
    output_path = Path(output_file_path)

    # Handle both list and string input formats
    if isinstance(alignments_data, list):
        alignments_string = '\n'.join(alignments_data)
    elif isinstance(alignments_data, str):
        alignments_string = alignments_data
    else:
        raise ValueError("alignments_data must be either a list or string")

    # Count sequences for reporting
    sequence_count = alignments_string.count('>')

    print(f"Writing {sequence_count} sequences to a3M format: {output_path}")

    try:
        with open(output_path, 'w', encoding='utf-8') as f:
            # Write the alignments
            f.write(alignments_string)

            # Ensure file ends with newline
            if not alignments_string.endswith('\n'):
                f.write('\n')

        # Verify the file was created successfully
        if output_path.exists():
            file_size = output_path.stat().st_size
            print(f"Successfully created a3M file:")
            print(f"File: {output_path}")
            print(f"Size: {file_size:,} bytes")
            print(f"Sequences: {sequence_count}")

            return str(output_path)
        else:
            raise IOError(f"Failed to create file {output_path}")

    except Exception as e:
        print(f"Error writing a3M file: {e}")
        raise


def process_msa_alignments(msa_response_dict, databases, sequence, max_sequences_per_db=10000, output_file="merged_alignments_protein.a3m"):
    """
    Process MSA alignments from multiple databases and merge them into A3M format.

    Args:
        msa_response_dict (dict): MSA response data containing alignments
        databases (list): List of database names to process
        sequence (str): Query sequence for alignment
        max_sequences_per_db (int): Maximum number of sequences to parse per database
        output_file (str): Output A3M file path

    Returns:
        tuple: (merged_alignments_protein, a3m_file_path)
            - merged_alignments_protein: List of merged alignments
            - a3m_file_path: Path to the created A3M file
    """
    all_parsed_dataset_output = []

    for num_done, database in enumerate(databases):
        print(f"Parsing results from database: {database}")

        # Pull string of alignments stored in json output for specific dataset
        a3m_dict_msa_search = msa_response_dict['alignments'][database]['a3m']['alignment']

        a3m_dict_msa_search_parsed = parse_sequences(a3m_dict_msa_search, max_sequences_per_db, sequence)

        num_sequences_aligned = (len(a3m_dict_msa_search_parsed))
        print(f"Number of sequences aligned: {num_sequences_aligned}")

        all_parsed_dataset_output.extend(a3m_dict_msa_search_parsed)

    # Sort all the alignments based off of the alignment score
    all_parsed_dataset_output.sort(key=lambda x: x[2], reverse=True)

    # Now that the alignments across all datasets are sorted, reformat each entry to name and sequence
    sorted_parsed_output_formatted = []
    for align_tuple in all_parsed_dataset_output:
        sorted_parsed_output_formatted.append(align_tuple[0])
        sorted_parsed_output_formatted.append(align_tuple[1])

    merged_alignments_protein = [f">query_sequence\n{sequence}"]
    merged_alignments_protein.extend(sorted_parsed_output_formatted)

    print(f"Total merged alignments: {len(merged_alignments_protein)}")

    # Write merged_alignments_protein to a3M format
    a3m_file_path = write_alignments_to_a3m(
        merged_alignments_protein,
        output_file,
        description=f"Merged protein alignments from MSA-Search NIM ({', '.join(databases)})"
    )

    return merged_alignments_protein, a3m_file_path


def write_filtered_a3m(alignments_data, output_file_path, max_sequences=None, min_length=None, description="Filtered MSA alignments"):
    """
    Write alignment data to a3M format with optional filtering.

    Args:
        alignments_data: String containing alignments in FASTA-like format
        output_file_path (str): Path for the output a3M file
        max_sequences (int, optional): Maximum number of sequences to include
        min_length (int, optional): Minimum sequence length (excluding gaps)
        description (str): Description for the file

    Returns:
        str: Path to the created a3M file
    """
    output_path = Path(output_file_path)

    # Parse sequences from the input data
    if isinstance(alignments_data, str):
        lines = alignments_data.strip().split('\n')
    else:
        lines = '\n'.join(alignments_data).strip().split('\n')

    sequences = []
    current_header = None
    current_sequence = ""

    for line in lines:
        line = line.strip()
        if line.startswith('>'):
            # Save previous sequence if it exists
            if current_header is not None:
                sequences.append((current_header, current_sequence))
            current_header = line
            current_sequence = ""
        else:
            current_sequence += line

    # Don't forget the last sequence
    if current_header is not None:
        sequences.append((current_header, current_sequence))

    print(f"Parsed {len(sequences)} sequences from input data")

    # Apply filters
    filtered_sequences = []

    for header, sequence in sequences:
        # Apply minimum length filter (count non-gap characters)
        if min_length is not None:
            non_gap_length = len(sequence.replace('-', '').replace('.', ''))
            if non_gap_length < min_length:
                continue

        filtered_sequences.append((header, sequence))

        # Apply maximum sequences limit
        if max_sequences is not None and len(filtered_sequences) >= max_sequences:
            break

    print(f"After filtering: {len(filtered_sequences)} sequences")
    if max_sequences:
        print(f"Max sequences limit: {max_sequences}")
    if min_length:
        print(f"Min length filter: {min_length}")

    # Write to a3M format
    try:
        with open(output_path, 'w', encoding='utf-8') as f:
            # Write sequences
            for header, sequence in filtered_sequences:
                f.write(f"{header}\n{sequence}\n")

        # Report success
        file_size = output_path.stat().st_size
        print(f"Successfully created filtered a3M file:")
        print(f"File: {output_path}")
        print(f"Size: {file_size:,} bytes")
        print(f"Sequences: {len(filtered_sequences)}")

        return str(output_path)

    except Exception as e:
        print(f"Error writing filtered a3M file: {e}")
        raise


def analyze_a3m_file(file_path):
    """
    Analyze an a3M file and provide statistics.

    Args:
        file_path (str): Path to the a3M file
    """
    file_path = Path(file_path)

    if not file_path.exists():
        print(f"File not found: {file_path}")
        return

    print(f"Analyzing a3M file: {file_path.name}")

    try:
        with open(file_path, 'r') as f:
            lines = f.readlines()

        # Count statistics
        total_lines = len(lines)
        comment_lines = sum(1 for line in lines if line.startswith('#'))
        sequence_headers = sum(1 for line in lines if line.startswith('>'))
        sequence_lines = total_lines - comment_lines - sequence_headers

        # Calculate sequence lengths
        sequence_lengths = []
        current_sequence = ""

        for line in lines:
            line = line.strip()
            if line.startswith('#'):
                continue
            elif line.startswith('>'):
                if current_sequence:
                    sequence_lengths.append(len(current_sequence))
                current_sequence = ""
            else:
                current_sequence += line

        # Don't forget the last sequence
        if current_sequence:
            sequence_lengths.append(len(current_sequence))

        # File statistics
        file_size = file_path.stat().st_size

        print(f"File Statistics:")
        print(f"File size: {file_size:,} bytes")
        print(f"Total lines: {total_lines}")
        print(f"Comment lines: {comment_lines}")
        print(f"Sequence headers: {sequence_headers}")
        print(f"Sequence lines: {sequence_lines}")

        if sequence_lengths:
            avg_length = sum(sequence_lengths) / len(sequence_lengths)
            min_length = min(sequence_lengths)
            max_length = max(sequence_lengths)

            print(f"Sequence Statistics:")
            print(f"Number of sequences: {len(sequence_lengths)}")
            print(f"Average length: {avg_length:.1f}")
            print(f"Length range: {min_length} - {max_length}")

            # Show first sequence as example
            with open(file_path, 'r') as f:
                content = f.read()

            # Find first sequence
            lines = content.split('\n')
            for i, line in enumerate(lines):
                if line.startswith('>') and not line.startswith('#'):
                    header = line
                    sequence = ""
                    j = i + 1
                    while j < len(lines) and not lines[j].startswith('>'):
                        if not lines[j].startswith('#'):
                            sequence += lines[j].strip()
                        j += 1

                    print(f"First sequence example:")
                    print(f"Header: {header}")
                    print(f"Length: {len(sequence)}")
                    print(f"Preview: {sequence[:80]}{'...' if len(sequence) > 80 else ''}")
                    break

    except Exception as e:
        print(f"Error analyzing file: {e}")


### Parse the MSA alignment results to merge results from all datasets into a single `A3M` format file

In [10]:
merged_alignments_protein, a3m_file_path = process_msa_alignments(
    msa_response_dict,
    databases,
    sequence,
    max_sequences_per_db=10000,
    output_file="merged_alignments_protein.a3m"
    )

Parsing results from database: Uniref30_2302
Number of sequences aligned: 100
Parsing results from database: colabfold_envdb_202108
Number of sequences aligned: 100
Parsing results from database: PDB70_220313
Number of sequences aligned: 88
Total merged alignments: 577
Writing 289 sequences to a3M format: merged_alignments_protein.a3m
Successfully created a3M file:
File: merged_alignments_protein.a3m
Size: 81,161 bytes
Sequences: 289


### Download the `A3M` file to your local machine

In [11]:
from google.colab import files
files.download("raw_msa_output.json")
files.download("merged_alignments_protein.a3m")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

## 1.5 Analyze the `A3M` Format File

In [None]:
# Analyze all created a3M files
print("=" * 60)
print("A3M FILE ANALYSIS")
print("=" * 60)

files_to_analyze = [
    "merged_alignments_protein.a3m",
]

for file_name in files_to_analyze:
    if Path(file_name).exists():
        analyze_a3m_file(file_name)
        print("-" * 40)
    else:
        print(f"File not found: {file_name}")
        print("-" * 40)

## 1.6 Visualize the MSA Results

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from collections import Counter
import seaborn as sns
from matplotlib.patches import Rectangle
import matplotlib.patches as mpatches


def visualize_msa_alignment(merged_alignments_protein, top_n=5, figsize=(20, 12)):
    """
    Visualize multiple sequence alignment results emphasizing top-5 residues at each position.

    Args:
        merged_alignments_protein (list): List of alternating headers and sequences
        top_n (int): Number of top residues to highlight at each position
        figsize (tuple): Figure size for the plot
    """

    # Parse the alignments into a more structured format
    sequences = []
    headers = []

    for i, item in enumerate(merged_alignments_protein):
        if i % 2 == 0:  # Header line
            headers.append(item)
        else:  # Sequence line
            sequences.append(item)

    # Get the query sequence (first sequence)
    query_sequence = sequences[0]
    alignment_length = len(query_sequence)

    print(f"Alignment length: {alignment_length}")
    print(f"Number of sequences: {len(sequences)}")
    print(f"Query sequence: {query_sequence}")

    # Create position-wise residue frequency analysis
    position_residues = {}
    position_gaps = {}

    for pos in range(alignment_length):
        residues_at_pos = []
        gaps_at_pos = 0

        for seq in sequences:
            if pos < len(seq):
                residue = seq[pos]
                if residue == '-' or residue == '.':
                    gaps_at_pos += 1
                else:
                    residues_at_pos.append(residue)

        # Count residue frequencies
        residue_counts = Counter(residues_at_pos)
        position_residues[pos] = residue_counts
        position_gaps[pos] = gaps_at_pos

    # Create the visualization
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=figsize,
                                         gridspec_kw={'height_ratios': [1, 2, 1]})

    # 1. Top panel: Query sequence with position numbers
    ax1.set_title('Query Sequence with Position Numbers', fontsize=14, fontweight='bold')
    ax1.set_xlim(0, alignment_length)
    ax1.set_ylim(0, 1)

    # Add position numbers every 10 positions
    for pos in range(0, alignment_length, 10):
        ax1.text(pos, 0.5, str(pos+1), ha='center', va='center', fontsize=8,
                bbox=dict(boxstyle="round,pad=0.2", facecolor="lightblue", alpha=0.7))

    # Add query sequence residues
    for pos, residue in enumerate(query_sequence):
        color = 'red' if residue == '-' else 'blue'
        ax1.text(pos, 0.2, residue, ha='center', va='center', fontsize=10,
                fontweight='bold', color=color)

    ax1.set_xticks([])
    ax1.set_yticks([])
    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)
    ax1.spines['left'].set_visible(False)
    ax1.spines['bottom'].set_visible(False)

    # 2. Middle panel: Top-N residue frequency heatmap
    ax2.set_title(f'Top-{top_n} Residue Frequencies at Each Position', fontsize=14, fontweight='bold')

    # Prepare data for heatmap
    top_residues = set()
    for pos in range(alignment_length):
        if pos in position_residues:
            top_residues.update([res for res, _ in position_residues[pos].most_common(top_n)])

    top_residues = sorted(list(top_residues))
    heatmap_data = np.zeros((len(top_residues), alignment_length))

    for pos in range(alignment_length):
        if pos in position_residues:
            residue_counts = position_residues[pos]
            total_non_gaps = sum(residue_counts.values())
            if total_non_gaps > 0:
                for i, residue in enumerate(top_residues):
                    if residue in residue_counts:
                        heatmap_data[i, pos] = residue_counts[residue] / total_non_gaps

    # Create heatmap
    im = ax2.imshow(heatmap_data, cmap='YlOrRd', aspect='auto', interpolation='nearest')

    # Set labels
    ax2.set_yticks(range(len(top_residues)))
    ax2.set_yticklabels(top_residues)
    ax2.set_ylabel('Residues', fontsize=12)

    # Set x-axis labels every 10 positions
    ax2.set_xticks(range(0, alignment_length, 10))
    ax2.set_xticklabels([str(i+1) for i in range(0, alignment_length, 10)])
    ax2.set_xlabel('Position', fontsize=12)

    # Add colorbar
    cbar = plt.colorbar(im, ax=ax2, shrink=0.8)
    cbar.set_label('Frequency', fontsize=10)

    # 3. Bottom panel: Gap analysis and conservation score
    ax3.set_title('Gap Analysis and Conservation Score', fontsize=14, fontweight='bold')

    # Calculate conservation score (percentage of most common residue at each position)
    conservation_scores = []
    for pos in range(alignment_length):
        if pos in position_residues and position_residues[pos]:
            most_common_count = max(position_residues[pos].values())
            total_non_gaps = sum(position_residues[pos].values())
            if total_non_gaps > 0:
                conservation_scores.append(most_common_count / total_non_gaps)
            else:
                conservation_scores.append(0)
        else:
            conservation_scores.append(0)

    # Plot conservation scores
    x_positions = range(alignment_length)
    ax3.plot(x_positions, conservation_scores, 'b-', linewidth=2, alpha=0.8, label='Conservation Score')

    # Plot gap percentages
    gap_percentages = [position_gaps[pos] / len(sequences) * 100 for pos in range(alignment_length)]
    ax3.plot(x_positions, gap_percentages, 'r--', linewidth=2, alpha=0.8, label='Gap Percentage')

    # Add horizontal line at 50% conservation
    ax3.axhline(y=0.5, color='gray', linestyle=':', alpha=0.7, label='50% Conservation Threshold')

    ax3.set_xlabel('Position', fontsize=12)
    ax3.set_ylabel('Score/Percentage', fontsize=12)
    ax3.legend()
    ax3.grid(True, alpha=0.3)

    # Set x-axis labels every 10 positions
    ax3.set_xticks(range(0, alignment_length, 10))
    ax3.set_xticklabels([str(i+1) for i in range(0, alignment_length, 10)])

    plt.tight_layout()
    plt.show()

    # Print detailed statistics
    print("\n" + "="*60)
    print("DETAILED ALIGNMENT STATISTICS")
    print("="*60)

    # Position-wise top residues
    print(f"\nTop-{top_n} residues at key positions:")
    key_positions = [0, alignment_length//4, alignment_length//2, 3*alignment_length//4, alignment_length-1]

    for pos in key_positions:
        if pos < alignment_length:
            print(f"\nPosition {pos+1}:")
            if pos in position_residues:
                top_res = position_residues[pos].most_common(top_n)
                for residue, count in top_res:
                    percentage = (count / sum(position_residues[pos].values())) * 100
                    print(f"  {residue}: {count} ({percentage:.1f}%)")
            else:
                print("  No data available")

    # Overall statistics
    total_gaps = sum(position_gaps.values())
    total_positions = alignment_length * len(sequences)
    gap_percentage = (total_gaps / total_positions) * 100

    print(f"\nOverall Statistics:")
    print(f"Total positions: {total_positions}")
    print(f"Total gaps: {total_gaps}")
    print(f"Overall gap percentage: {gap_percentage:.2f}%")

    # Most conserved positions
    conservation_by_position = [(pos, conservation_scores[pos]) for pos in range(alignment_length)]
    conservation_by_position.sort(key=lambda x: x[1], reverse=True)

    print(f"\nTop 10 most conserved positions:")
    for i, (pos, score) in enumerate(conservation_by_position[:10]):
        if score > 0:
            print(f"  Position {pos+1}: {score:.3f} ({score*100:.1f}%)")

    return fig, (ax1, ax2, ax3)


In [None]:
fig, axes = visualize_msa_alignment(merged_alignments_protein, top_n=5)

NameError: name 'merged_alignments_protein' is not defined