# Project: Visualizing and decoding chromatin dynamics with Atacformer


## Introduction

We have developed a model called Atacformer, a transformer-based neural network trained on a large corpus of single-cell ATAC-seq (scATAC-seq) data. Atacformer learns biologically meaningful representations (embeddings) of individual cells based on their chromatin accessibility profiles.

When new scATAC-seq data is passed through Atacformer, it produces low-dimensional embeddings of each cell. These embeddings can be used for visualization, clustering, classification, and interpretability analyses — offering a deep, data-driven view into regulatory identity.

This project applies Atacformer to a rich time-course dataset to visualize cellular trajectories, predict developmental time, and interpret which genomic regions drive embedding formation using attention maps.

## Goal

The goal of this project is to apply Atacformer to the [GSE242421](https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE242421) dataset to:

1. Project and visualize cell embeddings over time  
2. Train a model to predict each cell’s timepoint from its Atacformer embedding  
3. Extract and explore attention scores to interpret which genomic regions are most important at each stage  

## Project Steps

### Preparation

1. Read the paper (PMC10592962) to understand the biological context and timeline of reprogramming  
2. Explore the Geniml library (https://docs.bedbase.org/geniml/) for manipulating scATAC fragment files  
3. Familiarize yourself with Atacformer input format, expected preprocessing steps, and model interface  

### Step 1: Project Timepoint-Specific Embeddings

1. Download and prepare data:  
   - Get the 9 scATAC-seq fragment files (1 per timepoint) from GSE242421  
   - Organize by timepoint (day 0 through day 8)  

2. Preprocess:
   - Might be necessary to use SnapATAC2 for QC on the fragments
   - Use Geniml to gtokenize fragment files  
   - Format data as needed for Atacformer input  

3. Project through Atacformer:  
   - Embed each cell using the pretrained model, for each time point

4. Visualize:  
   - Fit a UMAP on all embeddings together  
   - For each timepoint:  
     - Plot all original Atacformer model cells in light gray  
     - Overlay current timepoint in color  
     - Save as a time-lapse frame to visualize trajectory over time  

### Step 2: Integrated Analysis — Predicting Reprogramming Time

In addition to visualizing individual timepoints, perform an integrated analysis using all cells at once.

1. Create dataset:  
   - Combine all cells’ embeddings into one matrix  
   - Assign each cell its corresponding timepoint as a label (e.g., integer from 0–8)  

2. Train a classifier:  
   - Use 80% of the data to train a model that predicts timepoint from the embedding (options: logistic regression, random forest, MLP)  
   - Hold out 20% for testing  

3. Evaluate performance:  
   - Report accuracy, confusion matrix, and regression error (e.g., RMSE if using a regressor)  

4. Feature and region importance:  
   - Identify which embedding features or original genomic regions are most predictive of time  
   - Use SHAP, feature importances, or attention-based explanations if possible  

### Step 3: Explore Attention Scores — Interpreting the Embedding

We're now going inside the black box to explore which genomic regions the model is using to form each cell’s embedding.

#### What are attention scores?

- In a transformer, each token corresponds to a genomic region (e.g., binned genome or peak)  
- Attention scores tell us which regions the model focused on when embedding a cell  

#### What to do

1. Extract per-cell attention scores:  
   - From the model’s forward pass, extract attention matrices per layer and head  
   - For each cell, identify which input tokens (regions) had the highest cumulative attention  

2. Aggregate to pseudobulks:  
   - Group cells by timepoint (or cluster)  
   - Average attention scores across cells to get a pseudobulk attention profile  
   - This gives you a ranked list of regions per timepoint, based on how much attention they received  

3. Compare and correlate:  
   - Compare attention-ranked regions across timepoints  
   - Correlate high-attention regions with:  
     - Known enhancers or promoters  
     - Transcription factor binding motifs (e.g., ASCL1, NEUROD1)  
     - External ChIP-seq/ATAC-seq annotations  
     - Regions known to be active during reprogramming (from the original paper)  

4. Optional visualization:  
   - Generate genome-browser–style attention tracks per timepoint  
   - Highlight which loci (e.g., Pou3f4, Myt1l) are gaining or losing attention over time  

## Additional Notes

- If attention scores are difficult to access directly, contact Nathan LeRoy, author of Atacformer.
- Consider simplifying attention score aggregation (e.g., averaging over heads/layers) for interpretability
- In case it doesn't do well on fibroblast data (because maybe fibroblasts weren't well represented in the original data), it might make sense to do an unsupervised fine-tuning step prior to this, to help the foundation model out.
- For the attention scores, this may be difficult or impossible due to the use of FlashAttention. Will need to explore to figure out what we can do.


## Let's begin!

### Download all the files  
Data is organized by days: D0-> D14

In [None]:
!wget ftp://ftp.ncbi.nlm.nih.gov/geo/samples/GSM7763nnn/GSM7763395/suppl/GSM7763395_D0.frag.bed.gz
!wget ftp://ftp.ncbi.nlm.nih.gov/geo/samples/GSM7763nnn/GSM7763396/suppl/GSM7763396_D2.frag.bed.gz
# !wget ftp://ftp.ncbi.nlm.nih.gov/geo/samples/GSM7763nnn/GSM7763397/suppl/GSM7763397_D4.frag.bed.gz
# !wget ftp://ftp.ncbi.nlm.nih.gov/geo/samples/GSM7763nnn/GSM7763398/suppl/GSM7763398_D6.frag.bed.gz
# !wget ftp://ftp.ncbi.nlm.nih.gov/geo/samples/GSM7763nnn/GSM7763399/suppl/GSM7763399_D8.frag.bed.gz
# !wget ftp://ftp.ncbi.nlm.nih.gov/geo/samples/GSM7763nnn/GSM7763400/suppl/GSM7763400_D10.frag.bed.gz
# !wget ftp://ftp.ncbi.nlm.nih.gov/geo/samples/GSM7763nnn/GSM7763401/suppl/GSM7763401_D12.frag.bed.gz
# !wget ftp://ftp.ncbi.nlm.nih.gov/geo/samples/GSM7763nnn/GSM7763402/suppl/GSM7763402_D14.frag.bed.gz


### Quality Control on dataset

In [None]:
!pip install --upgrade snapatac2
!pip install kaleido

During import fragments, it computes only basic QC metrics like the number of unique fragments per cell, fraction of duplicated reads and fraction of mitochondrial read

In [None]:
!wget ftp://ftp.ncbi.nlm.nih.gov/geo/samples/GSM7763nnn/GSM7763395/suppl/GSM7763395_D0.frag.bed.gz

In [None]:
import snapatac2 as snap

fragment_files = [
    "GSM7763395_D0.frag.bed.gz",
    "GSM7763396_D2.frag.bed.gz"
    # "GSM7763397_D4.frag.bed.gz",
    # "GSM7763398_D6.frag.bed.gz",
    # "GSM7763399_D8.frag.bed.gz",
    # "GSM7763400_D10.frag.bed.gz",
    # "GSM7763401_D12.frag.bed.gz",
    # "GSM7763402_D14.frag.bed.gz"
]

data_objects = [
    snap.pp.import_fragments(file, chrom_sizes=snap.genome.hg38, sorted_by_barcode=False)
    for file in fragment_files
]


In [None]:
# write to h5ad file
out_filenames = []
for fragment_file, data in zip(fragment_files, data_objects):
    base_name = fragment_file.replace(".frag.bed.gz", "")
    out_filename = base_name + ".h5ad"
    out_filenames.append(out_filename)
    data.write(out_filename)
    print(f"Saved {out_filename}")

Compute TSS enrichment scores and plot them to identify high quality reusable cells

In [None]:
tsse_objects = [
    snap.metrics.tsse(data, snap.genome.hg38) for data in data_objects
]

for tsse in tsse_objects:
    snap.pl.tsse(tsse, interactive=True)

In [None]:
for data in data_objects:
  snap.pp.filter_cells(data, min_counts=5000, min_tsse=10, max_counts=100000)

### Tokenization of data

In [None]:
!pip install --upgrade snapatac2 --use-deprecated=legacy-resolver
!pip install kaleido --use-deprecated=legacy-resolver
!pip install geniml[ml] --use-deprecated=legacy-resolver

In [None]:
!uv pip install git+https://github.com/databio/geniml.git

In [None]:
!python -c "from geniml import __version__; print(__version__)"


In [None]:
!pip install numpy

In [None]:
!pip install --upgrade huggingface_hub

In [None]:
from geniml.atacformer import AtacformerForCellClustering

atac_model = AtacformerForCellClustering.from_pretrained("databio/atacformer-base-hg38")
atac_model = atac_model.to("cuda")

In [None]:
# Embed to Atacformer
import scanpy as sc

from gtars.tokenizers import Tokenizer
# from geniml.tokenization import AnnDataTokenizer
import geniml.tokenization
print(dir(geniml.tokenization))


tokenizer = Tokenizer.from_pretrained("databio/atacformer-base-hg38")
data_objects = []
cell_embeddings = []


for out_filename in out_filenames:
  data = sc.read_h5ad(out_filename)
  data_objects.append(data)
  tokens = tokenize_anndata(data, tokenizer)
  input_ids = [t["input_ids"] for t in tokens]
  cell_embedding = atac_model.encode_tokenized_cells(
    input_ids=input_ids,
    batch_size=32,  # adjust based on your memory capacity
  )
  cell_embeddings.append(cell_embedding)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from umap import UMAP

for i,data in enumerate(data_objects):
  data.obsm["X_atacformer"] = cell_embeddings[i].cpu().numpy()

all_embeddings = np.vstack([data.obsm['X_atacformer'] for data in data_objects])
timepoint_days = list(range(0, 16, 2))

dataset_labels = []
for i, data in enumerate(data_objects):
    dataset_labels.extend([f"dataset_{i}"] * data.n_obs)
dataset_labels = np.array(dataset_labels)

umap_model = UMAP(n_neighbors=15, random_state=42)
combined_umap = umap_model.fit_transform(all_embeddings)

plt.figure(figsize=(8,6))
sns.scatterplot(
    x=combined_umap[:,0],
    y=combined_umap[:,1],
    hue=dataset_labels,
    palette='tab10',
    s=10
)
plt.title("UMAP of all datasets combined")
plt.show()


### Step 2: Integrated Analysis — Predicting Reprogramming Time

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

X_train, X_test, y_train, y_test = train_test_split(
    all_embeddings, timepoint_days,
    test_size=0.2,
    random_state=42
)

In [None]:
models = {
    'LR': LogisticRegression(
        max_iter=1000,
        random_state=42,
        multi_class='ovr'
    ),
    'RF': RandomForestClassifier(
        n_estimators=100,
        random_state=42,
        n_jobs=-1
    ),
    'MLP': MLPClassifier(
        hidden_layer_sizes=(128, 64),
        max_iter=500,
        random_state=42,
        early_stopping=True,
        validation_fraction=0.1
    )
}

# scale the features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Train and evaluate each model
results = {}
for name, model in models.items():
    print(f"\n{'='*50}")
    print(f"Training {name}...")

    if 'MLP' in name:
        X_tr, X_te = X_train_scaled, X_test_scaled
    else:
        X_tr, X_te = X_train, X_test

    # Train
    model.fit(X_tr, y_train)

    # Predict
    y_pred_train = model.predict(X_tr)
    y_pred_test = model.predict(X_te)

    # Get prediction probabilities (useful for analysis)
    if hasattr(model, 'predict_proba'):
        y_prob_test = model.predict_proba(X_te)
    else:
        y_prob_test = None

    # Calculate metrics
    train_acc = accuracy_score(y_train, y_pred_train)
    test_acc = accuracy_score(y_test, y_pred_test)
    if 'LR' in name:
      train_rmse = np.sqrt(mean_squared_error(y_train, y_pred_train))
      test_rmse = np.sqrt(mean_squared_error(y_test, y_pred_test))
      train_mae = mean_absolute_error(y_train, y_pred_train)
      test_mae = mean_absolute_error(y_test, y_pred_test)
      test_r2 = r2_score(y_test, y_pred_test)

      # Store results
      results[name] = {
        'model': model, 'test_accuracy': test_acc, 'test_rmse': test_rmse,
        'test_mae': test_mae, 'test_r2': test_r2, 'y_pred_test': y_pred_test
      }

      print(f"Accuracy: {test_acc:.3f} | RMSE: {test_rmse:.2f} | MAE: {test_mae:.2f} | R²: {test_r2:.3f}")
    else:
      results[name] = {
          'model': model, 'test_accuracy': test_acc, 'y_pred_test': y_pred_test
        }

      print(f"Accuracy: {test_acc:.3f}")

In [None]:
# Compare all models in one plot
fig, axes = plt.subplots(1, len(results), figsize=(5*len(results), 4))
if len(results) == 1:
    axes = [axes]

unique_timepoints = sorted(np.unique(y_test))
labels = [f'Day {int(tp)}' for tp in unique_timepoints]

for idx, (name, result) in enumerate(results.items()):
    cm = confusion_matrix(y_test, result['y_pred_test'])

    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=labels, yticklabels=labels, ax=axes[idx])

    axes[idx].set_title(f'{name}\nAcc: {result["test_accuracy"]:.3f}')
    axes[idx].set_xlabel('Predicted')
    if idx == 0:
        axes[idx].set_ylabel('Actual')

plt.tight_layout()
plt.show()

In [None]:
# https://www.geeksforgeeks.org/machine-learning/shap-a-comprehensive-guide-to-shapley-additive-explanations/

In [None]:
!pip install xgboost shap pandas scikit-learn matplotlib ipywidgets

In [None]:
import shap

for model in models:
  explainer = shap.Explainer(model)
  shap_values = explainer(X_test)
  shap.initjs()

  shap.summary_plot(shap_values, X_test)

### Step 3: Explore Attention Scores — Interpreting the Embedding

In [None]:
# using atac_model
# 4. Run forward pass to get attention
atac_model.eval()
with torch.no_grad():
    outputs = model(
        input_ids=torch.tensor(input_ids).to("cuda"),
        output_attentions=True
    )

    # outputs.attentions is a tuple: (num_layers, batch_size, num_heads, seq_len, seq_len)
    attentions = outputs.attentions

# Example: Get cumulative attention for first cell
cell_idx = 0
cell_attention = torch.stack(attentions)[:, cell_idx]  # (num_layers, num_heads, seq_len, seq_len)

# Average across heads and layers
avg_attention = cell_attention.mean(dim=0).mean(dim=0)  # (seq_len, seq_len)

# Cumulative importance per token
token_importance = avg_attention.sum(dim=0)  # Sum over source positions

# Get top 10 most attended genomic regions
top_tokens = torch.topk(token_importance, 10)
print("Top token indices:", top_tokens.indices.cpu().numpy())
print("Top attention scores:", top_tokens.values.cpu().numpy())

In [None]:
print(f"🔍 Analyzing attention for {len(input_ids_list)} cells...")

    # Extract attention scores using your existing setup
    all_attention = []
    atac_model.eval()

    for i in tqdm(range(0, len(input_ids_list), batch_size)):
        batch = input_ids_list[i:i+batch_size]

        # Pad batch to same length
        max_len = max(len(ids) for ids in batch)
        padded = [ids + [0]*(max_len-len(ids)) for ids in batch]
        masks = [[1]*len(ids) + [0]*(max_len-len(ids)) for ids in batch]

        # Extract attention (using your approach!)
        with torch.no_grad():
            outputs = atac_model(
                input_ids=torch.tensor(padded).to("cuda"),
                attention_mask=torch.tensor(masks).to("cuda"),
                output_attentions=True
            )

        # Process each cell in batch
        for j, original_ids in enumerate(batch):
            # Average across layers/heads, sum for token importance
            cell_attn = torch.stack(outputs.attentions)[:, j].mean(dim=[0,1]).sum(dim=0)
            all_attention.append(cell_attn[:len(original_ids)].cpu().numpy())

    print("📊 Creating pseudobulk profiles...")

    # Group cells and average attention
    groups = defaultdict(list)
    for i, group in enumerate(cell_groups[:len(all_attention)]):
        groups[group].append(all_attention[i])

    # Create profiles for each group
    profiles = {}
    for group_name, group_attentions in groups.items():
        print(f"  Processing {group_name}: {len(group_attentions)} cells")

        # Stack, pad, and average
        max_regions = max(len(scores) for scores in group_attentions)
        matrix = np.zeros((len(group_attentions), max_regions))
        counts = np.zeros(max_regions)

        for i, scores in enumerate(group_attentions):
            matrix[i, :len(scores)] = scores
            counts[:len(scores)] += 1

        # Average where we have data
        avg_scores = np.divide(matrix.sum(axis=0), counts,
                              out=np.zeros(max_regions), where=counts>0)

        # Get top regions
        top_idx = np.argsort(avg_scores)[::-1][:top_k]

        profiles[group_name] = pd.DataFrame({
            'region_index': top_idx,
            'attention_score': avg_scores[top_idx],
            'rank': range(1, len(top_idx)+1),
            'n_cells': len(group_attentions)
        })