## Installation and Imports


In [None]:
import os
import re
import gzip

import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sns
import requests
from pyfaidx import Fasta

---
# Post-Trained Model

## 1. Load post-trained model


In [None]:
from nucleotide_transformer_v3.pretrained import get_posttrained_ntv3_model
# Load the post-trained model
posttrained_model, tokenizer, config = get_posttrained_ntv3_model(
    model_name= "NTv3_100M_post",
    embeddings_layers_to_save=(6,), 
    attention_maps_to_save=((6, 1),),
    use_bfloat16=True, # use bfloat16 for lower memory usage default is False
)

In [None]:
# Explore the model configuration
print("Post-trained model configuration:")
print(f"  - Embedding dim: {config.embed_dim}")
print(f"  - Num layers: {config.num_layers}")
print(f"  - Keep center fraction: {config.keep_target_center_fraction}")

# Get available assemblies (species)
print(f"\nAvailable species: {posttrained_model.supported_species}")

# Get BED element names
bed_element_names = config.bed_elements_names
print(f"\nBED elements ({len(bed_element_names)}): {bed_element_names[:10]}...")

# Get bigwig track names for hg38
bigwig_names = config.bigwigs_per_species.get("human", [])
print(f"\nBigwig tracks for human ({len(bigwig_names)}): {bigwig_names[:5]}...")


## 2. Get genomic data
### Utility functions for fetching genomic data


In [None]:
# DNA validation and sanitization
DNA_RE = re.compile(r"^[ACGTNacgtn]+$")

def sanitize_dna(seq: str) -> str:
    """Sanitize DNA sequence: uppercase and replace non-ACGTN with N."""
    seq = seq.upper()
    seq = re.sub(r"[^ACGTN]", "N", seq)
    return seq

def download_ucsc_chrom_fasta(chrom: str, assembly: str, out_dir: str = None) -> str:
    """Download a single chromosome FASTA from UCSC and return local path."""
    if out_dir is None:
        out_dir = f"./{assembly}"
    os.makedirs(out_dir, exist_ok=True)
    
    gz_path = os.path.join(out_dir, f"{chrom}.fa.gz")
    fa_path = os.path.join(out_dir, f"{chrom}.fa")
    
    if os.path.exists(fa_path):
        print(f"Using cached FASTA: {fa_path}")
        return fa_path
    
    # UCSC chromosome fasta URL
    url = f"https://hgdownload.soe.ucsc.edu/goldenPath/{assembly}/chromosomes/{chrom}.fa.gz"
    print(f"Downloading: {url}")
    
    r = requests.get(url, stream=True)
    r.raise_for_status()
    
    with open(gz_path, "wb") as f:
        for chunk in r.iter_content(chunk_size=1024 * 1024):
            if chunk:
                f.write(chunk)
    
    # Decompress
    print("Decompressing...")
    with gzip.open(gz_path, "rb") as fin, open(fa_path, "wb") as fout:
        fout.write(fin.read())
    
    print(f"Saved to: {fa_path}")
    return fa_path

def fetch_window_sequence(chrom: str, start: int, end: int, fasta_path: str) -> str:
    """Fetch [start, end) sequence from fasta."""
    fasta = Fasta(fasta_path, rebuild=False)
    seq = fasta[chrom][start:end].seq
    return sanitize_dna(seq)


In [None]:
# User inputs - genomic region to analyze
assembly = "hg38"
chrom = "chr19"
start = 6_700_000
end = 6_831_072  # 131,072 bp window (multiple of 128)

# Validate window length
window_len = end - start
assert window_len % 128 == 0, f"Window length ({window_len}) must be a multiple of 128"
print(f"Window: {chrom}:{start:,}-{end:,}")
print(f"Window length: {window_len:,} bp")

# Download chromosome FASTA from UCSC
fasta_path = download_ucsc_chrom_fasta(chrom, assembly)

# Fetch the sequence
seq = fetch_window_sequence(chrom, start, end, fasta_path)

print(f"\nSequence preview: {seq[:60]}...")
print(f"Sequence length: {len(seq):,} bp")
print(f"Valid DNA: {bool(DNA_RE.match(seq))}")

assert len(seq) == window_len, "Fetched sequence length mismatch"


## 3. Prepare inputs for post-trained model


In [None]:
# Tokenize the sequence
tokens = tokenizer.batch_np_tokenize([seq])
print(f"Tokens shape: {tokens.shape}")

# Prepare species/assembly condition
species = 'human'
species_token = posttrained_model.encode_species(species)
print(f"Species token: {species_token} with shape {species_token.shape}")

In [None]:
%%time
# Run post-trained model inference
posttrained_outs = posttrained_model(
    tokens=tokens,
    species_tokens=species_token
)
print(f"\nOutput keys: {posttrained_outs.keys()}")

In [None]:
# Examine output shapes
if "bigwig_tracks_logits" in posttrained_outs:
    bigwig_logits = posttrained_outs["bigwig_tracks_logits"]
    print(f"Bigwig tracks logits shape: {bigwig_logits.shape}")
    print(f"  (batch_size, sequence_length, num_tracks)")

if "bed_tracks_logits" in posttrained_outs:
    bed_logits = posttrained_outs["bed_tracks_logits"]
    print(f"\nBED tracks logits shape: {bed_logits.shape}")
    print(f"  (batch_size, sequence_length, num_elements, num_classes)")


## 4. Plot Attention Maps 

In [None]:
# plot attention maps
num_downsamples = config.num_downsamples
base_pairs_per_token = 2 ** num_downsamples

# Convert attention maps to float32 for visualization (important when use_bfloat16=True)
attention_map = np.array(
    posttrained_outs["attention_map_layer_6_number_1"], 
    dtype=np.float32
)

# Get sequence length for first sequence
padding_mask = jnp.expand_dims(tokens != tokenizer.pad_token_id, axis=-1)
sequences_lengths = jnp.sum(padding_mask, axis=1)
seq_length = int(sequences_lengths[0][0])

# ticks at start, middle, end (bp)
x_labels = np.array([0, seq_length // 2, max(seq_length - 1, 0)])
y_labels = np.array([0, seq_length // 2, max(seq_length - 1, 0)])
x_ticks = x_labels // base_pairs_per_token
y_ticks = y_labels // base_pairs_per_token

# Plot attention map for first sequence
fig, ax = plt.subplots(figsize=(7, 6))

# Extract attention map slice
attn_slice = attention_map[0, 1 : (seq_length + 1), 1 : (seq_length + 1)]

# Use a high-contrast colormap (options: 'hot', 'viridis', 'plasma', 'inferno', 'magma')
im = ax.imshow(
    attn_slice,
    cmap='hot',  # High contrast colormap - change to 'viridis', 'plasma', etc. if preferred
    vmin=np.percentile(attn_slice, 5) ,
    vmax=np.percentile(attn_slice, 95)  ,
    aspect='auto'
)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
fig.colorbar(im, cax=cax, orientation="vertical")

ax.set_xticks(x_ticks)
ax.set_yticks(y_ticks)
ax.set_xticklabels(x_labels, rotation=45)
ax.set_yticklabels(y_labels)
ax.set_xlabel("Position (bp)")
ax.set_ylabel("Position (bp)")

fig.tight_layout()

## 5. Visualization of Post-trained Predictions

The post-trained model predicts for the **middle 37.5%** of the input sequence.

We'll visualize:
1. Functional genomics tracks (bigwig predictions)
2. Genomic element annotations (BED predictions)


In [None]:
def plot_tracks(tracks, start, end, chrom, assembly, height=1.0):
    """Plot functional genomics tracks."""
    fig, axes = plt.subplots(len(tracks), 1, figsize=(20, height * len(tracks)), sharex=True)
    
    if len(tracks) == 1:
        axes = [axes]
    
    for ax, (title, y) in zip(axes, tracks.items()):
        ax.fill_between(np.linspace(start, end, num=len(y)), y)
        ax.set_title(title)
        ax.set_ylabel("Signal")
        sns.despine(top=True, right=True, bottom=True, ax=ax)
    
    axes[-1].set_xlabel(f"{chrom}:{start:,}-{end:,} ({assembly})")
    plt.tight_layout()
    return fig


In [None]:
# Define tracks to plot (track_name -> track_id)
tracks_to_plot = {
    "K562 RNA-seq": "ENCSR056HPM",
    "K562 DNAse": "ENCSR921NMD",
    "K562 H3k4me3": "ENCSR000DWD",
    "K562 CTCF": "ENCSR000AKO",
    "HepG2 RNA-seq": "ENCSR561FEE_P",
    "HepG2 DNAse": "ENCSR000EJV",
    "HepG2 H3k4me3": "ENCSR000AMP",
    "HepG2 CTCF": "ENCSR000BIE",
}

# Check track availability
available_tracks = {}
for name, track_id in tracks_to_plot.items():
    if track_id in bigwig_names:
        available_tracks[name] = track_id
    else:
        print(f"Warning: Track '{name}' ({track_id}) not available")

print(f"\nAvailable tracks to plot: {len(available_tracks)}")

# Define genomic elements to plot
elements_to_plot = [
    "protein_coding_gene",
    "exon",
    "intron",
    "splice_donor",
    "splice_acceptor",
]

# Check element availability
available_elements = [elem for elem in elements_to_plot if elem in bed_element_names]
missing_elements = [elem for elem in elements_to_plot if elem not in bed_element_names]

if missing_elements:
    print(f"Warning: Missing elements: {missing_elements}")

print(f"Available elements to plot: {available_elements}")

### Extract and plot predictions


In [None]:
# Convert outputs to numpy
bigwig = np.array(posttrained_outs["bigwig_tracks_logits"])[0]  # shape: (seq_len, num_tracks)
bed_logits_np = np.array(posttrained_outs["bed_tracks_logits"])[0]  # shape: (seq_len, num_elements, 2)

print(f"Bigwig predictions shape: {bigwig.shape}")
print(f"BED predictions shape: {bed_logits_np.shape}")

# Calculate prediction window coordinates
# Model predicts for middle 37.5% of input sequence
keep_fraction = config.keep_target_center_fraction
prediction_start = start + int(window_len * (1 - keep_fraction) / 2)
prediction_end = prediction_start + int(window_len * keep_fraction)

print(f"\nPrediction window: {chrom}:{prediction_start:,}-{prediction_end:,}")
print(f"Prediction length: {prediction_end - prediction_start:,} bp")

# Extract functional tracks for selected bigwig tracks
bigwig_tracks = {}
for track_name, track_id in available_tracks.items():
    track_idx = bigwig_names.index(track_id)
    bigwig_tracks[track_name] = bigwig[:, track_idx]

# Convert BED logits to probabilities
exp = np.exp(bed_logits_np - bed_logits_np.max(axis=-1, keepdims=True))
bed_probs = exp / exp.sum(axis=-1, keepdims=True)

# Extract positive class probabilities for selected elements
bed_tracks = {}
for element_name in available_elements:
    element_idx = bed_element_names.index(element_name)
    bed_tracks[element_name] = bed_probs[:, element_idx, 1]  # positive class probability


fig = plot_tracks(bigwig_tracks, prediction_start, prediction_end, chrom, assembly, height=1.2)
plt.suptitle("Bigwig Track Predictions", y=1.02, fontsize=14)
plt.show()

fig = plot_tracks(bed_tracks, prediction_start, prediction_end, chrom, assembly, height=1.0)
plt.suptitle("Genomic Element Predictions (probability)", y=1.02, fontsize=14)
plt.show()