## Installation and Imports


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from flax import nnx
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable


---
# A. Pre-Trained Model


In [None]:
from nucleotide_transformer_v3.pretrained import get_pretrained_ntv3_model

# Load the pre-trained 106M model
pretrained_model, tokenizer, config = get_pretrained_ntv3_model(
    model_name="NTv3_100M_pre",
    embeddings_layers_to_save=(6,), 
    attention_maps_to_save=((6, 1),),
    use_bfloat16=True, # use bfloat16 for lower memory usage default is False
)

print(f"\nModel config:")
print(f"  - Embedding dim: {config.embed_dim}")
print(f"  - Num layers: {config.num_layers}")
print(f"  - Attention heads: {config.attention_heads}")
print(f"  - Num transformer blocks: {len(pretrained_model.transformer_blocks)}")
print(f"  - Num downsamples: {config.num_downsamples}")


### Model Architecture Visualization

The NTv3 model uses a **U-Net-like architecture** with the following key components:

- **Downsampling tower**: Convolutional layers that reduce sequence length by factors of 2
- **Transformer blocks**: Self-attention layers that process compressed representations  
- **Upsampling tower**: Deconvolutional layers that restore original sequence resolution
- **Skip connections**: Connect downsampling and upsampling layers for information flow

Below, we visualize the complete model structure using Flax's `nnx.display()` function, which shows the hierarchical organization of all model components, their shapes, and parameter counts.


In [None]:
# Visualize the model architecture
nnx.display(pretrained_model)

## 1. Tokenize DNA sequences

NTv3 uses single nucleotide tokenization. Sequences should be multiples of `2^num_downsamples` (128 for 7 downsamples).


In [None]:
# Example DNA sequences (padded to 256 nucleotides for demo)
# generate 2 random ATCG string length 2**15
import random
sequences = [
    "".join(random.choices("ATCG", k=2**15)),
    "".join(random.choices("ATCG", k=2**15)),
]
# Tokenize
tokens = tokenizer.batch_np_tokenize(sequences)

print(f"Input shape: {tokens.shape}")
print(f"Sequence length: {len(sequences[0])} nucleotides")
print(f"Pad token ID: {tokenizer.pad_token_id}")
print(f"Mask token ID: {tokenizer.mask_token_id}")


## 2. Run inference


In [None]:
%%time
# Run inference
outs = pretrained_model(tokens)

print(f"\nOutput keys: {outs.keys()}")


## 3. Retrieve embeddings


In [None]:
# Get embeddings from layer 12
embeddings = outs["embeddings_6"]
print(f"Embeddings shape: {embeddings.shape}")
print(f"  (batch_size, sequence_length, embed_dim)")

# Final embeddings after deconv tower
final_embeddings = outs["embedding"]
print(f"\nFinal embeddings shape: {final_embeddings.shape}")


### Compute mean embeddings

For sequence-level representations, we can compute mean embeddings (excluding padding).


In [None]:
# Create padding mask
padding_mask = jnp.expand_dims(tokens != tokenizer.pad_token_id, axis=-1)
masked_embeddings = final_embeddings * padding_mask

# Compute mean embeddings
sequences_lengths = jnp.sum(padding_mask, axis=1)
mean_embeddings = jnp.sum(masked_embeddings, axis=1) / sequences_lengths

print(f"Mean embeddings shape: {mean_embeddings.shape}")
print(f"  (batch_size, embed_dim)")


## 4. Plot attention maps

In [None]:
num_downsamples = config.num_downsamples
base_pairs_per_token = 2 ** num_downsamples
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(14, 6))

padding_mask = jnp.expand_dims(tokens != tokenizer.pad_token_id, axis=-1)
sequences_lengths = jnp.sum(padding_mask, axis=1)
seq_length0, seq_length1 = int(sequences_lengths[0][0]), int(sequences_lengths[1][0])

# ticks at start, middle, end (bp)
x_labels = np.array([0, seq_length0 // 2, max(seq_length0 - 1, 0)])
y_labels = np.array([0, seq_length0 // 2, max(seq_length0 - 1, 0)])
x_ticks = x_labels // base_pairs_per_token
y_ticks = y_labels // base_pairs_per_token

# plot for first seq in the batch
im0 = axes[0].imshow(
    outs["attention_map_layer_6_number_1"][
        0, 1 : (seq_length0 + 1), 1 : (seq_length0 + 1)
    ]
)
divider0 = make_axes_locatable(axes[0])
cax0 = divider0.append_axes("right", size="5%", pad=0.05)
fig.colorbar(im0, cax=cax0, orientation="vertical")

axes[0].set_xticks(x_ticks)
axes[0].set_yticks(y_ticks)
axes[0].set_xticklabels(x_labels, rotation=45)
axes[0].set_yticklabels(y_labels)
axes[0].set_xlabel("Position (bp)")
axes[0].set_ylabel("Position (bp)")

# plot for second seq in the batch
im1 = axes[1].imshow(
    outs["attention_map_layer_6_number_1"][
        1, 1 : (seq_length1 + 1), 1 : (seq_length1 + 1)
    ]
)
divider1 = make_axes_locatable(axes[1])
cax1 = divider1.append_axes("right", size="5%", pad=0.05)
fig.colorbar(im1, cax=cax1, orientation="vertical")
axes[1].set_xticks(x_ticks)
axes[1].set_yticks(y_ticks)
axes[1].set_xticklabels(x_labels, rotation=45)
axes[1].set_yticklabels(y_labels)
axes[1].set_xlabel("Position (bp)")
axes[1].set_ylabel("Position (bp)")

fig.tight_layout()