In [1]:
# The bpnet pipeline saves the model as a .pb file. We have to convert this to an h5 file for compatibility with our pipeline
# import tensorflow as tf

# # Load the SavedModel
# model = tf.keras.models.load_model(model_path)

# # Convert and reset model_path
# model_path = f'{base_dir}/resources/model.h5'
# if hasattr(model, 'save'):
#     model.save(model_path)

# Run Pipeline

## Set Filepaths and Variables

In [1]:
# Base directory for all inputs / outputs
base_dir = f'/scratch/users/jgalante/SparseNet/p53'

# Paths to resources
peaks_bed = f"{base_dir}/resources/peaks_inliers.bed"
seqs = f'{base_dir}/resources/genome.fa'
signal_plus = f"{base_dir}/resources/experiment_plus.bw"
signal_minus = f"{base_dir}/resources/experiment_minus.bw"
ctl_plus = f"{base_dir}/resources/control_plus.bw"
ctl_minus = f"{base_dir}/resources/control_minus.bw"
negatives_bed = f"{base_dir}/resources/gc_negatives.bed"
model_path = f'{base_dir}/resources/model.h5'

# Output Directory
out_dir = f'{base_dir}/results'
pwm_root_dir = f"{out_dir}/pwms"

In [2]:
# Set variables
expansion_factor = 4.0
topk_fraction = 0.05
center_length = 1000
epochs = 3
inner_bs = 16384
lr = 1e-3

# pwm settings
num_samples_per_node = 1000
num_top_nodes = int(expansion_factor * 64)
latent_dim = int(expansion_factor * 64)

## Create Datasets

In [4]:
# Prepare training data and other information
import numpy
import os
import torch
from bpnetlite.io import PeakGenerator
from scripts import deterministic_data_loaders as ddl
from scripts import models as mds
from scripts import SAE_trainer as st
from scripts import save_activations as sa
from scripts import plot_activation_frequencies as nf
from scripts import create_PWM_for_nodes as pwm

training_data = PeakGenerator(
    peaks = peaks_bed,
    negatives = negatives_bed,
    sequences = seqs,
    signals = [signal_plus, signal_minus],
    controls = [ctl_plus, ctl_minus],
    chroms = None,
    in_window = 2114,
    out_window = 1000,
    max_jitter = 128,
    negative_ratio = 0.33,
    reverse_complement = True,
    shuffle = True,
    min_counts = None,
    max_counts = None,
    summits = False,
    exclusion_lists = None,
    random_state = 12345,
    pin_memory = True,
    num_workers = 0,
    batch_size = 64,
    verbose = True
)

# Initialize the dataloader to pass all peaks through SAEs and capture activations in order of bed file
sae_testing_data = ddl.DeterministicPeakGenerator(
    peaks=[peaks_bed, negatives_bed],
    sequences=seqs,
    signals=[signal_plus, signal_minus],
    chroms=None,
    in_window=2114,
    out_window=1000,
    pin_memory=True,
    batch_size=64,
    verbose=True
)


Loading Loci:   0%|          | 0/12845 [00:00<?, ?it/s]

Loading Loci: 100%|██████████| 12845/12845 [00:08<00:00, 1557.97it/s]


## Train SAEs and Notate them with PWMs

In [None]:
# Instantiate the trainer object
trainer = st.SAETrainer(model_path=model_path, device="cuda", center_len=center_length)

# Train an SAE for each layer on all training data with given hyperparameters
trainer.train_all(
    train_loader=training_data,
    sae_cls=mds.SAETopK,
    sae_kwargs={"latent_multiplier": expansion_factor, "k_fraction": topk_fraction},
    save_dir=f'{out_dir}/models',
    logs_dir=f'{out_dir}/models/logs',
    epochs=epochs,
    inner_bs=inner_bs,
    lr=lr,
    log_every=50,
)

# Load the trained SAEs from disk written by trainer.train_all()
sae_models = mds.load_saes_from_dir(save_dir=f'{out_dir}/models', layers=trainer.layers, device=trainer.device)

# Run Top-K collection over deterministic data loader
meta = sa.collect_topk_indices_to_disk_from_trainer(
    trainer=trainer,
    sae_models=sae_models,
    loader=sae_testing_data,
    out_dir=f'{out_dir}/activations',
)

# Plot node activation frequencies
nf.plot_node_activation_frequencies(
	num_layers = len(trainer.layers), 
	latent_dim = int(64*expansion_factor), 
	data_dir = f'{out_dir}/activations'
)

pwm.compute_pwms_for_all_layers(
    trainer=trainer,
    loader=sae_testing_data,
    activations_dir=f"{out_dir}/activations",
    pwm_root_dir=pwm_root_dir,
    latent_dim=latent_dim,
    num_top_nodes=num_top_nodes,
    num_samples_per_node=num_samples_per_node
)

# Plots

## Plot nodes of multiple positions in one layer

In [None]:
from scripts import plot_position_nodes as plotpos

layer_to_view = 8
pwm_dir_for_layer = f"{out_dir}/pwms/layer{layer_to_view}"

viz = plotpos.SequenceNodeVisualizer(
    activations_dir=f"{out_dir}/activations",
    pwm_dir=pwm_dir_for_layer,
    loader=sae_testing_data,
)

viz.plot_sample(sample_idx=0, layer_idx=layer_to_view, top_n_nodes=50)
# viz.show_position_range_logos(start_pos=160, end_pos=170, threshold=0.0)
# viz.show_position_range_logos(start_pos=355, end_pos=365, threshold=0.0)
# viz.show_position_range_logos(start_pos=490, end_pos=501, threshold=0.0)

## Plot top nodes of multiple positions over all layers

In [None]:
viz.show_position_range_all_layers(
    sample_idx=0, 
    start_pos=160, 
    end_pos=170,
    pwm_root_dir=f"{out_dir}/pwms",
    threshold=0.0,
    max_nodes_per_pos=3
)

# Variant Testing

Marginalizing a p53 motif in background sequence motif versus p53 sequences in natural context

Looking at a p53 motif from one sample and mutating each base pair to see how the nodes and magnitudes react
- analyzing to understand diversity of nodes/magnitudes activated with each variant
- i.e. do all of these changes get written into magnitudes in the later layers. are they more indexed in earlier layers

In [None]:
# I need to take one canonical p53 motif from the samples [CATG CCCGGG ATG]
# I need to change each nucleotide to test 16*4 different inputs
# I want to track, for each of these inputs, the node and magnitude in each of the motif's positions
# I wnat to also track differences in flanking positions

# I then want to take a background sequence and marginalize a p53 sequence
# I want to test if there are differences in node and magnitude when marginalizing versus comparing to samples with that motif

# Okay i couldn't find like any example of the canonical motif that was not in an ERV... which is interesting in itself
# We're going to just : chr3:32364186-32364542
# chr8:71,530,288-71,531,938 (specific from below)
# chr8:68,226,153-72,600,681 (all)

In [None]:
# After I do this and understand some of the results in a controlled environment
# I want to do the same to ChromBPNet on the sequences used for variant-Effects
# Then can try to make a statement, based on node usage on which sequences might be more biological

# Transposon Testing

In [None]:
# p53 binding sites come from ERV events - can our network detect which of these are ERVs
# Do these sites look more like marginalized sequences or other...? idk if that question makes sense
# Are these mapping errors?