In [1]:
# EpiBERT caqTL Prediction Example
# This notebook demonstrates how to use EpiBERT to predict chromatin accessibility
# quantitative trait loci (caqTL) effects - i.e., how genetic variants affect
# chromatin accessibility.

import os
import warnings
from pathlib import Path

# Suppress TensorFlow warnings for cleaner output
warnings.filterwarnings("ignore", category=UserWarning, module='tensorflow_addons')
warnings.filterwarnings("ignore", category=FutureWarning)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Show errors and warnings only

# Core imports
import tensorflow as tf
from tensorflow.keras import mixed_precision
import numpy as np
import pandas as pd

# EpiBERT imports (using refactored modules)
from epibert.models import epibert_atac_pretrain as epibert
from epibert.analysis import interval_plotting_consolidated as utils

gpus = tf.config.experimental.list_physical_devices('GPU')
mixed_precision.set_global_policy('mixed_bfloat16')



In [2]:
# Configuration - Update these paths for your environment

fasta_file = '/home/jupyter/EpiBERT/example_usage/hg38_erccpatch.fa'

gtf_file = '/home/jupyter/data/hg38.refGene.gtf' # can get from UCSC or gs://epibert/example_data

# Paths to trained EpiBERT model checkpoints, obtain from gs://epibert/models/pretrained
checkpoint1_path = "model1/ckpt-45" # epibert/models/pretrained/model1
checkpoint2_path = "model2/ckpt-45" # epibert/models/pretrained/model2

# Example ATAC-seq data file (processed bedgraph format)
atac_file = "K562.adjust.bed.gz" # create your own in data_processing.ipynb or download an example from gs://epibert/example_data

motif_activity = '/home/jupyter/EpiBERT/example_usage/ENCFF135AEX.motifs.tsv' # create your own in data_processing.ipynb or use the provided example at gs://epibert/example_data

# Initialize FASTA extractor
fasta_extractor = utils.FastaStringExtractor(fasta_file)

# Load EpiBERT models
model1 = epibert.epibert_atac_pretrain()
model2 = epibert.epibert_atac_pretrain()

# Create ensemble model wrapper
epibert_model = utils.epibert_ensembl_model(model1, model2, checkpoint1_path, checkpoint2_path)
print("EpiBERT models loaded successfully")

Loaded checkpoints.
EpiBERT models loaded successfully


In [3]:
# Define a test variant: C -> T transition at chr12:9764948 (CD69 enhancer region)
variant = ('chr12:9764948', 'T')
# Prepare inputs for variant scoring
# Extract genomic sequence and prepare model inputs
# This function handles:
# - Extracting reference and alternate sequences
# - Processing ATAC-seq data for the region
# - Preparing motif enrichment features
# - Creating proper input tensors for the model


inputs, inputs_mut, masked_atac, motif, target_atac, masked_atac_reshape, mask, mask_centered, interval_resize = \
    utils.return_inputs_caqtl_score(variant, atac_file, motif_activity, fasta_extractor)


In [4]:
# Run variant effect prediction
# ==============================

# Use GPU if available, otherwise CPU
device = '/GPU:0' if tf.config.list_physical_devices('GPU') else '/CPU:0'
print(f"Using device: {device}")

with tf.device(device):
    # Get predictions for both reference and alternate sequences
    # Returns: wild-type prediction, mutant prediction, and caqTL score
    output, output_mut, caqtl_score = epibert_model.ca_qtl_score(inputs, inputs_mut)

print(f"✓ Prediction completed")
print(f"✓ Wild-type prediction shape: {output.shape}")
print(f"✓ Mutant prediction shape: {output_mut.shape}")
print(f"✓ caqTL score: {caqtl_score:.2f}")

Using device: /GPU:0
✓ Prediction completed
✓ Wild-type prediction shape: (4092,)
✓ Mutant prediction shape: (4092,)
✓ caqTL score: -36.00


In [None]:
# Visualize predictions across the genomic region

# Prepare tracks for plotting
# - Wild-type: predicted accessibility for reference sequence
# - Mutant: predicted accessibility for alternate sequence  
# - Difference: effect of the variant (amplified 10x for visibility)
tracks = {
    'wild_type_atac': (output, 'green'),
    'mutant_atac': (output_mut, 'blue'),
    'difference': ((output - output_mut) * 10, 'red')
}
# Plot tracks with gene annotations
# The plot shows the entire genomic region analyzed by EpiBERT
utils.plot_tracks_with_genes(tracks, gtf_file, interval_resize, 1500)
print("✓ Full region plot generated") 

In [None]:
# Zoom in to the variant location
# ================================
print("Creating zoomed view around variant...")

# Focus on a 100bp window around the variant (center of the prediction window)
# The variant is at the center of the 4096bp input, so positions 2000-2100 
# correspond to the immediate vicinity of the variant
zoom_start, zoom_end = 2000, 2100
tracks_zoomed = {
    'wild_type_atac': (output[zoom_start:zoom_end], 'green'),
    'mutant_atac': (output_mut[zoom_start:zoom_end], 'blue'),
    'difference': ((output - output_mut)[zoom_start:zoom_end] * 10, 'red')
}

# Calculate the genomic coordinates for the zoomed region
# Each prediction position corresponds to 128bp in the genome
zoomed_interval = (
    interval_resize[0], 
    interval_resize[1] + 128 * (2 + zoom_start), 
    interval_resize[1] + 128 * (2 + zoom_end)
)

utils.plot_tracks_with_genes(tracks_zoomed, gtf_file, zoomed_interval, 1500)
print(f"✓ Zoomed plot generated for region: {zoomed_interval[1]}-{zoomed_interval[2]}") 

In [None]:
# Display the caqTL score
# =======================
print(f"caqTL Score: {caqtl_score:.2f}")
print("\nInterpretation:")
print("- Positive scores indicate increased accessibility")
print("- Negative scores indicate decreased accessibility") 
print("- Larger absolute values indicate stronger effects")
print(f"- This variant shows a {'gain' if caqtl_score > 0 else 'loss'} of accessibility")

In [10]:
# Multiple Variant Scoring Example
# ==============================
print("Analyzing multiple variants...")

# Define a list of variants to test
# Testing different alleles at the same position and nearby positions
variant_list = [
    ('chr12:9764948', 'T'),  # Original variant
    ('chr12:9764948', 'G'),  # Different allele at same position
    ('chr12:9764948', 'A'),  # Another allele at same position
    ('chr12:9764958', 'C'),  # Nearby variant (+10bp)
    ('chr12:9764949', 'T'),  # Adjacent variant (+1bp)
    ('chr12:9764100', 'T'),  # Distant variant in same region
]

print(f"Testing {len(variant_list)} variants...")

# Score each variant (should probably figure out how to batch predict)
variant_scores = {}
with tf.device(device):
    for i, variant in enumerate(variant_list, 1):
        print(f"  Processing variant {i}/{len(variant_list)}: {variant[0]} -> {variant[1]}")
        
        # Prepare inputs for this variant
        inputs, inputs_mut, masked_atac, motif, target_atac, masked_atac_reshape, mask, mask_centered, interval_resize = \
            utils.return_inputs_caqtl_score(variant, atac_file, motif_activity, fasta_extractor)
        
        # Get prediction
        output, output_mut, caqtl_score = epibert_model.ca_qtl_score(inputs, inputs_mut)
        variant_scores[variant] = caqtl_score

print("Scoring completed")

Analyzing multiple variants...
Testing 6 variants...
  Processing variant 1/6: chr12:9764948 -> T
  Processing variant 2/6: chr12:9764948 -> G
  Processing variant 3/6: chr12:9764948 -> A
  Processing variant 4/6: chr12:9764958 -> C
  Processing variant 5/6: chr12:9764949 -> T
  Processing variant 6/6: chr12:9764100 -> T
✓ Batch scoring completed


In [13]:
# Also show the raw dictionary for reference
print("\nScores:")
variant_scores


Scores:


{('chr12:9764948', 'T'): -36.0,
 ('chr12:9764948', 'G'): -20.0,
 ('chr12:9764948', 'A'): -36.0,
 ('chr12:9764958', 'C'): 0.0,
 ('chr12:9764949', 'T'): 10.0,
 ('chr12:9764100', 'T'): -0.75}