# Hackable AlphaFold 3 without Docker or MSAs!

This Jupyter notebook provides a lightweight, hackable way to run AlphaFold 3. Experiment with structure prediction directly on your laptop or single GPU server, without the overhead of massive MSA databases or Docker. Just define your sequences and start predicting!

**This notebook allows you to:**
- Easily change input parameters and the molecular system definition.
- Step through the configuration and execution.
- Inspect intermediate variables by adding print statements or debugging code.
- Choose whether to run the data pipeline, inference, or both.

In [15]:
# Environment Variables Reminder
# IMPORTANT: Set necessary environment variables for JAX/XLA before running.
# These are usually set in your shell environment *before* launching Jupyter.
# If you need to set them for the current session (less ideal), you can use os.environ:

import os

# Example for NVIDIA A100/H100 (Compute Capability 8.0+):
os.environ['XLA_FLAGS'] = os.environ.get('XLA_FLAGS', "") + " --xla_gpu_enable_triton_gemm=false"
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "true"
os.environ['XLA_CLIENT_MEM_FRACTION'] = "0.95"

# Example for NVIDIA V100 (Compute Capability 7.x):
# os.environ['XLA_FLAGS'] = os.environ.get('XLA_FLAGS', "") + " --xla_disable_hlo_passes=custom-kernel-fusion-rewriter"

# It's best to set these *before* JAX initializes its backend.
# Restart the kernel if you change these here and JAX has already been imported.

from IPython.display import display, Markdown
display(Markdown(f"**Current XLA_FLAGS:** `{os.environ.get('XLA_FLAGS')}`"))
display(Markdown(
    "**Note:** If you need to change `XLA_FLAGS` for your GPU, "
    "it's best to set it in your shell *before* starting Jupyter "
    "or restart the kernel after setting it with `os.environ`."
))


**Current XLA_FLAGS:** ` --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_triton_gemm=false`

**Note:** If you need to change `XLA_FLAGS` for your GPU, it's best to set it in your shell *before* starting Jupyter or restart the kernel after setting it with `os.environ`.

# Imports and Initial Setup

In [16]:
import os
import pathlib
import shutil
import multiprocessing
import datetime
import functools
import json
import textwrap
import time
import csv
import dataclasses
from typing import Sequence, Callable, overload

# AlphaFold 3 specific imports
from alphafold3.common import folding_input
from alphafold3.common import resources
from alphafold3.constants import chemical_components
import alphafold3.cpp
from alphafold3.data import featurisation
from alphafold3.data import pipeline
from alphafold3.jax.attention import attention
from alphafold3.model import features
from alphafold3.model import model
from alphafold3.model import params
from alphafold3.model import post_processing
from alphafold3.model.components import utils

import haiku as hk
import jax
from jax import numpy as jnp
import numpy as np

# --- JAX Configuration (Optional) ---
# Set JAX platform (e.g., 'gpu' or 'cpu').
# jax.config.update('jax_platform_name', 'gpu')
# Enable x64 for certain operations if needed, though AF3 typically uses float32.
# jax.config.update('jax_enable_x64', True)

print(f"JAX devices: {jax.devices()}")


JAX devices: [CudaDevice(id=0)]


# Configuration Parameters (Mimicking command-line flags)

In [17]:
# --- Input and Output Paths ---
# If providing input directly in this cell, `json_path_notebook` can be None.
# If loading from a directory of JSONs, set `input_dir_notebook`.
json_path_notebook = None  # Or set to a file path string e.g., "my_input.json"
input_dir_notebook = None  # Or set to a directory path string e.g., "my_af_inputs/"
output_dir_notebook = "af_output_notebook" # REQUIRED: Where all outputs will be saved.

# Path to the AlphaFold 3 model parameters
_DEFAULT_MODEL_DIR_NB = pathlib.Path.cwd() / 'models'
model_dir_notebook = _DEFAULT_MODEL_DIR_NB.as_posix()

# --- Control which stages to run ---
run_data_pipeline_notebook = True  # Set to False if MSAs/templates are in input_dict or JSON
run_inference_notebook = True    # Set to False to only run data pipeline

# --- Binary Paths (for data pipeline if run_data_pipeline_notebook is True) ---
jackhmmer_binary_path_notebook = shutil.which('jackhmmer')
nhmmer_binary_path_notebook = shutil.which('nhmmer')
hmmalign_binary_path_notebook = shutil.which('hmmalign')
hmmsearch_binary_path_notebook = shutil.which('hmmsearch')
hmmbuild_binary_path_notebook = shutil.which('hmmbuild')

# --- Database Paths (for data pipeline if run_data_pipeline_notebook is True) ---
_DEFAULT_DB_DIR_NB = pathlib.Path(os.environ.get('HOME', '.')) / 'public_databases'
db_dirs_notebook = [_DEFAULT_DB_DIR_NB.as_posix()] # List of paths to search for DBs

# These paths use a placeholder ${DB_DIR} which will be replaced by one of the paths in db_dirs_notebook.
small_bfd_database_path_notebook = '${DB_DIR}/bfd-first_non_consensus_sequences.fasta'
mgnify_database_path_notebook = '${DB_DIR}/mgy_clusters_2022_05.fa'
uniprot_cluster_annot_database_path_notebook = '${DB_DIR}/uniprot_all_2021_04.fa'
uniref90_database_path_notebook = '${DB_DIR}/uniref90_2022_05.fa'
ntrna_database_path_notebook = '${DB_DIR}/nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta'
rfam_database_path_notebook = '${DB_DIR}/rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta'
rna_central_database_path_notebook = '${DB_DIR}/rnacentral_active_seq_id_90_cov_80_linclust.fasta'
pdb_database_path_notebook = '${DB_DIR}/mmcif_files' # Directory of mmCIFs
seqres_database_path_notebook = '${DB_DIR}/pdb_seqres_2022_09_28.fasta'

# --- CPU Counts for MSA tools (for data pipeline) ---
jackhmmer_n_cpu_notebook = min(multiprocessing.cpu_count(), 8)
nhmmer_n_cpu_notebook = min(multiprocessing.cpu_count(), 8)

# --- Data Pipeline Configuration ---
resolve_msa_overlaps_notebook = True
max_template_date_notebook = '2021-09-30' # YYYY-MM-DD
conformer_max_iterations_notebook = None # Use RDKit default

# --- JAX Inference Performance Tuning ---
jax_compilation_cache_dir_notebook = None # Or path to a JAX cache directory e.g., "./jax_cache"
gpu_device_notebook = 0 # Index of the GPU to use
buckets_notebook = [256, 512, 768, 1024, 1280, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120] # List of ints
flash_attention_implementation_notebook: attention.Implementation = 'xla' # 'triton', 'cudnn', or 'xla'
num_recycles_notebook = 10
num_diffusion_samples_notebook = 5
num_seeds_notebook = None # If int, generates N seeds from the first seed in input. If None, uses seeds from input.

# --- Output Controls ---
save_embeddings_notebook = False
save_distogram_notebook = False
force_output_dir_notebook = False # If True, reuses output_dir_notebook even if non-empty.


# Input Definition (Define your input as a Python dictionary here)

In [18]:

# This example is the same as the 2PV7 example from the README
# Modify this dictionary to define your prediction target.
input_dict = {
  "name": "2PV7_notebook_example", # It's good to give a unique name
  "sequences": [
    {
      "protein": {
        # For multiple identical protein chains, provide a list of IDs: e.g., "id": ["A", "B"],
        "id": "A",
        "sequence": "GMRESYANENQFGFKTINSDIHKIVIVGGYGKLGGLFARYLRASGYPISILDREDWAVAESILANADVVIVSVPINLTLETIERLKPYLTENMLLADLTSVKREPLAKMLEVHTGAVLGLHPMFGADIASMAKQVVVRCDGRFPERYEWLLEQIQIWGAKIYQTNATEHDHNMTYIQALRHFSTFANGLHLSKQPINLANLLALSSPIYRLELAMIGRLFAQDAELYADIIMDKSENLAVIETLKQTYDEALTFFENNDRQGFIDAFHKVRDWFGDYSEQFLKESRQLLQQANDLKQG",
        # Optional: MSA-free, template-free prediction
        "unpairedMsa": "",
        "pairedMsa": "",
        "templates": [],
        # Optional: Provide precomputed MSAs or templates if run_data_pipeline_notebook is False
        # "unpairedMsa": "CONTENTS_OF_MSA_FILE_OR_PATH_TO_MSA_FILE",
        # "pairedMsa": "CONTENTS_OF_MSA_FILE_OR_PATH_TO_MSA_FILE",
        # "templates": [
        #   {
        #     "mmcif": "CONTENTS_OF_CIF_FILE_OR_PATH_TO_CIF_FILE",
        #     "queryIndices": [0, 1, 2],    # 0-indexed query residue indices matching template
        #     "templateIndices": [10, 11, 12] # 0-indexed template residue indices
        #   }
        # ]
        # Optional: PTMs
        # "modifications": [
        #    {"ptmType": "ACE", "ptmPosition": 1}, # N-terminal acetylation at residue 1
        #    {"ptmType": "PHO", "ptmPosition": 15} # Phosphorylation at residue 15
        # ]
      }
    }
    # --- Other examples of sequence types ---
    # {
    #   "ligand": {
    #     "id": "L",
    #     "ccdCodes": ["ATP"] # Can be a list for multi-component ligands
    #     # Or use SMILES:
    #     # "smiles": "Cc1ccccc1"
    #   }
    # },
    # {
    #   "rna": {
    #     "id": "R",
    #     "sequence": "AUGGCUAG",
    #     # "modifications": [
    #     #    {"modificationType": "PSU", "basePosition": 3} # Pseudouridine at base 3
    #     # ]
    #   }
    # },
    # {
    #   "dna": {
    #     "id": "D",
    #     "sequence": "ATGCGTTA",
    #     # "modifications": [...]
    #   }
    # }
  ],
  "modelSeeds": [1], # A list of one or more integer seeds. If num_seeds_notebook is set, only the first seed is used as a base.
  "dialect": "alphafold3",
  "version": folding_input.JSON_VERSION # Uses the latest version from the library
  # Optional: Define inter-chain bonds
  # "bondedAtomPairs": [
  #   [["A", 10, "SG"], ["L", 1, "C1"]] # Bond between CYS 10 (chain A) SG and Ligand (chain L) C1
  # ]
}

# Optional: Define a user-specific Chemical Component Dictionary (CCD)
# This should be a string in mmCIF format.
user_ccd_string_notebook = None
# Example:
# user_ccd_string_notebook = """
# data_MYLIGAND
# _chem_comp.id                                    MYLIGAND
# _chem_comp.name                                  "MY CUSTOM LIGAND"
# _chem_comp.type                                  NON-POLYMER
# _chem_comp.formula                               "C6 H12 O6"
# _chem_comp.mon_nstd_parent_comp_id               ?
# _chem_comp.pdbx_synonyms                         ?
# _chem_comp.formula_weight                        180.156
# # ... (atoms and bonds for MYLIGAND) ...
# """

# --- Convert Python dict to folding_input.Input object ---
# This section prepares the `current_fold_input` object which will be processed.
# If `json_path_notebook` or `input_dir_notebook` is set later, this `current_fold_input` might be overridden.
current_fold_input = None
if input_dict:
    try:
        input_json_string = json.dumps(input_dict)
        current_fold_input = folding_input.Input.from_json(input_json_string)
        if user_ccd_string_notebook:
            current_fold_input = dataclasses.replace(current_fold_input, user_ccd=user_ccd_string_notebook)

        # If num_seeds_notebook is set globally, adjust the current_fold_input
        if num_seeds_notebook is not None and current_fold_input:
            if len(current_fold_input.rng_seeds) != 1:
                raise ValueError(
                    "If num_seeds_notebook is set, the input_dict should contain only one seed in 'modelSeeds'."
                )
            print(f"Expanding input '{current_fold_input.name}' to {num_seeds_notebook} seeds based on global setting.")
            current_fold_input = current_fold_input.with_multiple_seeds(num_seeds_notebook)

        print(f"Successfully created input object: {current_fold_input.name}")
        print(f"  Number of chains: {len(current_fold_input.chains)}")
        print(f"  RNG Seeds: {current_fold_input.rng_seeds}")
    except Exception as e:
        print(f"Error creating folding_input.Input from input_dict: {e}")
        current_fold_input = None # Ensure it's None if creation fails
else:
    print("No input_dict provided in this cell. Ensure json_path_notebook or input_dir_notebook is set if you intend to run.")



Successfully created input object: 2PV7_notebook_example
  Number of chains: 1
  RNG Seeds: (1,)


# Helper Functions & Classes (Ported from run_alphafold.py)

In [19]:
# --- ModelRunner Class ---
class ModelRunner:
  """Helper class to run structure prediction stages."""
  def __init__(
      self,
      config: model.Model.Config,
      device: jax.Device,
      model_dir: pathlib.Path,
  ):
    self._model_config = config
    self._device = device
    self._model_dir = model_dir

  @functools.cached_property
  def model_params(self) -> hk.Params:
    """Loads model parameters from the model directory."""
    return params.get_model_haiku_params(model_dir=self._model_dir)

  @functools.cached_property
  def _model(
      self,
  ) -> Callable[[jnp.ndarray, features.BatchDict], model.ModelResult]:
    """Loads model parameters and returns a jitted model forward pass."""
    @hk.transform
    def forward_fn(batch):
      return model.Model(self._model_config)(batch)
    return functools.partial(
        jax.jit(forward_fn.apply, device=self._device), self.model_params
    )

  def run_inference(
      self, featurised_example: features.BatchDict, rng_key: jnp.ndarray
  ) -> model.ModelResult:
    """Computes a forward pass of the model on a featurised example."""
    featurised_example = jax.device_put(
        jax.tree_util.tree_map(
            jnp.asarray, utils.remove_invalidly_typed_feats(featurised_example)
        ),
        self._device,
    )
    result = self._model(rng_key, featurised_example)
    result = jax.tree.map(np.asarray, result)
    result = jax.tree.map(
        lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x,
        result,
    )
    result = dict(result)
    identifier = self.model_params['__meta__']['__identifier__'].tobytes()
    result['__identifier__'] = identifier
    return result

  def extract_inference_results(
      self,
      batch: features.BatchDict,
      result: model.ModelResult,
      target_name: str,
  ) -> list[model.InferenceResult]:
    """Extracts inference results from model outputs."""
    return list(
        model.Model.get_inference_result(
            batch=batch, result=result, target_name=target_name
        )
    )

  def extract_embeddings(
      self, result: model.ModelResult, num_tokens: int
  ) -> dict[str, np.ndarray] | None:
    """Extracts embeddings from model outputs."""
    embeddings = {}
    if 'single_embeddings' in result:
      embeddings['single_embeddings'] = result['single_embeddings'][
          :num_tokens
      ].astype(np.float16)
    if 'pair_embeddings' in result:
      embeddings['pair_embeddings'] = result['pair_embeddings'][
          :num_tokens, :num_tokens
      ].astype(np.float16)
    return embeddings or None

  def extract_distogram(
      self, result: model.ModelResult, num_tokens: int
  ) -> np.ndarray | None:
    """Extracts distogram from model outputs."""
    if 'distogram' not in result['distogram']:
      return None
    distogram = result['distogram']['distogram'][:num_tokens, :num_tokens, :]
    return distogram

# --- ResultsForSeed Dataclass ---
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class ResultsForSeed:
  seed: int
  inference_results: Sequence[model.InferenceResult]
  full_fold_input: folding_input.Input
  embeddings: dict[str, np.ndarray] | None = None
  distogram: np.ndarray | None = None

# --- make_model_config_notebook ---
def make_model_config_notebook(
    *,
    flash_attention_implementation: attention.Implementation = 'triton',
    num_diffusion_samples: int = 5,
    num_recycles: int = 10,
    return_embeddings: bool = False,
    return_distogram: bool = False,
) -> model.Model.Config:
  config = model.Model.Config()
  config.global_config.flash_attention_implementation = flash_attention_implementation
  config.heads.diffusion.eval.num_samples = num_diffusion_samples
  config.num_recycles = num_recycles
  config.return_embeddings = return_embeddings
  config.return_distogram = return_distogram
  return config

# --- predict_structure_notebook ---
def predict_structure_notebook(
    fold_input: folding_input.Input,
    model_runner: ModelRunner,
    buckets: Sequence[int] | None = None,
    ref_max_modified_date: datetime.date | None = None,
    conformer_max_iterations: int | None = None,
    resolve_msa_overlaps: bool = True,
) -> Sequence[ResultsForSeed]:
  print(f'Featurising data with {len(fold_input.rng_seeds)} seed(s)...')
  featurisation_start_time = time.time()
  ccd = chemical_components.cached_ccd(user_ccd=fold_input.user_ccd)
  featurised_examples = featurisation.featurise_input(
      fold_input=fold_input,
      buckets=buckets,
      ccd=ccd,
      verbose=True,
      ref_max_modified_date=ref_max_modified_date,
      conformer_max_iterations=conformer_max_iterations,
      resolve_msa_overlaps=resolve_msa_overlaps,
  )
  print(f'Featurising data took {time.time() - featurisation_start_time:.2f} seconds.')

  print(f'Running model inference for {len(fold_input.rng_seeds)} seed(s)...')
  all_inference_start_time = time.time()
  all_results_for_seeds = []
  for seed, example in zip(fold_input.rng_seeds, featurised_examples):
    print(f'  Running inference for seed {seed}...')
    inference_start_time = time.time()
    rng_key = jax.random.PRNGKey(seed)
    result = model_runner.run_inference(example, rng_key)
    print(f'  Inference for seed {seed} took {time.time() - inference_start_time:.2f} seconds.')

    print(f'  Extracting results for seed {seed}...')
    extract_time = time.time()
    inference_results_list = model_runner.extract_inference_results(
        batch=example, result=result, target_name=fold_input.name
    )
    num_tokens = len(inference_results_list[0].metadata['token_chain_ids'])
    embeddings = model_runner.extract_embeddings(result=result, num_tokens=num_tokens)
    distogram = model_runner.extract_distogram(result=result, num_tokens=num_tokens)
    print(f'  Extraction for seed {seed} took {time.time() - extract_time:.2f} seconds.')

    all_results_for_seeds.append(
        ResultsForSeed(
            seed=seed,
            inference_results=inference_results_list,
            full_fold_input=fold_input, # Store the (potentially data-pipelined) input
            embeddings=embeddings,
            distogram=distogram,
        )
    )
  print(f'Total model inference and extraction took {time.time() - all_inference_start_time:.2f} seconds.')
  return all_results_for_seeds

# --- write_fold_input_json_notebook ---
def write_fold_input_json_notebook(
    fold_input: folding_input.Input,
    output_dir: os.PathLike[str] | str,
) -> None:
  os.makedirs(output_dir, exist_ok=True)
  path = os.path.join(output_dir, f'{fold_input.sanitised_name()}_data.json')
  print(f'Writing model input JSON to {path}')
  with open(path, 'wt') as f:
    f.write(fold_input.to_json())

# --- write_outputs_notebook ---
def write_outputs_notebook(
    all_results_for_seeds: Sequence[ResultsForSeed],
    output_dir: os.PathLike[str] | str,
    job_name: str,
) -> None:
  ranking_scores = []
  max_ranking_score = None
  max_ranking_result = None
  output_terms_path = pathlib.Path(alphafold3.cpp.__file__).parent / 'OUTPUT_TERMS_OF_USE.md'
  output_terms = output_terms_path.read_text() if output_terms_path.exists() else "# AlphaFold 3 Output Terms of Use\n..."


  os.makedirs(output_dir, exist_ok=True)
  for results_for_seed_item in all_results_for_seeds:
    seed = results_for_seed_item.seed
    for sample_idx, result_item in enumerate(results_for_seed_item.inference_results):
      sample_dir = os.path.join(output_dir, f'seed-{seed}_sample-{sample_idx}')
      os.makedirs(sample_dir, exist_ok=True)
      post_processing.write_output(
          inference_result=result_item,
          output_dir=sample_dir,
          name=f'{job_name}_seed-{seed}_sample-{sample_idx}',
      )
      ranking_score = float(result_item.metadata['ranking_score'])
      ranking_scores.append((seed, sample_idx, ranking_score))
      if max_ranking_score is None or ranking_score > max_ranking_score:
        max_ranking_score = ranking_score
        max_ranking_result = result_item

    if embeddings_item := results_for_seed_item.embeddings:
      embeddings_dir = os.path.join(output_dir, f'seed-{seed}_embeddings')
      os.makedirs(embeddings_dir, exist_ok=True)
      post_processing.write_embeddings(
          embeddings=embeddings_item,
          output_dir=embeddings_dir,
          name=f'{job_name}_seed-{seed}',
      )

    if (distogram_item := results_for_seed_item.distogram) is not None:
      distogram_dir = os.path.join(output_dir, f'seed-{seed}_distogram')
      os.makedirs(distogram_dir, exist_ok=True)
      distogram_path = os.path.join(distogram_dir, f'{job_name}_seed-{seed}_distogram.npz')
      with open(distogram_path, 'wb') as f:
        np.savez_compressed(f, distogram=distogram_item.astype(np.float16))

  if max_ranking_result is not None:
    post_processing.write_output(
        inference_result=max_ranking_result,
        output_dir=output_dir,
        terms_of_use=output_terms,
        name=job_name,
    )
    with open(os.path.join(output_dir, f'{job_name}_ranking_scores.csv'), 'wt') as f:
      writer = csv.writer(f)
      writer.writerow(['seed', 'sample', 'ranking_score'])
      writer.writerows(ranking_scores)

# --- replace_db_dir_notebook ---
import string # Ensure string is imported
def replace_db_dir_notebook(path_with_db_dir: str, db_dirs: Sequence[str]) -> str:
  template = string.Template(path_with_db_dir)
  if 'DB_DIR' in template.get_identifiers():
    for db_dir in db_dirs:
      path = template.substitute(DB_DIR=db_dir)
      if os.path.exists(path):
        return path
    raise FileNotFoundError(f'{path_with_db_dir} with ${{DB_DIR}} not found in any of {db_dirs}.')
  if not os.path.exists(path_with_db_dir):
    raise FileNotFoundError(f'{path_with_db_dir} does not exist.')
  return path_with_db_dir

# --- process_fold_input_notebook (with overloads) ---
@overload
def process_fold_input_notebook(
    fold_input: folding_input.Input,
    data_pipeline_config: pipeline.DataPipelineConfig | None,
    model_runner: None,
    output_dir: os.PathLike[str] | str,
    buckets: Sequence[int] | None = None,
    ref_max_modified_date: datetime.date | None = None,
    conformer_max_iterations: int | None = None,
    resolve_msa_overlaps: bool = True,
    force_output_dir: bool = False,
) -> folding_input.Input: ...

@overload
def process_fold_input_notebook(
    fold_input: folding_input.Input,
    data_pipeline_config: pipeline.DataPipelineConfig | None,
    model_runner: ModelRunner,
    output_dir: os.PathLike[str] | str,
    buckets: Sequence[int] | None = None,
    ref_max_modified_date: datetime.date | None = None,
    conformer_max_iterations: int | None = None,
    resolve_msa_overlaps: bool = True,
    force_output_dir: bool = False,
) -> Sequence[ResultsForSeed]: ...

def process_fold_input_notebook(
    fold_input: folding_input.Input,
    data_pipeline_config: pipeline.DataPipelineConfig | None,
    model_runner: ModelRunner | None,
    output_dir: os.PathLike[str] | str,
    buckets: Sequence[int] | None = None,
    ref_max_modified_date: datetime.date | None = None,
    conformer_max_iterations: int | None = None,
    resolve_msa_overlaps: bool = True,
    force_output_dir: bool = False,
) -> folding_input.Input | Sequence[ResultsForSeed]:
  print(f'\nRunning fold job {fold_input.name}...')
  if not fold_input.chains:
    raise ValueError('Fold input has no chains.')

  if (not force_output_dir and os.path.exists(output_dir) and os.listdir(output_dir)):
    new_output_dir = f'{output_dir}_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}'
    print(f'Output will be written in {new_output_dir} since {output_dir} is non-empty.')
    output_dir = new_output_dir
  else:
    print(f'Output will be written in {output_dir}')

  processed_fold_input = fold_input # Keep a reference to the potentially modified input
  if data_pipeline_config is None:
    print('Skipping data pipeline...')
  else:
    print('Running data pipeline...')
    data_pipeline_start_time = time.time()
    processed_fold_input = pipeline.DataPipeline(data_pipeline_config).process(fold_input)
    print(f'Data pipeline took {time.time() - data_pipeline_start_time:.2f} seconds.')

  write_fold_input_json_notebook(processed_fold_input, output_dir) # Write the (possibly augmented) input

  if model_runner is None:
    print('Skipping model inference...')
    final_output = processed_fold_input
  else:
    all_inference_results = predict_structure_notebook(
        fold_input=processed_fold_input, # Use the processed input
        model_runner=model_runner,
        buckets=buckets,
        ref_max_modified_date=ref_max_modified_date,
        conformer_max_iterations=conformer_max_iterations,
        resolve_msa_overlaps=resolve_msa_overlaps,
    )
    print(f'Writing outputs for {len(processed_fold_input.rng_seeds)} seed(s)...')
    write_outputs_notebook(
      all_results_for_seeds=all_inference_results,
        output_dir=output_dir,
        job_name=processed_fold_input.sanitised_name(),
    )
    final_output = all_inference_results

  print(f'Fold job {processed_fold_input.name} done, output written to {os.path.abspath(output_dir)}\n')
  return final_output


# Main Execution Logic

In [20]:

# --- JAX Compilation Cache ---
if jax_compilation_cache_dir_notebook:
    jax.config.update('jax_compilation_cache_dir', jax_compilation_cache_dir_notebook)
    print(f"JAX compilation cache enabled at: {jax_compilation_cache_dir_notebook}")

# --- Validate Prerequisite Configurations ---
if not run_inference_notebook and not run_data_pipeline_notebook:
    raise ValueError("At least one of run_inference_notebook or run_data_pipeline_notebook must be True.")

# --- Create Output Directory ---
# The actual job-specific output directory will be a subdirectory of this.
# This global output_dir_notebook is where all job folders will reside.
try:
    os.makedirs(output_dir_notebook, exist_ok=True)
    print(f"Global output directory: {os.path.abspath(output_dir_notebook)}")
except OSError as e:
    print(f"Failed to create global output directory {output_dir_notebook}: {e}")
    raise

# --- Setup Data Pipeline Config (if running data pipeline) ---
data_pipeline_config_obj = None
if run_data_pipeline_notebook:
    print("Setting up Data Pipeline Configuration...")
    max_template_date_obj = datetime.date.fromisoformat(max_template_date_notebook)
    expand_path = lambda p: replace_db_dir_notebook(p, db_dirs_notebook)
    try:
        data_pipeline_config_obj = pipeline.DataPipelineConfig(
            jackhmmer_binary_path=jackhmmer_binary_path_notebook,
            nhmmer_binary_path=nhmmer_binary_path_notebook,
            hmmalign_binary_path=hmmalign_binary_path_notebook,
            hmmsearch_binary_path=hmmsearch_binary_path_notebook,
            hmmbuild_binary_path=hmmbuild_binary_path_notebook,
            small_bfd_database_path=expand_path(small_bfd_database_path_notebook),
            mgnify_database_path=expand_path(mgnify_database_path_notebook),
            uniprot_cluster_annot_database_path=expand_path(uniprot_cluster_annot_database_path_notebook),
            uniref90_database_path=expand_path(uniref90_database_path_notebook),
            ntrna_database_path=expand_path(ntrna_database_path_notebook),
            rfam_database_path=expand_path(rfam_database_path_notebook),
            rna_central_database_path=expand_path(rna_central_database_path_notebook),
            pdb_database_path=expand_path(pdb_database_path_notebook),
            seqres_database_path=expand_path(seqres_database_path_notebook),
            jackhmmer_n_cpu=jackhmmer_n_cpu_notebook,
            nhmmer_n_cpu=nhmmer_n_cpu_notebook,
            max_template_date=max_template_date_obj,
        )
        print("Data Pipeline Configuration ready.")
    except FileNotFoundError as e:
        print(f"ERROR: Database or binary path not found: {e}")
        print("Please ensure all database and binary paths in Cell 3 are correct and accessible.")
        data_pipeline_config_obj = None # Prevent further execution if paths are wrong
        # raise # Optionally re-raise to stop execution

# --- Setup Model Runner (if running inference) ---
model_runner_obj = None
if run_inference_notebook:
    print("Setting up Model Runner...")
    local_devices = jax.local_devices()
    if not local_devices:
        raise RuntimeError("No JAX devices found. Ensure JAX is installed correctly for your hardware (CPU/GPU).")

    # Try to use GPU if available, otherwise CPU (with a warning)
    gpu_devices = [d for d in local_devices if d.platform.upper() == 'GPU']
    if gpu_devices:
        if gpu_device_notebook >= len(gpu_devices):
            print(f"Warning: GPU device index {gpu_device_notebook} out of range. Found {len(gpu_devices)} GPUs. Using GPU 0.")
            selected_gpu_idx = 0
        else:
            selected_gpu_idx = gpu_device_notebook
        selected_device = gpu_devices[selected_gpu_idx]
        print(f"Using GPU: {selected_device}")
    else:
        print("WARNING: No GPU detected by JAX. Inference will run on CPU and may be very slow or OOM.")
        cpu_devices = [d for d in local_devices if d.platform.upper() == 'CPU']
        if not cpu_devices:
             raise RuntimeError("No CPU devices found by JAX.")
        selected_device = cpu_devices[0] # Use the first CPU device
        print(f"Using CPU: {selected_device}")


    # GPU Compute Capability Check (informational)
    if selected_device.platform.upper() == 'GPU':
        try:
            compute_capability = float(selected_device.compute_capability)
            print(f"  GPU Compute Capability: {compute_capability}")
            if compute_capability < 6.0:
                display(Markdown("**WARNING:** AlphaFold 3 ideally requires GPU compute capability 6.0 or higher."))
            elif 7.0 <= compute_capability < 8.0:
                xla_flags = os.environ.get('XLA_FLAGS', "")
                required_flag = '--xla_disable_hlo_passes=custom-kernel-fusion-rewriter'
                if required_flag not in xla_flags:
                    display(Markdown(f"**WARNING:** For GPU compute capability 7.x, `XLA_FLAGS` should include `{required_flag}`."))
                if flash_attention_implementation_notebook != 'xla':
                    display(Markdown("**WARNING:** For GPU compute capability 7.x, `flash_attention_implementation_notebook` should be set to `'xla'`."))
        except Exception as e:
            print(f"  Could not determine GPU compute capability: {e}")

    print("Building model and loading parameters...")
    model_runner_obj = ModelRunner(
        config=make_model_config_notebook(
            flash_attention_implementation=flash_attention_implementation_notebook,
            num_diffusion_samples=num_diffusion_samples_notebook,
            num_recycles=num_recycles_notebook,
            return_embeddings=save_embeddings_notebook,
            return_distogram=save_distogram_notebook,
        ),
        device=selected_device,
        model_dir=pathlib.Path(model_dir_notebook),
    )
    try:
        _ = model_runner_obj.model_params # This triggers loading
        print("Model parameters loaded successfully.")
    except FileNotFoundError as e:
        print(f"ERROR: Model parameters not found at {model_dir_notebook}: {e}")
        print("Please ensure model_dir_notebook in Cell 3 points to the correct directory.")
        model_runner_obj = None # Prevent further execution
        # raise # Optionally re-raise

# --- Determine inputs to process ---
fold_inputs_to_process = []
if input_dir_notebook:
    print(f"Loading inputs from directory: {input_dir_notebook}")
    fold_inputs_to_process.extend(
        folding_input.load_fold_inputs_from_dir(pathlib.Path(input_dir_notebook))
    )
elif json_path_notebook:
    print(f"Loading input from JSON file: {json_path_notebook}")
    fold_inputs_to_process.extend(
        folding_input.load_fold_inputs_from_path(pathlib.Path(json_path_notebook))
    )
elif current_fold_input: # Input defined in Cell 4
    print(f"Using input defined in notebook cell: {current_fold_input.name}")
    fold_inputs_to_process.append(current_fold_input)

if not fold_inputs_to_process:
    if not (run_data_pipeline_notebook and data_pipeline_config_obj is None) and \
       not (run_inference_notebook and model_runner_obj is None): # Avoid error if setup failed
        raise ValueError(
            "No inputs to process. Define `input_dict` in Cell 4, or set `json_path_notebook` or `input_dir_notebook` in Cell 3."
        )
    else:
        print("Skipping processing due to earlier setup errors.")


# --- Display AlphaFold 3 Usage Notice ---
notice_md = textwrap.dedent(f"""\
    Running AlphaFold 3. Please note that standard AlphaFold 3 model
    parameters are only available under terms of use provided at
    https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md.
    If you do not agree to these terms and are using AlphaFold 3 derived
    model parameters, cancel execution of AlphaFold 3 inference with
    CTRL-C (or stop button in Jupyter), and do not use the model parameters.
    """)
display(Markdown(notice_md))

# --- Process each input ---
num_processed_fold_inputs = 0
if (run_data_pipeline_notebook and data_pipeline_config_obj) or \
   (run_inference_notebook and model_runner_obj) or \
   (not run_data_pipeline_notebook and not run_inference_notebook): # Allow if both are false (just writes JSON)

    for f_input_to_process in fold_inputs_to_process:
        # Apply num_seeds_notebook if it was set globally and not already applied
        # This handles cases where input comes from file/dir and num_seeds_notebook is active
        if num_seeds_notebook is not None and f_input_to_process is not current_fold_input:
            if len(f_input_to_process.rng_seeds) != 1:
                 raise ValueError("If num_seeds_notebook is set, inputs from files/directories should also contain only one seed.")
            print(f"Expanding input '{f_input_to_process.name}' to {num_seeds_notebook} seeds based on global setting.")
            f_input_to_process = f_input_to_process.with_multiple_seeds(num_seeds_notebook)

        # Determine the specific output directory for this job, inside the global output_dir_notebook
        job_specific_output_dir = os.path.join(output_dir_notebook, f_input_to_process.sanitised_name())

        process_fold_input_notebook(
            fold_input=f_input_to_process,
            data_pipeline_config=data_pipeline_config_obj if run_data_pipeline_notebook else None,
            model_runner=model_runner_obj if run_inference_notebook else None,
            output_dir=job_specific_output_dir,
            buckets=tuple(buckets_notebook) if buckets_notebook else None,
            ref_max_modified_date=datetime.date.fromisoformat(max_template_date_notebook) if run_inference_notebook else None,
            conformer_max_iterations=conformer_max_iterations_notebook if run_inference_notebook else None,
            resolve_msa_overlaps=resolve_msa_overlaps_notebook if run_data_pipeline_notebook else True,
            force_output_dir=force_output_dir_notebook,
        )
        num_processed_fold_inputs += 1
else:
    print("Skipping main processing loop due to errors in Data Pipeline or Model Runner setup.")


if num_processed_fold_inputs > 0:
    display(Markdown(f"### ✅ Done running {num_processed_fold_inputs} fold job(s)."))
    display(Markdown(f"Outputs saved to subdirectories within: `{os.path.abspath(output_dir_notebook)}`"))
elif not fold_inputs_to_process:
     display(Markdown("No input was processed."))
else:
    display(Markdown("Processing was skipped due to setup issues. Please check error messages above."))



Global output directory: /home/ckj24/alphafold3/af_output_notebook
Setting up Data Pipeline Configuration...
ERROR: Database or binary path not found: ${DB_DIR}/bfd-first_non_consensus_sequences.fasta with ${DB_DIR} not found in any of ['/home/ckj24/public_databases'].
Please ensure all database and binary paths in Cell 3 are correct and accessible.
Setting up Model Runner...
Using GPU: cuda:0
  GPU Compute Capability: 8.0
Building model and loading parameters...
Model parameters loaded successfully.
Using input defined in notebook cell: 2PV7_notebook_example


Running AlphaFold 3. Please note that standard AlphaFold 3 model
parameters are only available under terms of use provided at
https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md.
If you do not agree to these terms and are using AlphaFold 3 derived
model parameters, cancel execution of AlphaFold 3 inference with
CTRL-C (or stop button in Jupyter), and do not use the model parameters.



Running fold job 2PV7_notebook_example...
Output will be written in af_output_notebook/2pv7_notebook_example
Skipping data pipeline...
Writing model input JSON to af_output_notebook/2pv7_notebook_example/2pv7_notebook_example_data.json
Featurising data with 1 seed(s)...
Featurising data with seed 1.
Featurising data with seed 1 took 4.76 seconds.
Featurising data took 4.76 seconds.
Running model inference for 1 seed(s)...
  Running inference for seed 1...
  Inference for seed 1 took 72.25 seconds.
  Extracting results for seed 1...
  Extraction for seed 1 took 0.39 seconds.
Total model inference and extraction took 72.64 seconds.
Writing outputs for 1 seed(s)...
Fold job 2PV7_notebook_example done, output written to /home/ckj24/alphafold3/af_output_notebook/2pv7_notebook_example



### ✅ Done running 1 fold job(s).

Outputs saved to subdirectories within: `/home/ckj24/alphafold3/af_output_notebook`