In [None]:
# EpiBERT Gradients and Gene Analysis Workflow
# This notebook demonstrates a complete workflow for analyzing model gradients and gene predictions
# using EpiBERT. It covers:
# 1. Model gradient analysis for single genomic intervals
# 2. Batch predictions for multiple genes
# 3. Parsing and evaluating predictions

import sys
import os
import subprocess
import warnings
from pathlib import Path
import random
import time
from datetime import datetime

# Add parent directory to path for imports
notebook_dir = Path(os.getcwd())
sys.path.append(str(notebook_dir.parent))

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore", category=UserWarning, module='tensorflow_addons')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# Core imports
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
from tensorflow import strings as tfs
from tensorflow.keras import mixed_precision
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from scipy import stats
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import precision_recall_curve, auc, average_precision_score
import kipoiseq

import src.models.epibert_rampage_finetune as epibert
import training_utils_rampage_finetune as training_utils
import analysis.interval_and_plotting_utilities as utils

plt.style.use('seaborn-v0_8')

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(f"GPU memory growth setup failed: {e}")

# Use mixed precision for better performance
mixed_precision.set_global_policy('mixed_bfloat16')



In [None]:
# Configuration - Update these paths for your environment
# =======================================================

# Model parameters
SEQUENCE_LENGTH = 524288
RESOLUTION = 4
NUM_BINS = SEQUENCE_LENGTH // RESOLUTION
OUTPUT_LENGTH = NUM_BINS // 32
CROP_SIZE = 1600
MASK_INDICES = '2041-2053'

# File paths - update these for your setup
FASTA_FILE = '/home/jupyter/reference/hg38_erccpatch.fa'
ATAC_FILE = "/home/jupyter/datasets/ATAC/HG_K562.bed.gz" # see data processing notebook
RNA_FILE = "/home/jupyter/datasets/ATAC/HG_K562.rampage.bed.gz" # bed formatted RAMPAGE file
TF_FILE = '/home/jupyter/datasets/ATAC/HG_K562.tsv'
ENHANCER_FILE = '/home/jupyter/datasets/eg/hg38_eg.bed'

# Model checkpoint path
CHECKPOINT_PATH = "gs://genformer_europe_west_copy/524k/rampage_finetune/models/genformer_524k_LR1-5.0e-04_LR2-5.0e-04_C-512_640_640_768_896_1024_T-8_motif-True_9_m7o9qhwt/ckpt-16"

# Output directories
TEMP_DIR = "temp_files"
OUTPUT_DIR = "output_gradients"

# Create output directories if they don't exist
os.makedirs(TEMP_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Gene dictionary for analysis
GENES_DICT = {
    "HNRNPA1": "chr12:54279939-54281439",
    "NFE2": "chr12:54300287-54301787", 
    "COPZ1": "chr12:54324339-54325839",
    "ITGA5": "chr12:54418516-54420016",
    "WDR83OS": "chr19:12668901-12670401",
    "DHPS": "chr19:12681137-12682637",
    "C19orf43": "chr19:12734025-12735525",
    "JUNB": "chr19:12790745-12792245",
    "PRDX2": "chr19:12801160-12802660",
    "RNASEH2A": "chr19:12805863-12807363",
    "MYC": "chr8:127735318-127736818",
    "GATA1": "chrX:48785823-48787323"
}

# Set random seed for reproducibility
SEED = 6
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)



In [None]:
# Initialize Model and Data Extractors

#  FASTA extractor
fasta_extractor = utils.FastaStringExtractor(FASTA_FILE)

# Device strategy - use GPU since model is too large to run on CPU 
device = '/GPU:0' if tf.config.list_physical_devices('GPU') else '/CPU:0'

# Initialize model with GPU strategy

with tf.device(device):
    model = epibert.epibert_rampage_finetune(
        kernel_transformation='relu_kernel_transformation',
        dropout_rate=0.20,
        pointwise_dropout_rate=0.10,
        input_length=SEQUENCE_LENGTH,
        output_length=4096,
        final_output_length=896,
        num_heads=8,
        numerical_stabilizer=0.0000001,
        max_seq_length=4096,
        seed=SEED,
        norm=True,
        BN_momentum=0.90,
        normalize=True,
        use_rot_emb=True,
        num_transformer_layers=8,
        final_point_scale=6,
        filter_list_seq=[512, 640, 640, 768, 896, 1024],
        filter_list_atac=[32, 64],
        predict_atac=True
    )
    
    model.load_weights(CHECKPOINT_PATH)

In [None]:
# Single Interval Analysis - HNRNPA1 Example

# Select a specific genomic interval for detailed analysis
example_gene = "HNRNPA1"
interval = GENES_DICT[example_gene]

print(f"Analyzing interval: {example_gene} at {interval}")

# Extract inputs for the genomic interval
print("Extracting inputs for interval analysis...")
inputs, masked_atac, target_atac, target_atac_uncropped, rna_arr, masked_atac_reshape, mask, mask_centered = \
    utils.return_all_inputs_simple(
        interval, 
        ATAC_FILE, 
        RNA_FILE, 
        SEQUENCE_LENGTH,
        NUM_BINS, 
        RESOLUTION, 
        TF_FILE, 
        CROP_SIZE, 
        OUTPUT_LENGTH,
        fasta_extractor, 
        MASK_INDICES, 
        None  # No strategy needed for single GPU
    )

print(f" Inputs extracted:")
print(f"  - Sequence shape: {inputs[0].shape}")
print(f"  - ATAC shape: {masked_atac.shape}")
print(f"  - RNA shape: {rna_arr.shape}")
print(f"  - Target ATAC shape: {target_atac.shape}")

# Compute gradients using integrated gradients
print("Computing gradients...")
with tf.device(device):
    seq, seq_grads, atac_grads, prediction, att_matrices, att_matrices_norm = \
        model.contribution_input_grad_dist_simple(inputs, mask)

print(f"✓ Gradients computed:")
print(f"  - Sequence gradients shape: {seq_grads.shape}")
print(f"  - ATAC gradients shape: {atac_grads.shape}")
print(f"  - Prediction shape: {prediction.shape}")
print(f"  - Attention matrices shape: {att_matrices.shape}")


In [None]:
# Visualize Predictions vs Ground Truth

print("Creating prediction comparison plot...")

# Plot predicted vs actual RNA expression
tracks = {
    'actual_rna': (rna_arr[:, 0], 'red'),
    'predicted_rna': (prediction[0, :, 0], 'blue')
}

plt.figure(figsize=(15, 6))
utils.plot_tracks(tracks, 0, 896, 25)
plt.title(f'{example_gene} - RNA Expression: Predicted vs Actual')
plt.xlabel('Genomic Position (bins)')
plt.ylabel('Expression Level')
plt.legend(['Actual RNA', 'Predicted RNA'])
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Calculate correlation between predicted and actual
correlation = pearsonr(rna_arr[:, 0].numpy(), prediction[0, :, 0].numpy())[0]
print(f"Pearson correlation (predicted vs actual): {correlation:.3f}")


In [None]:
# Process and Visualize Gradients

print("Processing gradients and enhancer annotations...")

# Process ATAC gradients
grad_input = tf.abs(atac_grads[0][:, 0]) * masked_atac[:, 0]
reshaped_grad = tf.reduce_sum(tf.reshape(grad_input, [4096, 32]), axis=1)

# Process sequence gradients
seq_grad_input = tf.reduce_sum(
    tf.reshape(
        tf.reduce_sum(tf.abs(seq_grads) * seq[0][0, :, :], axis=1),
        [-1, 128]
    ),
    axis=1
)

# Load enhancer annotations if available
try:
    enhancer_file = f'/home/jupyter/datasets/eg/hg38_eg.{example_gene}.bed.gz'
    eg = utils.return_eg(interval, enhancer_file, SEQUENCE_LENGTH)
    eg_grouped = tf.reduce_max(tf.reshape(eg, [4096, 128]), axis=1)
    
    # Load significance annotations
    sig_file = f'/home/jupyter/datasets/eg/hg38_eg.{example_gene}.sig.bed.gz'
    eg_sig = utils.return_eg(interval, sig_file, SEQUENCE_LENGTH)
    eg_grouped_sig = tf.reduce_max(tf.reshape(eg_sig, [4096, 128]), axis=1)
    
    enhancer_available = True
    print("✓ Enhancer annotations loaded")
except:
    enhancer_available = False
    print("! Enhancer annotations not available")

# Normalize gradients for visualization
atac_grads_norm = reshaped_grad / tf.reduce_max(reshaped_grad)
seq_grads_norm = seq_grad_input / tf.reduce_max(seq_grad_input)
atac_signal_norm = target_atac_uncropped[:, 0] / tf.reduce_max(target_atac_uncropped[:, 0])


In [None]:
# Create Comprehensive Gradient Visualization
# ==========================================

tracks = {
    'atac_gradients': (atac_grads_norm, 'blue'),
    'sequence_gradients': (seq_grads_norm, 'green'),
    'atac_signal': (atac_signal_norm, 'orange'),
    'combined_gradients': (atac_grads_norm + seq_grads_norm, 'purple')
}

# Add enhancer tracks if available
if enhancer_available:
    tracks['enhancers'] = (eg_grouped, 'red')
    tracks['significant_enhancers'] = (eg_grouped_sig, 'pink')

# Create the plot
plt.figure(figsize=(20, 12))
utils.plot_tracks(tracks, 0, 4096, 1.0)
plt.title(f'{example_gene} - Gradient Analysis Across Genomic Region')
plt.xlabel('Genomic Position (128bp bins)')
plt.ylabel('Normalized Signal')

# Create legend
legend_labels = [
    'ATAC Gradients',
    'Sequence Gradients', 
    'ATAC Signal',
    'Combined Gradients'
]
if enhancer_available:
    legend_labels.extend(['Enhancers', 'Significant Enhancers'])

plt.legend(legend_labels, loc='upper right')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("✓ Gradient visualization completed")


In [None]:
# Batch Processing: Prepare Enhancer Files

print("Preparing enhancer files for batch analysis...")

def prepare_enhancer_files(gene_name, enhancer_file, temp_dir):
    """Prepare enhancer files for a specific gene"""
    
    # Extract enhancers for this gene
    gene_bed_file = f"{temp_dir}/{gene_name}.eg.bed"
    command = f"grep '{gene_name}' {enhancer_file} | sort -k1,1 -k2,2n > {gene_bed_file}"
    result = subprocess.run(command, shell=True, capture_output=True, text=True)
    
    if result.returncode != 0:
        print(f"Warning: No enhancers found for {gene_name}")
        return None
    
    # Add encoding column
    encoded_file = f"{temp_dir}/{gene_name}.eg.encoded.bed"
    command = f"awk '{{OFS=\"\\t\"}}{{print $1,$2,$3,NR}}' {gene_bed_file} > {encoded_file}"
    subprocess.run(command, shell=True)
    
    # Compress and index
    command = f"bgzip -f {encoded_file}"
    subprocess.run(command, shell=True)
    
    command = f"tabix -f {encoded_file}.gz"
    subprocess.run(command, shell=True)
    
    return f"{encoded_file}.gz"

# Process enhancer files for all genes
enhancer_files = {}
for gene_name in GENES_DICT.keys():
    print(f"  Processing enhancers for {gene_name}...")
    enhancer_file = prepare_enhancer_files(gene_name, ENHANCER_FILE, TEMP_DIR)
    if enhancer_file:
        enhancer_files[gene_name] = enhancer_file

print(f"✓ Prepared enhancer files for {len(enhancer_files)} genes")


In [None]:
# Batch Processing: Compute Gradients for All Genes
print("Starting batch gradient computation...")

def process_gene_gradients(gene_name, interval, enhancer_file):
    """Process gradients for a single gene"""
    
    print(f"  Processing {gene_name}...")
    
    # Extract inputs
    inputs, masked_atac, target_atac, target_atac_uncropped, rna_arr, \
    masked_atac_reshape, mask, mask_centered = utils.return_all_inputs_simple(
        interval, ATAC_FILE, RNA_FILE, SEQUENCE_LENGTH, NUM_BINS, RESOLUTION,
        TF_FILE, CROP_SIZE, OUTPUT_LENGTH, fasta_extractor, MASK_INDICES, None
    )
    
    # Compute gradients
    with tf.device(device):
        seq, seq_grads, atac_grads, prediction, att_matrices, att_matrices_norm = \
            model.contribution_input_grad_dist_simple(inputs, mask)
    
    # Process ATAC gradients
    grad_input = tf.abs(atac_grads[0][:, 0]) * masked_atac[:, 0]
    reshaped_grad = tf.reduce_sum(tf.reshape(grad_input, [4096, 32]), axis=1)
    
    # Load enhancer annotations
    eg = utils.return_eg(interval, enhancer_file, SEQUENCE_LENGTH)
    eg_grouped = tf.reduce_max(tf.reshape(eg, [4096, 128]), axis=1)
    eg_grouped = eg_grouped.numpy()
    eg_unique = np.unique(eg_grouped)
    
    # Scale gradients
    atac_grads_scaled = reshaped_grad / tf.reduce_max(reshaped_grad)
    
    # Extract enhancer-specific gradient scores
    enhancer_scores = []
    for enhancer_id in eg_unique[eg_unique != 0]:
        indices = np.where(eg_grouped == enhancer_id)[0]
        gradient_score = tf.reduce_sum(tf.gather(atac_grads_scaled, indices)).numpy()
        enhancer_scores.append({
            'encoding': int(enhancer_id),
            'grad_out': gradient_score
        })
    
    # Save results
    output_file = f"{OUTPUT_DIR}/{gene_name}.eg.preds.tsv"
    df = pd.DataFrame(enhancer_scores)
    df.to_csv(output_file, sep='\t', index=False)
    
    return len(enhancer_scores)

# Process all genes
total_enhancers = 0
processing_results = {}

for gene_name, interval in GENES_DICT.items():
    if gene_name in enhancer_files:
        try:
            n_enhancers = process_gene_gradients(gene_name, interval, enhancer_files[gene_name])
            processing_results[gene_name] = n_enhancers
            total_enhancers += n_enhancers
            print(f"    ✓ {gene_name}: {n_enhancers} enhancers processed")
        except Exception as e:
            print(f"    ✗ {gene_name}: Error - {str(e)}")
            processing_results[gene_name] = 0
    else:
        print(f"    ✗ {gene_name}: No enhancer file available")
        processing_results[gene_name] = 0

print(f"\n✓ Batch processing completed:")
print(f"  - Genes processed: {len([g for g, n in processing_results.items() if n > 0])}")
print(f"  - Total enhancers: {total_enhancers}")
print(f"  - Output directory: {OUTPUT_DIR}")


In [None]:
# Load and Parse Prediction Results
# =================================

print("Loading and parsing prediction results...")

def load_and_parse_gene_results(gene_name, temp_dir, output_dir):
    """Load predictions and merge with enhancer annotations"""
    
    # Load predictions
    pred_file = f"{output_dir}/{gene_name}.eg.preds.tsv"
    if not os.path.exists(pred_file):
        return None
    
    preds = pd.read_csv(pred_file, sep='\t')
    
    # Load enhancer annotations
    enhancer_file = f"{temp_dir}/{gene_name}.eg.bed"
    if not os.path.exists(enhancer_file):
        return None
    
    enhancer_df = pd.read_csv(enhancer_file, sep='\t', header=None)
    enhancer_df.columns = ['chrom', 'start', 'stop', 'info', 'blank']
    
    # Parse enhancer information
    enhancer_df[['gene_name', 'true', 'abc_score', 'distance']] = \
        enhancer_df['info'].str.split('_', expand=True)
    
    # Add encoding column
    enhancer_df['encoding'] = enhancer_df.reset_index().index + 1
    
    # Merge predictions with annotations
    merged_df = enhancer_df.merge(preds, on='encoding', how='inner')
    
    # Clean up data types
    merged_df['true'] = merged_df['true'].replace({'TRUE': 1, 'FALSE': 0})
    merged_df['abc_score'] = merged_df['abc_score'].astype(float)
    merged_df['distance'] = merged_df['distance'].astype(int)
    
    return merged_df

# Load results for all genes
all_results = []
for gene_name in GENES_DICT.keys():
    if processing_results.get(gene_name, 0) > 0:
        gene_results = load_and_parse_gene_results(gene_name, TEMP_DIR, OUTPUT_DIR)
        if gene_results is not None:
            all_results.append(gene_results)
            print(f"  ✓ {gene_name}: {len(gene_results)} enhancers loaded")

# Combine all results
if all_results:
    combined_df = pd.concat(all_results, ignore_index=True)
    print(f"\n✓ Combined dataset created:")
    print(f"  - Total enhancers: {len(combined_df)}")
    print(f"  - Genes represented: {combined_df['gene_name'].nunique()}")
    print(f"  - Positive examples: {combined_df['true'].sum()}")
    print(f"  - Positive rate: {combined_df['true'].mean():.3f}")
else:
    print("✗ No results to combine")
    combined_df = None


In [None]:
# Filter Data and Evaluate Performance
# ====================================

if combined_df is not None:
    print("Filtering data and evaluating performance...")
    
    # Apply distance filter (common in enhancer studies)
    max_distance = 100000  # 100kb
    min_distance = 1000    # 1kb
    
    filtered_df = combined_df[
        (combined_df['distance'] <= max_distance) & 
        (combined_df['distance'] >= min_distance)
    ].copy()
    
    print(f"✓ Applied distance filter ({min_distance}-{max_distance} bp):")
    print(f"  - Filtered enhancers: {len(filtered_df)}")
    print(f"  - Positive examples: {filtered_df['true'].sum()}")
    print(f"  - Positive rate: {filtered_df['true'].mean():.3f}")
    
    # Evaluate EpiBERT gradients
    if len(filtered_df) > 0 and filtered_df['true'].sum() > 0:
        precision_grad, recall_grad, _ = precision_recall_curve(
            filtered_df['true'], filtered_df['grad_out']
        )
        auprc_grad = auc(recall_grad, precision_grad)
        ap_grad = average_precision_score(filtered_df['true'], filtered_df['grad_out'])
        
        print(f"\n✓ EpiBERT Gradients Performance:")
        print(f"  - AUPRC: {auprc_grad:.4f}")
        print(f"  - Average Precision: {ap_grad:.4f}")
        
        # Evaluate ABC scores for comparison
        precision_abc, recall_abc, _ = precision_recall_curve(
            filtered_df['true'], filtered_df['abc_score']
        )
        auprc_abc = auc(recall_abc, precision_abc)
        ap_abc = average_precision_score(filtered_df['true'], filtered_df['abc_score'])
        
        print(f"\n✓ ABC Scores Performance:")
        print(f"  - AUPRC: {auprc_abc:.4f}")
        print(f"  - Average Precision: {ap_abc:.4f}")
        
        # Create comparison plot
        plt.figure(figsize=(12, 8))
        
        plt.subplot(1, 2, 1)
        plt.plot(recall_grad, precision_grad, label=f'EpiBERT Gradients (AP={ap_grad:.3f})', linewidth=2)
        plt.plot(recall_abc, precision_abc, label=f'ABC Scores (AP={ap_abc:.3f})', linewidth=2)
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.title('Precision-Recall Curves')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.subplot(1, 2, 2)
        plt.hist(filtered_df[filtered_df['true'] == 1]['grad_out'], 
                alpha=0.7, label='True Enhancers', bins=30, density=True)
        plt.hist(filtered_df[filtered_df['true'] == 0]['grad_out'], 
                alpha=0.7, label='Non-Enhancers', bins=30, density=True)
        plt.xlabel('EpiBERT Gradient Score')
        plt.ylabel('Density')
        plt.title('Score Distribution')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        # Performance comparison
        print(f"\n✓ Performance Comparison:")
        print(f"  - EpiBERT improvement over ABC: {((ap_grad - ap_abc) / ap_abc * 100):+.1f}%")
        
    else:
        print("✗ Insufficient data for evaluation")
        
else:
    print("✗ No combined dataset available for evaluation")


In [None]:
# Export Results and Summary
# ==========================

if combined_df is not None:
    print("Exporting results and creating summary...")
    
    # Save combined results
    output_file = f"{OUTPUT_DIR}/combined_results.tsv"
    filtered_df.to_csv(output_file, sep='\t', index=False)
    print(f"✓ Combined results saved: {output_file}")
    
    # Create summary statistics
    summary_stats = {
        'total_genes': len(GENES_DICT),
        'processed_genes': len([g for g, n in processing_results.items() if n > 0]),
        'total_enhancers': len(filtered_df),
        'positive_enhancers': filtered_df['true'].sum(),
        'positive_rate': filtered_df['true'].mean(),
        'epibert_ap': ap_grad if 'ap_grad' in locals() else None,
        'abc_ap': ap_abc if 'ap_abc' in locals() else None,
        'distance_filter': f"{min_distance}-{max_distance}bp"
    }
    
    # Save summary
    summary_file = f"{OUTPUT_DIR}/analysis_summary.json"
    import json
    with open(summary_file, 'w') as f:
        json.dump(summary_stats, f, indent=2)
    print(f"✓ Analysis summary saved: {summary_file}")
    
    # Display final summary
    print(f"\n" + "="*60)
    print("WORKFLOW SUMMARY")
    print("="*60)
    print(f"Genes analyzed: {summary_stats['processed_genes']}/{summary_stats['total_genes']}")
    print(f"Enhancers evaluated: {summary_stats['total_enhancers']}")
    print(f"Positive examples: {summary_stats['positive_enhancers']} ({summary_stats['positive_rate']:.1%})")
    
    if summary_stats['epibert_ap'] is not None:
        print(f"EpiBERT AP: {summary_stats['epibert_ap']:.4f}")
        print(f"ABC AP: {summary_stats['abc_ap']:.4f}")
        improvement = ((summary_stats['epibert_ap'] - summary_stats['abc_ap']) / summary_stats['abc_ap'] * 100)
        print(f"Improvement: {improvement:+.1f}%")
    
    print(f"Distance filter: {summary_stats['distance_filter']}")
    print(f"Output directory: {OUTPUT_DIR}")
    print("="*60)
    
else:
    print("✗ No results to export")

print("\n✓ Workflow completed successfully!")
