# AlphaGenome Finetuning Tutorial

This notebook demonstrates how to finetune an AlphaGenome model on custom
genomic tracks.

**What you'll learn:**

-   How to define custom track metadata for finetuning
-   How to set up the data pipeline for training
-   How to initialize and configure the model with new output heads
-   How to run the training loop with JAX/Haiku
-   How to use the finetuned model for inference

### Prerequisites

Install AlphaGenome Research package.

In [None]:
from IPython.display import clear_output
! PIP_NO_BINARY=pyBigWig pip install git+https://github.com/google-deepmind/alphagenome_research.git
clear_output()

## 1. Environment Setup

First, we configure TensorFlow to avoid GPU conflicts since we only use it for
data loading (JAX handles the actual training).

In [None]:
import tensorflow as tf

# Hide local GPUs/TPUs. TensorFlow only used for data loading.
tf.config.set_visible_devices([], 'GPU')
tf.config.set_visible_devices([], 'TPU')

## 2. Imports

Import the necessary libraries:

-   `alphagenome_research.finetuning` contains the finetuning utilities
-   `alphagenome_research.model` provides the model architecture and metadata
    handling
-   `alphagenome.data` provides genomic data utilities

In [None]:
import dataclasses
from datetime import datetime
import os
import pprint

from alphagenome.data import fold_intervals
from alphagenome.data import genome
from alphagenome.visualization import plot_components
from alphagenome_research.finetuning import dataset as dataset_lib
from alphagenome_research.finetuning import finetune
from alphagenome_research.model import dna_model
from alphagenome_research.model.metadata import metadata as metadata_lib
from etils import epath
import huggingface_hub
import haiku as hk
import jax
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec as P
import numpy as np
import optax
import orbax.checkpoint as ocp
import pandas as pd

## 3. Configuration

Define the key hyperparameters for finetuning:

-   `LEARNING_RATE`: Controls the step size during optimization
-   `MODEL_VERSION`: Which pretrained fold to use (FOLD_0 through FOLD_3)
-   `NUM_TRAIN_STEPS`: Number of training steps for which we optimize the model.
-   `SEQUENCE_LENGTH`: Length of input DNA sequences (1M bp = 2^20). Training
    requires at least 2**17.
-   `BATCH_SIZE`: Number of samples per device.
-   `ORGANISM`: Target organism for predictions. Harded-coded to human for now.

In [None]:
LEARNING_RATE = 1e-3
NUM_TRAIN_STEPS = 1000
MODEL_VERSION = dna_model.ModelVersion.FOLD_0
SEQUENCE_LENGTH = int(2**20)
BATCH_SIZE = 1  # Per device
ORGANISM = dna_model.Organism.HOMO_SAPIENS
SAVE_CHECKPOINT_DIR = '/tmp/checkpoint'

## 4. Track Metadata

Define which genomic tracks to finetune on. Each track requires:

-   `name`: Human-readable name.
-   `output_type`: Output type of assay (e.g., `RNA_SEQ`, `DNASE`, `CHIP_TF`,
    `ATAC`). One of `dna_model.OutputType`.
-   `strand`: Strand orientation (`+`, `-`, or `.` for unstranded)
-   `nonzero_mean`: Optional mean of non-zero values (used for normalization).
-   `file_path`: Path to the BigWig file containing the track data.

In [None]:
! pushd /tmp && curl \
  -C - \
  -Z -O https://storage.googleapis.com/alphagenome/reference/encode/hg38/ENCFF018EZY.bigWig \
  -O https://storage.googleapis.com/alphagenome/reference/encode/hg38/ENCFF904TSK.bigWig \
  -O https://storage.googleapis.com/alphagenome/reference/encode/hg38/ENCFF218CLQ.bigWig && popd

TRACK_METADATA = pd.DataFrame(
    data=[
        [
            'RNA_SEQ',
            'UBERON:0000948 total RNA-seq',
            '+',
            '/tmp/ENCFF018EZY.bigWig',
        ],
        [
            'RNA_SEQ',
            'UBERON:0000948 total RNA-seq',
            '-',
            '/tmp/ENCFF904TSK.bigWig',
        ],
        [
            'DNASE',
            'EFO:0005337 DNase-seq',
            '.',
            '/tmp/ENCFF218CLQ.bigWig',
        ],
    ],
    columns=['output_type', 'name', 'strand', 'file_path'],
)
TRACK_METADATA

### Build Output Metadata

Convert the track DataFrame into an `AlphaGenomeOutputMetadata` object that
configures the model's output heads.

In [None]:
def build_output_metadata(
    track_metadata: pd.DataFrame,
) -> metadata_lib.AlphaGenomeOutputMetadata:
  """Builds AlphaGenomeOutputMetadata from the track metadata DataFrame.

  Args:
    track_metadata: A pandas DataFrame containing metadata for the tracks,
      including 'output_type', 'name', 'strand', and 'file_path'.

  Returns:
    A dict mapping organism to AlphaGenomeOutputMetadata.
  """
  required_cols = {'file_path', 'name', 'output_type', 'strand'}
  if not required_cols.issubset(track_metadata.columns):
    raise ValueError(
        'track_metadata must have columns %s. Missing: %s.',
        required_cols,
        required_cols - set(track_metadata.columns),
    )
  metadata = {}
  for output_type, df_group in track_metadata.groupby('output_type'):
    try:
      output_type = dna_model.OutputType[str(output_type)]
    except KeyError as e:
      raise ValueError(f'Unknown output_type: {output_type}') from e
    metadata[output_type.name.lower()] = df_group
  return metadata_lib.AlphaGenomeOutputMetadata(**metadata)


output_metadata = {
    dna_model.Organism.HOMO_SAPIENS: build_output_metadata(TRACK_METADATA)
}

## 5. Data Pipeline

Set up the training data iterator. This loads genomic intervals and
corresponding track values from the specified BigWig files.

In [None]:
ds_iter = finetune.get_dataset_iterator(
    batch_size=BATCH_SIZE * jax.local_device_count(),
    sequence_length=SEQUENCE_LENGTH,
    output_metadata=output_metadata[ORGANISM],
    organism=ORGANISM,
    model_version=MODEL_VERSION,
    subset=fold_intervals.Subset.TRAIN,
)

In [None]:
batch = next(ds_iter)
pprint.pprint(jax.tree.map(np.shape, batch))

## 6. Model Initialization

Load the pretrained AlphaGenome checkpoint and initialize new output heads for
the finetuning tracks.

In [None]:
repo = f'google/alphagenome-{MODEL_VERSION.name.lower().replace('_', '-')}'
checkpoint_path = huggingface_hub.snapshot_download(repo_id=repo)
checkpointer = ocp.StandardCheckpointer()
params_base, state_base = checkpointer.restore(checkpoint_path)

### Set Up Device Mesh for Data Parallelism

In [None]:
num_devices = jax.local_device_count()
devices = mesh_utils.create_device_mesh((num_devices,))
mesh = Mesh(devices, axis_names=('data',))
data_sharding = P('data')
replicated_sharding = P()

### Initialize New Output Heads

Create the forward function configured for our finetuning tracks and initialize
the new head parameters.

In [None]:
forward_fn = finetune.get_forward_fn(output_metadata)
with jax.set_mesh(mesh):
  batch = jax.device_put(batch, data_sharding)
  params_ft, state_ft = jax.jit(
      forward_fn.init,
      in_shardings=(replicated_sharding, data_sharding),
      out_shardings=replicated_sharding,
  )(jax.random.PRNGKey(0), batch)

### Merge Pretrained Trunk with New Heads

Perform weight surgery: keep the pretrained trunk parameters and replace the
head parameters with the newly initialized ones.

In [None]:
params_ft_head = hk.data_structures.filter(
    lambda module_name, *_: 'head' in module_name, params_ft
)
params_base_no_head = hk.data_structures.filter(
    lambda module_name, *_: 'head' not in module_name, params_base
)
params = hk.data_structures.merge(params_base_no_head, params_ft_head)
state = state_base
optimizer = optax.adam(LEARNING_RATE)
opt_state = optimizer.init(params)
train_step = jax.jit(
    finetune.get_train_step(forward_fn.apply, optimizer),
    in_shardings=(
        replicated_sharding,
        replicated_sharding,
        replicated_sharding,
        data_sharding,
    ),
    out_shardings=(
        replicated_sharding,
        replicated_sharding,
        replicated_sharding,
        replicated_sharding,
    ),
)

## 7. Training Loop

Set up checkpointing and run the finetuning training loop.

### Configure Checkpoint Directory

In [None]:
path_suffix = datetime.now().strftime('%Y%m%d_%H%M%S')
checkpoint_dir = epath.Path(SAVE_CHECKPOINT_DIR) / path_suffix
checkpoint_dir.mkdir(parents=True, exist_ok=True)
checkpoint_dir = str(checkpoint_dir)
print('Saving to', checkpoint_dir)

In [None]:
checkpointer = ocp.StandardCheckpointer()


def save(weights, idx):
  ckpt_path = os.path.join(checkpoint_dir, 'checkpoint_{:05d}'.format(idx))
  print(f'Saving checkpoint to {ckpt_path}')
  checkpointer.save(ckpt_path, weights)
  checkpointer.wait_until_finished()
  return ckpt_path

### Run Training

In [None]:
loss = []
step = 0
for step in range(NUM_TRAIN_STEPS):
  try:
    batch = next(ds_iter)
  except StopIteration:
    print('Dataset exhausted')
    break
  with jax.set_mesh(mesh):
    batch = jax.device_put(batch, data_sharding)
    params, state, opt_state, scalars = train_step(
        params, state, opt_state, batch
    )
  loss.append(scalars['loss'])
  if step % 10 == 1:
    print('loss', step, loss[-1])
ckpt_path = save((params, state), step + 1)

## 8. Inference with Finetuned Model

Load the finetuned checkpoint into a `DnaModel` for inference and compare
predictions against ground truth.

In [None]:
# Load default organism settings but overwrite with fine-tuned output metadata.
default_settings_human = dna_model.default_organism_settings()[
    dna_model.Organism.HOMO_SAPIENS
]
settings_human_finetune = dataclasses.replace(
    default_settings_human,
    metadata=output_metadata[dna_model.Organism.HOMO_SAPIENS],
)
model = dna_model.create(
    ckpt_path,
    organism_settings={
        dna_model.Organism.HOMO_SAPIENS: settings_human_finetune
    },
)

### Select Test Interval

In [None]:
interval = genome.Interval(
    chromosome='chr21', start=46125238, end=46126738
).resize(SEQUENCE_LENGTH)

### Generate Predictions

In [None]:
preds = model.predict_interval(
    interval,
    requested_outputs=[dna_model.OutputType.RNA_SEQ],
    ontology_terms=None,
)

### Load Ground Truth Tracks

In [None]:
true_tracks = dataset_lib.MultiTrackExtractor(
    output_metadata[ORGANISM], sequence_length=SEQUENCE_LENGTH
).extract(interval)

### Visualize Predictions vs Ground Truth

In [None]:
def compact_dict(**kwargs):
  return {k: v for k, v in kwargs.items() if v is not None}


def plot(*, interval, predictions, targets=None):
  if targets is None:
    colors = {'pred': 'black'}
  else:
    colors = {'pred': 'black', 'true': 'red'}
  fig = plot_components.plot(
      [
          plot_components.OverlaidTracks(
              tdata=compact_dict(
                  pred=predictions.rna_seq,
                  true=dataclasses.replace(
                      predictions.rna_seq,
                      values=targets['rna_seq'].astype(np.float32),
                  )
                  if targets is not None
                  else None,
              ),
              colors=colors,
          ),
      ],
      interval=interval.resize(int(2**11)),
  )
  return fig


_ = plot(predictions=preds, interval=interval, targets=true_tracks)