In [None]:
# -*- coding: utf-8 -*-
"""
AlphaFlow - Jing et al. 2024 Colab Inference

Automatically generated based on discussions and repository analysis.
This notebook runs AlphaFlow/ESMFlow inference using the method described in:
arXiv:2402.04845v2 (https://arxiv.org/abs/2402.04845)
Code from: https://github.com/bjing2016/alphaflow
Installation fixes adapted from: https://github.com/CamelCaseCam/alphaflow-but-it-works

**Disclaimers:**
*   Requires a GPU runtime (T4, P100, V100, A100). Check Runtime > Change runtime type.
*   Setup (Cell 2) is time-consuming due to Miniconda install and AlphaFold parameter downloads (~43GB).
*   Requires significant disk space for parameters. Colab's free tier disk space might be insufficient; consider Colab Pro.
*   This notebook currently REQUIRES uploading a pre-computed MSA (A3M format). MSA generation is not implemented.
*   This is based on community findings and repository code; use at your own discretion.
"""

# ==============================================================================
# Cell 1: Introduction & Preliminaries
# ==============================================================================
# @title 1. Introduction & Preliminaries
# @markdown ## AlphaFlow (Flow Matching for Ensembles - Jing et al. 2024)
# @markdown This notebook runs inference using the AlphaFlow/ESMFlow models, which combine AlphaFold/ESMFold with flow matching to generate conformational ensembles.
# @markdown - **Paper:** [arXiv:2402.04845v2](https://arxiv.org/abs/2402.04845)
# @markdown - **Code:** [github.com/bjing2016/alphaflow](https://github.com/bjing2016/alphaflow)
# @markdown - **Installation Fixes:** Adapted from [github.com/CamelCaseCam/alphaflow-but-it-works](https://github.com/CamelCaseCam/alphaflow-but-it-works)
# @markdown - **Note**: This colab is hacked together with the use of Gemini and ChatGPT
# @markdown - **Colab version Github Repository**: https://github.com/mar4mn/ColabFlow
# @markdown ---
# @markdown ### **Important Notes:**
# @markdown *   **GPU Required:** Ensure you have selected a GPU runtime (Runtime -> Change runtime type -> GPU).
# @markdown *   **Setup Time:** Cell 2 (Setup) will take some time (potentially 5 minutes) due to Miniconda installation and downloading large AlphaFold parameter files (~43 GB).
# @markdown *   **CUDA/JAX:** Setup attempts to install compatible library versions, but issues can still arise depending on the Colab backend environment.
# @markdown ---

import os
import sys
import time
import subprocess
import shutil
from IPython.display import display, HTML

# Function to display styled messages
def print_info(message):
    display(HTML(f'<font color="blue"><b>INFO:</b> {message}</font>'))
def print_warning(message):
    display(HTML(f'<font color="orange"><b>WARNING:</b> {message}</font>'))
def print_error(message):
    display(HTML(f'<font color="red"><b>ERROR:</b> {message}</font>'))
def print_success(message):
    display(HTML(f'<font color="green"><b>SUCCESS:</b> {message}</font>'))

# Check for GPU
try:
    subprocess.check_output(['nvidia-smi'])
    print_info('GPU detected.')
except Exception:
    print_error('No GPU detected. Please enable GPU acceleration in Runtime -> Change runtime type.')
    # Optionally stop execution: sys.exit()

# Check disk space (rough estimate)
total, used, free = shutil.disk_usage("/")
print_info(f"Disk Space: Total={total / (2**30):.1f} GB, Used={used / (2**30):.1f} GB, Free={free / (2**30):.1f} GB")
if free / (2**30) < 60: # Need ~43GB params + ~5GB conda env + code + outputs
    print_warning("Available disk space is low (< 60 GB). Setup or execution might fail. Consider Colab Pro.")

print_success("Preliminaries complete. Proceed to Cell 2 for setup.")


In [None]:
#@title 2. Setup Environment, Dependencies & AlphaFold Parameters
#@markdown This cell performs the main setup following the recipe from the `CamelCaseCam/alphaflow-but-it-works` fork:
#@markdown 1. Installs Miniconda (if needed).
#@markdown 2. Creates Conda environment (if needed).
#@markdown 3. Clones code repositories (if needed).
#@markdown 4. Installs/Updates Python dependencies via pip (runs if setup marker is missing).
#@markdown 5. Downloads/Extracts AlphaFold parameters (if needed).
#@markdown ---
#@markdown **Options:**
force_setup_rerun = False #@param {type:"boolean"}
#@markdown ---
#@markdown **Note:** Setup can take 5+ minutes on the first run. If the marker file `.setup_complete` exists, most steps will be skipped unless `force_setup_rerun` is checked. Pip installs will run if the marker is missing, even if the Conda env exists.

import os
import sys
import time
import subprocess
import shutil
from IPython.display import display, HTML # For styled print functions

# --- Configuration ---
SETUP_COMPLETE_MARKER = "/content/.setup_complete"
CONDA_ENV_NAME = "alphaflow"
PYTHON_VERSION = "3.9"
NUMPY_VERSION = "1.26.4"
PANDAS_VERSION = "1.5.3"
TORCH_VERSION = "2.3.1"
ALPHAFLOW_CODE_DIR = "/content/alphaflow_code"
OPENFOLD_CODE_DIR = os.path.join(ALPHAFLOW_CODE_DIR, "openfold")
AF_PARAMS_DIR = "/content/alphafold/data"
AF_PARAMS_CHECK_FILE = os.path.join(AF_PARAMS_DIR, "params_model_1.npz") # Check file directly in data dir
CONDA_PREFIX = f"/usr/local/envs/{CONDA_ENV_NAME}"
CONDA_BIN_PATH = f"{CONDA_PREFIX}/bin"
CONDA_PIP_PATH = f"{CONDA_BIN_PATH}/pip"
CONDA_PYTHON_PATH = f"{CONDA_BIN_PATH}/python"

# --- Helper function to display styled messages ---
def print_info(message): display(HTML(f'<font color="blue"><b>INFO:</b> {message}</font>'))
def print_warning(message): display(HTML(f'<font color="orange"><b>WARNING:</b> {message}</font>'))
def print_error(message): display(HTML(f'<font color="red"><b>ERROR:</b> {message}</font>'))
def print_success(message): display(HTML(f'<font color="green"><b>SUCCESS:</b> {message}</font>'))

# --- Helper function to run shell commands ---
def run_shell_command(command, description, check_return_code=True):
    # Runs command, prints status, optionally raises error on failure
    print(f"--- Running: {description} ---")
    # print(f"Command: {command}") # Uncomment for debugging
    start_time = time.time()
    process = subprocess.run(command, shell=True, capture_output=True, text=True)
    elapsed_time = time.time() - start_time
    success = process.returncode == 0

    if success:
        print_success(f"Finished: {description} (Took {elapsed_time:.2f} seconds)")
    else:
        print_error(f"Failed: {description} (Return Code: {process.returncode})")
        print("--- STDERR ---")
        print(process.stderr)
        print("--- STDOUT ---")
        print(process.stdout)
        if check_return_code:
            raise RuntimeError(f"Setup failed during: {description}")

    return success, process # Return success status and process object

# --- Idempotency Check ---
if force_setup_rerun and os.path.exists(SETUP_COMPLETE_MARKER):
    print_info("`force_setup_rerun` is True. Removing marker file to force setup.")
    os.remove(SETUP_COMPLETE_MARKER)

if os.path.exists(SETUP_COMPLETE_MARKER):
    print_success("Setup marker file found. Assuming environment is already set up.")
    print_info("To force a re-run, check the `force_setup_rerun` box and run this cell again.")
    # Ensure conda env is in PATH for subsequent cells if kernel restarted
    if CONDA_BIN_PATH not in os.environ['PATH']:
        os.environ['PATH'] = f"{CONDA_BIN_PATH}:{os.environ['PATH']}"
    # Add code path just in case
    if ALPHAFLOW_CODE_DIR not in sys.path:
        sys.path.append(ALPHAFLOW_CODE_DIR)

else:
    print_info("Setup marker file not found or forced rerun. Starting setup checks...")
    setup_errors = [] # Collect errors

    try:
        # --- Start Actual Setup ---

        # Check Colab's CUDA version for PyTorch install
        print("--- Checking Colab Host CUDA Version ---")
        torch_cuda_suffix = "+cu118" # Default assumption
        nvcc_path = shutil.which('nvcc')
        if nvcc_path:
            nvcc_output = subprocess.run("nvcc --version | grep release", shell=True, capture_output=True, text=True).stdout
            if "release 12." in nvcc_output:
                torch_cuda_suffix = "+cu121"
                print_info("Host CUDA seems to be 12.x. Will install PyTorch with cu121.")
            elif "release 11." in nvcc_output:
                torch_cuda_suffix = "+cu118"
                print_info("Host CUDA seems to be 11.x. Will install PyTorch with cu118.")
            else: print_warning("Could not reliably determine host CUDA version. Assuming cu118.")
        else: print_warning("nvcc not found on host. Assuming cu118 for PyTorch.")

        # 1. Install Miniconda if not present
        if not os.path.exists("/usr/local/bin/conda"):
            run_shell_command(
                "wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && "
                "chmod +x Miniconda3-latest-Linux-x86_64.sh && "
                "./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local && "
                "rm Miniconda3-latest-Linux-x86_64.sh",
                "Install Miniconda"
            )
            os.environ['PATH'] = f"/usr/local/bin:{os.environ['PATH']}"
            py_version_long = f"{sys.version_info.major}.{sys.version_info.minor}"
            sys.path.append(f'/usr/local/lib/python{py_version_long}/site-packages/')
        else:
            print_info("Miniconda already installed.")
        # Ensure conda is in PATH
        if '/usr/local/bin' not in os.environ['PATH']: os.environ['PATH'] = f"/usr/local/bin:{os.environ['PATH']}"

        # Accept Anaconda channel ToS (required for non-interactive conda)
        run_shell_command(
            "conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main",
            "Accept ToS for pkgs/main",
            check_return_code=True,
        )
        run_shell_command(
            "conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r",
            "Accept ToS for pkgs/r",
            check_return_code=True,
        )


        # 2. Create Conda Environment if not present
        if not os.path.exists(CONDA_PREFIX):
            run_shell_command(f"conda create -n {CONDA_ENV_NAME} python={PYTHON_VERSION} -y", f"Create Conda env {CONDA_ENV_NAME}")
        else:
            print_info(f"Conda environment '{CONDA_ENV_NAME}' already exists.")
        # Ensure env bin is in PATH
        if CONDA_BIN_PATH not in os.environ['PATH']: os.environ['PATH'] = f"{CONDA_BIN_PATH}:{os.environ['PATH']}"

        # 3. Clone Repositories (only if not already present)
        if not os.path.exists(ALPHAFLOW_CODE_DIR):
            run_shell_command(f"git clone https://github.com/CamelCaseCam/alphaflow-but-it-works.git {ALPHAFLOW_CODE_DIR}", "Clone alphaflow-but-it-works")
        else:
            print_info(f"Directory '{ALPHAFLOW_CODE_DIR}' already exists, skipping clone.")

        if not os.path.exists(OPENFOLD_CODE_DIR):
             os.chdir(ALPHAFLOW_CODE_DIR)
             run_shell_command(f"git clone https://github.com/CamelCaseCam/openfold-but-fixed-for-alphaflow.git openfold", "Clone openfold-but-fixed-for-alphaflow")
             os.chdir('/content') # Go back
        else:
            print_info(f"Directory '{OPENFOLD_CODE_DIR}' already exists, skipping clone.")

        # 4. Install/Update Python Dependencies (Run these pip installs every time marker is missing)
        print_info("Installing/Updating Python dependencies into Conda environment...")
        os.chdir(ALPHAFLOW_CODE_DIR)
        run_shell_command(f"{CONDA_PIP_PATH} install --upgrade pip", "Upgrade pip") # Ensure pip is recent
        run_shell_command(f"{CONDA_PIP_PATH} install numpy=={NUMPY_VERSION} pandas=={PANDAS_VERSION}", "Install pinned numpy & pandas")
        torch_index_url = f"--index-url https://download.pytorch.org/whl/{torch_cuda_suffix[1:]}"
        run_shell_command(f"{CONDA_PIP_PATH} install torch=={TORCH_VERSION}{torch_cuda_suffix} {torch_index_url}", f"Install PyTorch {TORCH_VERSION} for host CUDA ({torch_cuda_suffix})")
        run_shell_command(f"{CONDA_PIP_PATH} install biopython==1.81 dm-tree==0.1.8 modelcif==0.7 ml-collections==0.1.1 scipy==1.10.1 absl-py einops", "Install core Python libs")
        # Zlib install via conda - check if needed if env exists? Assume it's needed if pip installs run.
        run_shell_command(f"conda install -n {CONDA_ENV_NAME} -c conda-forge zlib -y", f"Install/Verify Zlib in {CONDA_ENV_NAME}")
        run_shell_command(f"{CONDA_PIP_PATH} install --force-reinstall numpy=={NUMPY_VERSION}", "Re-install pinned numpy")
        run_shell_command(f"{CONDA_PIP_PATH} install pytorch_lightning==2.0.4 fair-esm mdtraj==1.9.9 wandb requests", "Install lightning, esm, mdtraj, wandb, requests")
        run_shell_command(f"{CONDA_PIP_PATH} install --force-reinstall numpy=={NUMPY_VERSION}", "Re-install pinned numpy (2nd)")

        # --- IMPORTANT: Add build-essential for OpenFold compilation ---
        run_shell_command("apt-get update && apt-get install -y build-essential", "Install build-essential (system-wide dev tools)")
        run_shell_command(f"{CONDA_PIP_PATH} cache purge", "Clear pip cache") # Clear cache before OpenFold install


        os.chdir(OPENFOLD_CODE_DIR)
        run_shell_command(f"{CONDA_PIP_PATH} install --upgrade setuptools wheel", "Install/Upgrade setuptools and wheel") # Upgrade build deps
        run_shell_command(f"{CONDA_PIP_PATH} install --no-deps --config-settings editable_mode=compat --no-build-isolation -e .", "Install OpenFold (editable)")
        os.chdir(ALPHAFLOW_CODE_DIR)
        run_shell_command(f"{CONDA_PIP_PATH} install modelcif", "Install modelcif (again?)")
        run_shell_command(f"{CONDA_PIP_PATH} install --force-reinstall numpy=={NUMPY_VERSION}", "Re-install pinned numpy (final)")
        os.chdir('/content') # Return to base directory
        run_shell_command(f"{CONDA_PIP_PATH} install py3Dmol", "Install py3Dmol for visualization")

        # 5. Download AlphaFold Parameters (only if check file not already present)
        AF_PARAMS_TAR = "/content/alphafold_params_2022-12-06.tar"
        AF_PARAMS_URL = "https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar"

        if not os.path.exists(AF_PARAMS_CHECK_FILE):
            print_info("AlphaFold parameter check file not found. Proceeding with download and extraction.")
            os.makedirs(AF_PARAMS_DIR, exist_ok=True) # Ensure directory exists

            if not os.path.exists(AF_PARAMS_TAR):
                 print_info("Downloading AlphaFold parameters (~43 GB)...")
                 if shutil.which("aria2c"):
                     run_shell_command(f"aria2c -q -x 16 -s 16 -k 1M -d /content/ -o {os.path.basename(AF_PARAMS_TAR)} {AF_PARAMS_URL}", "Download AlphaFold params (aria2c)")
                 else:
                     run_shell_command(f"wget -q -P /content/ {AF_PARAMS_URL}", "Download AlphaFold params (wget)")
            else:
                 print_info("Tar file already exists, proceeding to extraction.")

            print_info(f"Extracting ALL contents of {AF_PARAMS_TAR} into {AF_PARAMS_DIR}...")
            run_shell_command(f"tar -xf {AF_PARAMS_TAR} -C {AF_PARAMS_DIR}", "Extract AlphaFold params")

            print_info(f"Verifying extraction by checking for: {AF_PARAMS_CHECK_FILE}")
            if os.path.exists(AF_PARAMS_CHECK_FILE):
                print_success("Extraction successful. Check file found.")
                run_shell_command(f"rm {AF_PARAMS_TAR}", "Remove params tarball", check_return_code=False) # Don't fail setup if tar removal fails
            else:
                print_error(f"Extraction verification failed. Expected file not found: {AF_PARAMS_CHECK_FILE}")
                print_warning("Listing contents of extraction directory for debugging:")
                os.system(f"ls -lR {AF_PARAMS_DIR}") # List contents for debugging
                raise RuntimeError("Parameter extraction failed verification.")
        else:
            print_info(f"AlphaFold parameter check file '{os.path.basename(AF_PARAMS_CHECK_FILE)}' already exists, skipping download/extraction.")


        # --- Final Verification & Marker File ---
        print("\n--- Verifying environment ---")
        print("Python version:")
        run_shell_command(f"{CONDA_PYTHON_PATH} --version", "Check Python version")
        print("\nPyTorch CUDA status:")
        # Corrected: Format the string *inside* the Python command executed by the shell
        run_shell_command(f"{CONDA_PYTHON_PATH} -c 'import torch; print(\"PyTorch version: {{}}\".format(torch.__version__)); print(\"CUDA available: {{}}\".format(torch.cuda.is_available())); print(\"CUDA version detected by Torch: {{}}\".format(torch.version.cuda))'", "Check PyTorch CUDA status")
        # Add alphaflow_code to python path
        sys.path.append(ALPHAFLOW_CODE_DIR)

        # Create marker file to indicate successful completion
        with open(SETUP_COMPLETE_MARKER, "w") as f:
            f.write(f"Setup completed successfully at {time.ctime()}")
        print_success("Setup complete. Marker file created.")
        # setup_successful = True # No longer needed with try/except

    except Exception as e:
        print_error(f"An error occurred during setup: {e}")
        print_warning("Setup did not complete successfully. Marker file was not created. Please fix the error and re-run.")

    finally:
        # Ensure we are in the right directory even if errors occurred
        os.chdir('/content')
        # Final check for marker file
        if os.path.exists(SETUP_COMPLETE_MARKER):
             print_success("Environment 'alphaflow' should be ready.")
        else:
             print_error("Setup failed or was interrupted. Environment may not be functional.")

--- Checking Colab Host CUDA Version ---


--- Running: Upgrade pip ---


--- Running: Install pinned numpy & pandas ---


--- Running: Install PyTorch 2.3.1 for host CUDA (+cu121) ---


--- Running: Install core Python libs ---


--- Running: Install/Verify Zlib in alphaflow ---


--- Running: Re-install pinned numpy ---


--- Running: Install lightning, esm, mdtraj, wandb, requests ---


--- Running: Re-install pinned numpy (2nd) ---


--- Running: Install build-essential (system-wide dev tools) ---


--- Running: Clear pip cache ---


--- Running: Install/Upgrade setuptools and wheel ---


--- Running: Install OpenFold (editable) ---


--- Running: Install modelcif (again?) ---


--- Running: Re-install pinned numpy (final) ---


--- Running: Install py3Dmol for visualization ---


--- Running: Download AlphaFold params (wget) ---


--- Running: Extract AlphaFold params ---


--- Running: Remove params tarball ---



--- Verifying environment ---
Python version:
--- Running: Check Python version ---



PyTorch CUDA status:
--- Running: Check PyTorch CUDA status ---


In [5]:
#@title 3. Download Fine-tuned AlphaFlow/ESMFlow Model Weights
#@markdown Select the specific fine-tuned model weights you want to use for inference.
#@markdown The URLs are taken from the `bjing2016/alphaflow` repository README.

import os
import requests
#import GPy # Need GPy for ml_collections often
from google.colab import files
from IPython.display import display, HTML # For styled print functions

# --- Helper function to display styled messages ---
def print_info(message): display(HTML(f'<font color="blue"><b>INFO:</b> {message}</font>'))
def print_warning(message): display(HTML(f'<font color="orange"><b>WARNING:</b> {message}</font>'))
def print_error(message): display(HTML(f'<font color="red"><b>ERROR:</b> {message}</font>'))
def print_success(message): display(HTML(f'<font color="green"><b>SUCCESS:</b> {message}</font>'))

# Model weights dictionary (from README)
model_weights_urls = {
    "AlphaFlow-PDB_base": "https://storage.googleapis.com/alphaflow/params/alphaflow_pdb_base_202402.pt",
    "AlphaFlow-PDB_distilled": "https://storage.googleapis.com/alphaflow/params/alphaflow_pdb_distilled_202402.pt",
    "AlphaFlow-MD_base": "https://storage.googleapis.com/alphaflow/params/alphaflow_md_base_202402.pt",
    "AlphaFlow-MD_distilled": "https://storage.googleapis.com/alphaflow/params/alphaflow_md_distilled_202402.pt",
    "AlphaFlow-MD+Templates_base": "https://storage.googleapis.com/alphaflow/params/alphaflow_md_templates_base_202402.pt",
    "AlphaFlow-MD+Templates_distilled": "https://storage.googleapis.com/alphaflow/params/alphaflow_md_templates_distilled_202402.pt",
    "AlphaFlow-MD+Templates_12l-base": "https://storage.googleapis.com/alphaflow/params/alphaflow_12l_md_templates_base_202406.pt",
    "AlphaFlow-MD+Templates_12l-distilled": "https://storage.googleapis.com/alphaflow/params/alphaflow_12l_md_templates_distilled_202406.pt",
    "ESMFlow-PDB_base": "https://storage.googleapis.com/alphaflow/params/esmflow_pdb_base_202402.pt",
    "ESMFlow-PDB_distilled": "https://storage.googleapis.com/alphaflow/params/esmflow_pdb_distilled_202402.pt",
    "ESMFlow-MD_base": "https://storage.googleapis.com/alphaflow/params/esmflow_md_base_202402.pt",
    "ESMFlow-MD_distilled": "https://storage.googleapis.com/alphaflow/params/esmflow_md_distilled_202402.pt",
    "ESMFlow-MD+Templates_base": "https://storage.googleapis.com/alphaflow/params/esmflow_md_templates_base_202402.pt",
    "ESMFlow-MD+Templates_distilled": "https://storage.googleapis.com/alphaflow/params/esmflow_md_templates_distilled_202402.pt",
}

# @markdown ---
# @markdown **Select Model:** Choose the desired fine-tuned model weights.
selected_model_name = "AlphaFlow-MD_base" # @param ["AlphaFlow-PDB_base", "AlphaFlow-PDB_distilled", "AlphaFlow-MD_base", "AlphaFlow-MD_distilled", "AlphaFlow-MD+Templates_base", "AlphaFlow-MD+Templates_distilled", "AlphaFlow-MD+Templates_12l-base", "AlphaFlow-MD+Templates_12l-distilled", "ESMFlow-PDB_base", "ESMFlow-PDB_distilled", "ESMFlow-MD_base", "ESMFlow-MD_distilled", "ESMFlow-MD+Templates_base", "ESMFlow-MD+Templates_distilled"]
# @markdown ---

model_weights_url = model_weights_urls[selected_model_name]
model_weights_filename = os.path.basename(model_weights_url)
model_weights_dir = "/content/model_weights"
model_weights_path = os.path.join(model_weights_dir, model_weights_filename)

os.makedirs(model_weights_dir, exist_ok=True)

# Download the weights if they don't exist
if not os.path.exists(model_weights_path):
    print_info(f"Downloading weights for {selected_model_name} from {model_weights_url}...")
    try:
        response = requests.get(model_weights_url, stream=True)
        response.raise_for_status() # Raise an exception for bad status codes
        with open(model_weights_path, "wb") as f:
            # Simple download without progress bar
            f.write(response.content)
        print_success(f"Weights downloaded successfully to {model_weights_path}")
    except Exception as e:
        print_error(f"Failed to download weights: {e}")
else:
    print_info(f"Model weights '{model_weights_filename}' already exist.")

# Store for later use
%store selected_model_name
%store model_weights_path

Stored 'selected_model_name' (str)
Stored 'model_weights_path' (str)


In [7]:
#@title 4. Configure Inference Run
#@markdown Define the parameters for your AlphaFlow/ESMFlow inference run.

import os
import re
import sys # Needed for sys.exit or raising errors effectively
from typing import List, Tuple
from IPython.display import display, HTML # For styled print functions

# --- Helper function to display styled messages ---
def print_info(message): display(HTML(f'<font color="blue"><b>INFO:</b> {message}</font>'))
def print_warning(message): display(HTML(f'<font color="orange"><b>WARNING:</b> {message}</font>'))
def print_error(message): display(HTML(f'<font color="red"><b>ERROR:</b> {message}</font>'))
def print_success(message): display(HTML(f'<font color="green"><b>SUCCESS:</b> {message}</font>'))

# --- Sequence Sanitization Function ---
def sanitize_and_validate_sequences(sequence_input: str) -> Tuple[List[str], List[str]]:
    """
    Sanitizes and validates protein sequences provided in a single string.

    Sequences can be single or multiple, separated by colons.
    Sanitization involves:
    1. Removing leading/trailing whitespace.
    2. Removing all non-alphabetic characters (including internal whitespace/gaps).
    3. Converting to uppercase.
    Validation checks if the final sequence contains only A-Z characters and is not empty.

    Args:
        sequence_input: The raw input string containing one or more sequences.

    Returns:
        A tuple containing:
        - list[str]: A list of sanitized and validated sequences.
        - list[str]: A list of error messages encountered during processing.
                     If this list is empty, all sequences were processed successfully.
    """
    sanitized_sequences = []
    errors = []

    if not sequence_input or not sequence_input.strip():
        errors.append("Sequence input is empty or contains only whitespace.")
        # No need to return here, let the main validation logic handle the empty errors list

    # Split potential multiple sequences
    raw_sequences = sequence_input.split(':')

    for i, raw_seq in enumerate(raw_sequences):
        original_segment_for_error = raw_seq[:30] + ('...' if len(raw_seq) > 30 else '') # For error messages

        # 1. Remove leading/trailing whitespace
        stripped_seq = raw_seq.strip()

        # Ignore empty segments resulting from splitting (e.g., "SEQ1::SEQ2")
        if not stripped_seq:
            # Optionally print a warning, but don't treat as a fatal error unless no valid sequences are found later
            # print_warning(f"Sequence segment {i+1} was empty or only whitespace, skipping.")
            continue

        # 2. Remove non-alphabetic characters and convert to uppercase
        #    [^A-Za-z] matches anything that is NOT an uppercase or lowercase letter.
        cleaned_seq = re.sub(r'[^ACDEFGHIKLMNPQRSTVWY]', '', stripped_seq.upper())

        # 3. Validate the cleaned sequence
        if not cleaned_seq:
            error_msg = (f"Sequence segment {i+1} ('{original_segment_for_error}') "
                         f"contained no valid amino acid characters after cleaning.")
            errors.append(error_msg)
            continue # Skip adding this sequence

        # Optional stricter validation (redundant but safe): Check if only standard A-Z chars remain
        if not re.match(r"^[A-Z]+$", cleaned_seq):
             error_msg = (f"Sanitized sequence segment {i+1} ('{cleaned_seq[:30]}...') "
                          f"still contains unexpected characters after cleaning. Original: '{original_segment_for_error}'")
             errors.append(error_msg)
             continue # Skip adding this sequence

        # If all checks pass, add the sanitized sequence
        sanitized_sequences.append(cleaned_seq)

    # Final check: If after processing all segments, we have no valid sequences AND errors were generated, report them.
    # If no valid sequences were found but no specific errors were generated (e.g., input was " : : "), add a general error.
    if not sanitized_sequences and not errors and sequence_input and sequence_input.strip():
         errors.append("No valid sequences found after processing the input. Input might contain only separators or invalid characters.")
    elif not sanitized_sequences and not errors and (not sequence_input or not sequence_input.strip()):
         # This catches the case where the initial input was truly empty or just whitespace
         errors.append("Sequence input is empty or contains only whitespace.")


    return sanitized_sequences, errors

# @markdown ---
# @markdown **Job Name:** Base name used for output files and directories. Unique IDs will be generated for multiple sequences (e.g., jobname_1, jobname_2).
jobname = "1UBQ" # @param {type:"string"}
# @markdown ---
# @markdown **Input Sequence(s):** Paste your protein sequence(s) here. Sequences will be sanitized (non-alphabetic characters removed, converted to uppercase).
# @markdown - **Single sequence:** `MGYQV INTNSQ...` (Spaces/gaps will be removed)
# @markdown - **Multiple sequences:** Separate sequences with a colon (**`:`**). Example: `SEQWENCE: SE QWENCE`
# @markdown - **Do NOT use FASTA format here.** Provide only the raw amino acid strings.
sequence_input = "MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG" # @param {type:"string"}
# @markdown ---
# @markdown **MSA Input (Required for AlphaFlow, ignored for ESMFlow):**
# @markdown - **`Generate MSA (via ColabFold API)`:** Uses the ColabFold server (via `scripts/mmseqs_query.py`) to generate MSAs **for each sanitized sequence**. This typically uses UniRef30 + Environmental databases with filtering. Low-level MMseqs2 parameters are handled by the API.
# @markdown - **`Upload Custom A3M`:** You will be prompted to upload **one A3M file for each sequence** in the next cell. This bypasses MSA generation.
msa_mode = "Generate MSA (via ColabFold API)" # @param ["Generate MSA (via ColabFold API)", "Upload Custom A3M"]
# @markdown ---
# @markdown **Sampling Parameters:** (Applied to each sequence during AlphaFlow/ESMFlow inference)
num_samples = 10 # @param {type:"integer", min:1, max:250}
# @markdown `inference_steps`: Corresponds to `N` in the paper's Algorithm 2. Default is 10. Lower values might be faster but less accurate/diverse.
inference_steps = 10 # @param {type:"integer", min:1, max:50}
# @markdown `tmax`: Endpoint for flow integration (default 1.0). Truncating (e.g., 0.2) increases precision, reduces diversity (See paper Appendix B.1).
tmax = 1.0 #@param {type:"slider", min:0.1, max:1.0, step:0.1}
# @markdown ---
# @markdown **Template Input (Conditional - for Model Choice):**
# @markdown *This is only relevant if you selected an "MD+Templates" model in the previous cell.*
# @markdown - If checked **and** using a "+Templates" model, you will be prompted to upload **one template file (e.g., PDB/CIF) for each sequence** in Cell 5.
# @markdown - This setting controls **user-uploaded templates** used directly by the MD model, it does *not* enable PDB70 searching during MSA generation via the ColabFold API.
use_md_template = False # @param {type:"boolean"}
# @markdown ---
# @markdown **Performance Options (AlphaFlow/ESMFlow Inference):**
# @markdown *Recommended for PDB models for potentially better results.*
self_cond = True # @param {type:"boolean"}
resample_msa_per_sample = True # @param {type:"boolean"}
# @markdown *Required for distilled models.*
distilled_model_mode = False #@param {type:"boolean"}
# @markdown If using a distilled model, set `noisy_first = True` and `no_diffusion = True`.
noisy_first = False #@param {type:"boolean"}
no_diffusion = False #@param {type:"boolean"}
# @markdown ---

# --- Validate Inputs ---

# Sanitize Jobname
if not jobname:
    jobname = "alphaflow_run"
    print_warning(f"Jobname empty, using default: {jobname}")
jobname = re.sub(r'\W+', '_', jobname) # Replace non-alphanumeric with underscore

# Sanitize and Validate Sequences
sanitized_sequences, validation_errors = sanitize_and_validate_sequences(sequence_input)

if validation_errors:
    print_error("Sequence validation failed:")
    for error_msg in validation_errors:
        display(HTML(f'<font color="red">- {error_msg}</font>')) # Use display/HTML for consistent formatting
    # Stop execution by raising an error
    raise ValueError("Invalid sequence input detected. Please check errors above and correct the 'sequence_input' field.")
else:
    # Sequences are valid and sanitized
    num_sequences = len(sanitized_sequences)
    if num_sequences > 1:
        print_info(f"Successfully sanitized and validated {num_sequences} sequences.")
        # Optionally print the first few residues of each for confirmation
        # for i, seq in enumerate(sanitized_sequences):
        #     print_info(f"  Sequence {i+1}: {seq[:10]}... (Length: {len(seq)})")
    elif num_sequences == 1:
        print_info(f"Successfully sanitized and validated 1 sequence.")
        # print_info(f"  Sequence: {sanitized_sequences[0][:20]}... (Length: {len(sanitized_sequences[0])})")
    # If we got here, sequences are good.


# Check template consistency (applies generally, template files handled per-sequence later)
# Ensure selected_model_name is loaded if it was set in a previous cell (e.g., Cell 3)
try:
    %store -r selected_model_name
except KeyError:
    print_warning("Variable 'selected_model_name' not found in store. Template/Distilled checks might be inaccurate if model was selected earlier.")
    selected_model_name = "" # Assign default empty string

is_template_model = "+Templates" in selected_model_name
if use_md_template and not is_template_model:
    print_warning(f"Template usage enabled (`use_md_template=True`), but selected model '{selected_model_name}' may not be a '+Templates' variant. Template uploads will be requested in Cell 5 but might be ignored by the model.")
if not use_md_template and is_template_model:
    print_warning(f"Template usage disabled (`use_md_template=False`), but selected model '{selected_model_name}' appears to be a '+Templates' variant. Model might expect templates, but none will be provided via upload.")

# Check distilled settings consistency
is_distilled_model = "distilled" in selected_model_name
if distilled_model_mode and not is_distilled_model:
     print_warning(f"Distilled mode options enabled, but selected model '{selected_model_name}' may not be a distilled variant.")
if not distilled_model_mode and is_distilled_model:
     print_warning(f"Distilled mode options disabled, but selected model '{selected_model_name}' appears to be a distilled variant. Ensure --noisy_first and --no_diffusion are set if needed.")
if distilled_model_mode and (not noisy_first or not no_diffusion):
     print_warning("Distilled mode selected, but --noisy_first and/or --no_diffusion are not checked. These are usually required for distilled models.")

# --- Store variables for later cells ---
%store jobname
# Store the list of sanitized sequences - this is likely more useful downstream
%store sanitized_sequences
# You might not need the original raw input anymore
%store sequence_input
%store msa_mode
%store num_samples
%store inference_steps
%store tmax
%store use_md_template
%store self_cond
%store resample_msa_per_sample
%store distilled_model_mode
%store noisy_first
%store no_diffusion
# Also store selected_model_name if it was loaded, Cell 6 uses it.
if 'selected_model_name' in locals() or 'selected_model_name' in globals():
     %store selected_model_name

print_success("Configuration set and sequences sanitized. Proceed to the next cell.")

Stored 'jobname' (str)
Stored 'sanitized_sequences' (list)
Stored 'sequence_input' (str)
Stored 'msa_mode' (str)
Stored 'num_samples' (int)
Stored 'inference_steps' (int)
Stored 'tmax' (float)
Stored 'use_md_template' (bool)
Stored 'self_cond' (bool)
Stored 'resample_msa_per_sample' (bool)
Stored 'distilled_model_mode' (bool)
Stored 'noisy_first' (bool)
Stored 'no_diffusion' (bool)
Stored 'selected_model_name' (str)


In [8]:
#@title 5. Prepare Input Files (Using scripts.mmseqs_query)

import os
import shutil
import pandas as pd
import re
import random
from google.colab import files
import time
import zipfile # Might still be needed if the script downloads a zip, but likely handled internally
import subprocess # Using subprocess for better control over script execution
from IPython.display import display, HTML # For styled print functions

# --- Helper function to display styled messages ---
def print_info(message): display(HTML(f'<font color="blue"><b>INFO:</b> {message}</font>'))
def print_warning(message): display(HTML(f'<font color="orange"><b>WARNING:</b> {message}</font>'))
def print_error(message): display(HTML(f'<font color="red"><b>ERROR:</b> {message}</font>'))
def print_success(message): display(HTML(f'<font color="green"><b>SUCCESS:</b> {message}</font>'))

# --- Load variables from previous cell ---
%store -r jobname
#%store -r sequence_input # Raw input string (potentially multiple sequences separated by :)
%store -r msa_mode
%store -r use_md_template
%store -r selected_model_name
# Needed to determine mode for predict.py

# --- Define Paths and Create Directories ---
# **Minimal Change:** Use the variable names Cell 6 expects directly
alphaflow_repo_path = "/content/alphaflow_code" # Assuming the repo is cloned here
run_dir = f"/content/{jobname}"
output_dir = os.path.join(run_dir, "output") # predict.py outputs here (shared)
msa_dir = os.path.join(run_dir, "msas") # Base dir for MSAs (using name Cell 6 expects)
template_dir_path = None # Initialize to None (using name Cell 6 expects)
if use_md_template and "+Templates" in selected_model_name:
    template_dir_path = os.path.join(run_dir, "templates") # Assign path only if templates will be used

os.makedirs(run_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)
# Clean potential previous run artifacts for base MSA/Template dirs
if os.path.exists(msa_dir): shutil.rmtree(msa_dir)
if template_dir_path and os.path.exists(template_dir_path): shutil.rmtree(template_dir_path) # Only remove if path was defined
os.makedirs(msa_dir, exist_ok=True) # Recreate base MSA dir
if template_dir_path: os.makedirs(template_dir_path, exist_ok=True) # Recreate template dir only if needed

print_info(f"Run directory: {run_dir}")
print_info(f"Output directory: {output_dir}")
print_info(f"Base MSA directory (msa_dir): {msa_dir}")
if template_dir_path:
    print_info(f"Base Template directory (template_dir_path): {template_dir_path}")
else:
    print_info("Template directory not created (template usage not indicated or model incompatible).")


# --- Process Sequences and Prepare Files ---
input_sequences = sequence_input.split(':')
num_sequences = len(input_sequences)
print_info(f"Found {num_sequences} sequence(s) to process.")

csv_data = [] # To build the final input CSV data

# --- Loop through each sequence ---
model_mode_for_predict = "alphafold" # Default for AlphaFlow
if "ESMFlow" in selected_model_name:
    model_mode_for_predict = "esmfold"
    print_info("ESMFlow model selected - MSA processing will be skipped.")


for i, seq_raw in enumerate(input_sequences):
    seq_index = i + 1
    print(f"\n--- Processing Sequence {seq_index}/{num_sequences} ---")

    # --- Generate Unique ID and Clean Sequence ---
    fasta_id = f"{jobname}_{seq_index}" # Unique ID for this sequence
    seqres = seq_raw.strip().upper().replace(" ", "").replace("\n", "").replace("\t", "")

    # Validate sequence
    if not seqres:
        print_error(f"Sequence {seq_index} is empty after cleaning. Skipping.")
        continue # Skip to the next sequence in the loop
    elif not re.match(r"^[A-Z]+$", seqres):
        print_error(f"Sequence {seq_index} '{seqres[:20]}...' contains invalid characters. Skipping.")
        continue
    elif seqres.startswith(">"):
        print_error(f"Sequence {seq_index} input started with '>'. Please provide only raw sequences. Skipping.")
        continue

    print_info(f"Sequence ID: {fasta_id}, Length: {len(seqres)}")
    # Add data for this sequence to our list for the final CSV
    csv_data.append({'name': fasta_id, 'seqres': seqres})

    # --- Prepare MSA (if AlphaFlow mode) ---
    if model_mode_for_predict == "alphafold":
        # Define the final expected path for the A3M file within the base msa_dir
        target_a3m_path = os.path.join(msa_dir, fasta_id, "a3m", f"{fasta_id}.a3m")
        # Ensure the *parent* directories for the target file exist *before* calling the script/moving file
        os.makedirs(os.path.dirname(target_a3m_path), exist_ok=True)

        if msa_mode == "Upload Custom A3M":
            print(f"\nPlease upload the A3M MSA file for sequence {seq_index} (ID: {fasta_id}):")
            uploaded_msa = files.upload()
            if not uploaded_msa:
                print_error(f"No A3M file uploaded for sequence {fasta_id}. AlphaFlow requires MSAs.")
                raise RuntimeError(f"MSA upload failed for sequence {fasta_id}.")
            else:
                uploaded_a3m_name = list(uploaded_msa.keys())[0]
                shutil.move(uploaded_a3m_name, target_a3m_path) # Move to the final target location
                print_success(f"A3M file for {fasta_id} saved to: {target_a3m_path}")

        elif msa_mode == "Generate MSA (via ColabFold API)":
            print_info(f"Generating MSA for sequence {fasta_id} using scripts.mmseqs_query...")

            # Create a temporary CSV file for this single sequence
            temp_csv_path = os.path.join(run_dir, f"temp_query_{fasta_id}.csv")
            temp_df = pd.DataFrame([{'name': fasta_id, 'seqres': seqres}])
            temp_df.to_csv(temp_csv_path, index=False)
            print_info(f"Created temporary input CSV: {temp_csv_path}")

            # Construct and Run the Command using the correct base directory name 'msa_dir'
            script_path = os.path.join(alphaflow_repo_path, "scripts", "mmseqs_query.py")
            cmd = [
                "python", "-m", "scripts.mmseqs_query",
                "--split", temp_csv_path,
                "--outdir", msa_dir # Pass the base MSA directory (named msa_dir)
            ]
            print_info(f"Running command: {' '.join(cmd)}")

            try:
                # Execute the command
                result = subprocess.run(cmd, capture_output=True, text=True, check=True, cwd=alphaflow_repo_path, timeout=900)
                print_info("Script stdout:")
                print(result.stdout)
                if result.stderr:
                    print_warning("Script stderr:")
                    print(result.stderr)

                # Verify Output
                if os.path.exists(target_a3m_path):
                    print_success(f"Successfully generated MSA. Output A3M found at: {target_a3m_path}")
                else:
                    print_error(f"MSA generation script finished, but the expected output file was not found: {target_a3m_path}")
                    raise RuntimeError(f"MSA generation failed for {fasta_id}: Output file missing.")

            except subprocess.CalledProcessError as e:
                print_error(f"Error executing mmseqs_query script for {fasta_id}. Return code: {e.returncode}")
                print_error("Stdout:")
                print(e.stdout)
                print_error("Stderr:")
                print(e.stderr)
                raise RuntimeError(f"MSA generation failed for {fasta_id} due to script error.")
            except subprocess.TimeoutExpired:
                 print_error(f"MSA generation script timed out for {fasta_id}.")
                 raise RuntimeError(f"MSA generation timed out for {fasta_id}.")
            finally:
                # Cleanup Temporary CSV
                if os.path.exists(temp_csv_path):
                    os.remove(temp_csv_path)
                    print_info(f"Removed temporary input CSV: {temp_csv_path}")
        else:
            print_error(f"Invalid msa_mode selected: {msa_mode}")
            raise ValueError("Invalid MSA mode.")
    # --- End MSA Handling ---

    # --- Prepare Template (Conditional) ---
    # Use the template_dir_path variable (which might be None or the actual path)
    if template_dir_path: # This condition implicitly checks use_md_template and model compatibility
        print(f"\nPlease upload the single PDB/mmCIF template structure for sequence {seq_index} (ID: {fasta_id}):")
        uploaded_template = files.upload()
        if not uploaded_template:
            print_error(f"Template usage was selected, but no template file was uploaded for sequence {fasta_id}.")
            raise RuntimeError(f"Template upload failed for sequence {fasta_id}.")
        else:
            uploaded_template_name = list(uploaded_template.keys())[0]
            template_file_extension = ".pdb"
            if uploaded_template_name.lower().endswith(".cif"): template_file_extension = ".cif"
            # Save template with the unique fasta_id name directly into the base template dir
            target_template_name = f"{fasta_id}{template_file_extension}"
            # **Minimal Change:** Use template_dir_path directly
            template_path = os.path.join(template_dir_path, target_template_name)
            shutil.move(uploaded_template_name, template_path)
            print_success(f"Template file for {fasta_id} saved to: {template_path}")
    elif use_md_template: # If user checked the box but model is wrong or template_dir_path is None
        print_warning(f"Template usage selected, but the chosen model is not a '+Templates' variant or template dir wasn't created. Skipping template upload for sequence {fasta_id}.")
    # --- End Template Handling ---

# --- Finalize: Write CSV and Store Paths ---
if not csv_data:
     print_error("No valid sequences were processed. Cannot proceed.")
     raise ValueError("No sequences to run inference on.")

input_csv_path = os.path.join(run_dir, "input.csv")
input_df = pd.DataFrame(csv_data)
input_df.to_csv(input_csv_path, index=False)
print_success(f"\nInput CSV containing {len(csv_data)} sequence(s) saved to: {input_csv_path}")
print(input_df.head()) # Show the first few entries of the CSV

# --- Store paths for Cell 6 ---
# **Minimal Change:** Store using the exact names Cell 6 expects.
# No renaming needed at the end.
%store input_csv_path
%store msa_dir
# Storing the base MSA directory path under the name 'msa_dir'
%store template_dir_path
# Storing the base template directory path (or None) under the name 'template_dir_path'
%store output_dir
%store model_mode_for_predict

print_success("\nInput file preparation complete for all sequences. Proceed to run inference.")

# --- NO RENAMING NEEDED ---
# The variables msa_dir and template_dir_path were used directly and stored with the correct names.


--- Processing Sequence 1/1 ---






  0%|          | 0/150 [elapsed: 00:00 remaining: ?]
SUBMIT:   0%|          | 0/150 [elapsed: 00:00 remaining: ?]
COMPLETE:   0%|          | 0/150 [elapsed: 00:00 remaining: ?]
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:00 remaining: 00:00]
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:01 remaining: 00:00]



     name                                             seqres
0  1UBQ_1  MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFA...
Stored 'input_csv_path' (str)
Stored 'msa_dir' (str)
Stored 'template_dir_path' (NoneType)
Stored 'output_dir' (str)
Stored 'model_mode_for_predict' (str)


In [9]:
#@title 6. Run AlphaFlow / ESMFlow Inference

import os
import subprocess
import time
from IPython.display import display, HTML # For styled print functions

# --- Helper function to display styled messages ---
def print_info(message): display(HTML(f'<font color="blue"><b>INFO:</b> {message}</font>'))
def print_warning(message): display(HTML(f'<font color="orange"><b>WARNING:</b> {message}</font>'))
def print_error(message): display(HTML(f'<font color="red"><b>ERROR:</b> {message}</font>'))
def print_success(message): display(HTML(f'<font color="green"><b>SUCCESS:</b> {message}</font>'))

# --- Load variables ---
%store -r input_csv_path
%store -r msa_dir
%store -r template_dir_path
%store -r output_dir
%store -r model_weights_path
%store -r num_samples
%store -r inference_steps
%store -r tmax
%store -r self_cond
%store -r resample_msa_per_sample
%store -r distilled_model_mode
%store -r noisy_first
%store -r no_diffusion
%store -r model_mode_for_predict # 'alphafold' or 'esmfold'

# --- Construct the command for predict.py ---
# Navigate to the code directory
os.chdir('/content/alphaflow_code')

# Get the correct python executable from the conda env
conda_python_path = "/usr/local/envs/alphaflow/bin/python"

# Check if predict.py exists
predict_script_path = "/content/alphaflow_code/predict.py"
if not os.path.exists(predict_script_path):
    print_error(f"Inference script not found at {predict_script_path}. Check repository structure.")
    raise FileNotFoundError("predict.py not found.")

cmd = [
    conda_python_path, predict_script_path,
    f"--mode={model_mode_for_predict}",
    f"--input_csv={input_csv_path}",
    f"--weights={model_weights_path}",
    f"--samples={num_samples}",
    f"--outpdb={output_dir}",
    f"--steps={inference_steps}",
    f"--tmax={tmax}",
]

# Add AlphaFlow specific args
if model_mode_for_predict == "alphafold":
    if msa_dir and os.path.isdir(msa_dir): # Check if dir exists
        cmd.append(f"--msa_dir={msa_dir}")
    else:
        # If MSA generation failed or upload was skipped, this will error
        print_error("AlphaFlow mode selected, but MSA directory is missing or invalid.")
        raise FileNotFoundError("MSA directory not prepared correctly.")
    if self_cond:
        cmd.append("--self_cond")
    if resample_msa_per_sample:
        cmd.append("--resample")

# Add template arg if applicable
if template_dir_path and os.path.isdir(template_dir_path): # Check if dir exists
    cmd.append(f"--templates_dir={template_dir_path}")

# Add distilled args if applicable
# Check based on selected model name OR the checkbox
%store -r selected_model_name
is_distilled_model_selected = "distilled" in selected_model_name
if distilled_model_mode or is_distilled_model_selected:
     if noisy_first: cmd.append("--noisy_first")
     if no_diffusion: cmd.append("--no_diffusion")
     print_info("Distilled model flags added (based on checkbox or model name).")


# --- Execute the command ---
print("--- Running Inference ---")
command_str = " ".join(cmd)
print(f"Command: {command_str}\n")

start_time = time.time()
# Use subprocess.Popen for real-time output
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, universal_newlines=True)

# Print output line by line
while True:
    line = process.stdout.readline()
    if not line:
        break
    print(line.strip()) # Print the output from predict.py

process.wait() # Wait for the process to complete
elapsed_time = time.time() - start_time

if process.returncode == 0:
    print_success(f"\nInference finished successfully! (Took {elapsed_time:.2f} seconds)")
else:
    print_error(f"\nInference failed with return code {process.returncode}. Check logs above.")
    # Consider raising error: raise RuntimeError("Inference failed.")

# Change back to /content directory
os.chdir('/content')

no stored variable or alias #
no stored variable or alias alphafold
no stored variable or alias or
no stored variable or alias esmfold
--- Running Inference ---
Command: /usr/local/envs/alphaflow/bin/python /content/alphaflow_code/predict.py --mode=alphafold --input_csv=/content/1UBQ/input.csv --weights=/content/model_weights/alphaflow_md_base_202402.pt --samples=30 --outpdb=/content/1UBQ/output --steps=10 --tmax=1.0 --msa_dir=/content/1UBQ/msas --self_cond --resample

__import__("pkg_resources").declare_namespace(__name__)
2025-06-15 08:39:54,162 [f610dcf98224:16011] [INFO] Loading the model
2025-06-15 08:40:02,629 [f610dcf98224:16011] [INFO] Model has been loaded

0%|          | 0/30 [00:00<?, ?it/s]
3%|▎         | 1/30 [00:38<18:29, 38.27s/it]
7%|▋         | 2/30 [01:15<17:28, 37.44s/it]
10%|█         | 3/30 [01:53<17:01, 37.83s/it]
13%|█▎        | 4/30 [02:32<16:39, 38.46s/it]
17%|█▋        | 5/30 [03:11<16:03, 38.53s/it]
20%|██        | 6/30 [03:50<15:29, 38.71s/it]
23%|██▎       

In [11]:
#@title 7. Visualize Sample & Download Results
# @markdown  **Note:** When visualizing in PyMol don't forget to intra_fit >object<

import sys

!{sys.executable} -m pip install py3Dmol
import py3Dmol

#@title 7. Visualize Samples Interactively & Download Results

import os
import glob
import shutil
from google.colab import files
import py3Dmol
from IPython.display import display, clear_output, HTML

# --- Helper function to display styled messages ---
def print_info(message): display(HTML(f'<font color="blue"><b>INFO:</b> {message}</font>'))
def print_warning(message): display(HTML(f'<font color="orange"><b>WARNING:</b> {message}</font>'))
def print_error(message): display(HTML(f'<font color="red"><b>ERROR:</b> {message}</font>'))
def print_success(message): display(HTML(f'<font color="green"><b>SUCCESS:</b> {message}</font>'))

# --- Load variables ---
try:
    %store -r jobname
    %store -r output_dir
except KeyError:
    print_error("Could not load 'jobname' or 'output_dir'. Please re-run previous cells.")
    # jobname = "alphaflow_sample_1"
    # output_dir = f"/content/{jobname}/output"

# --- Find output PDB files ---
print(f"--- Searching for PDB files in {output_dir} ---")
pdb_files = []
if os.path.isdir(output_dir):
    pdb_files = sorted(glob.glob(os.path.join(output_dir, '*.pdb')))
num_pdb_files = len(pdb_files)

if not pdb_files:
    print_warning("No PDB output files found in the output directory.")
else:
    print_info(f"Found {num_pdb_files} PDB file(s).")

    # --- Function to display a PDB file ---
    def show_pdb_styled(pdb_path):
        try:
            view = py3Dmol.view(width=800, height=600)
            with open(pdb_path, 'r') as pdb_file:
                pdb_data = pdb_file.read()
            view.addModel(pdb_data, 'pdb')
            view.setStyle({'cartoon': {'colorscheme': 'chain'}})
            view.zoomTo()
            view.show()
            print_success(f"Displayed: {os.path.basename(pdb_path)}")
        except Exception as e:
            print_error(f"Error displaying PDB: {e}")

    # --- Show only the first sample ---
    print("\n--- Displaying First Sample ---")
    show_pdb_styled(pdb_files[0])

# --- Package results ---
print("\n--- Packaging results for download ---")
archive_base_name = f"/content/{jobname}_alphaflow_output"
output_archive_filename = f"{archive_base_name}.zip"

if os.path.isdir(output_dir) and len(os.listdir(output_dir)) > 0:
    try:
        shutil.make_archive(archive_base_name, 'zip', output_dir)
        print_success(f"Output files zipped into: {output_archive_filename}")

        print("\nInitiating download of results archive...")
        files.download(output_archive_filename)
    except Exception as e:
        print_error(f"Failed to package or download results: {e}")
else:
    print_warning(f"Output directory '{output_dir}' is empty or does not exist. Skipping packaging and download.")

print("\nProcessing complete.")

--- Searching for PDB files in /content/1UBQ/output ---



--- Displaying First Sample ---



--- Packaging results for download ---



Initiating download of results archive...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


Processing complete.
