In [1]:
############################################################
################### 1. General setup #######################
################### Mostly not model specific ##############
############################################################

import sys
sys.path.append('../src')
import argparse
import torch
from SAE_models import get_cfg, TopKSAE, VanillaSAE, JumpReLUSAE, BatchTopKSAE
from SAE_training import SAETraining
from torch.utils.data import DataLoader
import numpy as np
import json

# SpliceAI specific imports
from datasets import LiaoDatasetEmbedded
from SpliceAI import SpliceAI

In [27]:
#Add spliceAI specific arguments

spliceai_args = {
    'k': 32,  # Number of filters in convolutional layers
    'wsets': [11],  # Kernel widths for each MegaBlock
    'dsets': [1],  # Dilation rates for each MegaBlock
    'nt_dims': 4,  # Number of nucleotide dimensions
    'output_dim': 3,  # Number of output dimensions
    'dropout_rate': None,  # Dropout rate
    'block_count': 4,  # Number of MegaBlocks
    'hook_point': 'mb 1',  # Embedding layer to use
    'embedding_dim': 64,  # Embedding dimension
    'input_length': 176,  # Input length
    'positions_to_use': [0,75],  # Positions to use for training
    'csv_path': '../data/Liao_Dataset/liao_training_set.csv',  # Path to training set CSV file
    'plasmid_path': '../data/Liao_Dataset/liao_plasmid.gbk',  # Path to plasmid file
    'test_csv_path': '../data/Liao_Dataset/liao_test_set.csv',  # Path to test set CSV file
    'flanking_len': 6,  # Flanking length
    'add_context_len': True,  # Are positions relative to input length rather than context length
    'auto_find_bc_pos': True,  # Enable automatic finding of BC positions
    'auto_find_ex_pos': True,  # Enable automatic finding of exon positions
    'preload': True,  # Enable data preloading
    'preload_embeddings': True,  # Enable embeddings preloading
    'num_workers': 0,  # Number of workers for data loading
    'spliceai_batch_size': 100,  # Batch size for training
    'h5_file': '../SpliceAI_Models/SpliceNet80_g1.h5',  # Path to SpliceAI model file
    'model_name': 'SpliceAI'  # Name of the model
}

cfg = get_cfg(epochs=10, top_k=10, top_k_aux=50, aux_penalty=1/10, act_size=64, dic_size=128, **spliceai_args)

cfg

{'seed': 49,
 'batch_size': 4096,
 'lr': 0.0003,
 'l1_coeff': 0,
 'beta1': 0.9,
 'beta2': 0.99,
 'max_grad_norm': 100000,
 'dtype': torch.float32,
 'act_size': 64,
 'dict_size': 12288,
 'wandb_project': 'sparse_autoencoders',
 'input_unit_norm': True,
 'perf_log_freq': 1000,
 'sae_type': 'topk',
 'checkpoint_freq': 10000,
 'n_batches_to_dead': 5,
 'warmstart_batches': 1000,
 'warmstart_start_factor': 0.0001,
 'warmstart_end_factor': 1,
 'scheduler': 'RedOnPlateau',
 'weight_decay': 0.0001,
 'reduceLROnPlateau_factor': 0.1,
 'reduceLROnPlateau_patience': 4,
 'reduceLROnPlateau_threshold': 0.0001,
 'reduceLROnPlateau_cooldown': 0,
 'reduceLROnPlateau_min': 0,
 'reduceLROnPlateau_eps': 1e-08,
 'epochs': 10,
 'training_set_batches': 1000,
 'outpath': './out/',
 'min_delta': 0,
 'patience': 10,
 'accelerator': 'auto',
 'devices': 'auto',
 'include_checkpointing': True,
 'include_early_stopping': True,
 'track_LR': True,
 'top_k': 10,
 'top_k_aux': 50,
 'aux_penalty': 0.1,
 'bandwidth': 0.00

In [24]:
trainer = SAETraining(cfg)

Seed set to 49


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [25]:
# Setup the SpliceAI embedding dataset

def one_hot_encode(x):
    """
    Convert DNA sequence to one-hot encoded tensor.
    
    Parameters
    ----------
    x : numpy.ndarray
        Input sequence array
        
    Returns
    -------
    torch.Tensor
        One-hot encoded tensor of shape (4, sequence_length)
    """
    var_ar = x[:,None] == np.array(['A','C','G', 'T'])
    var_ar = var_ar.T
    var_t = torch.Tensor(var_ar).float()
    return(var_t)


spliceai_model = SpliceAI(k=cfg['k'], 
                    wsets=cfg['wsets'], 
                    dsets=cfg['dsets'], 
                    nt_dims=cfg['nt_dims'], 
                    output_dim=cfg['output_dim'], 
                    dropout_rate=cfg['dropout_rate'], 
                    block_count=cfg['block_count'], 
                    embedding_layer=cfg['hook_point'], 
                    embedding_dim=cfg['embedding_dim'], 
                    input_length=cfg['input_length'], 
                    positions_to_use=cfg['positions_to_use'])

spliceai_model.load_from_h5_file(cfg['h5_file'])
spliceai_model.mode = 'embed'

cfg['context_len'] = spliceai_model.cl

if cfg['add_context_len']:
    cfg['positions_to_use'] = [pos+cfg['context_len']//2 for pos in cfg['positions_to_use']]


transform_x = one_hot_encode

full_train_ds = LiaoDatasetEmbedded(cfg['csv_path'], 
                            cfg['plasmid_path'], 
                            cfg['context_len']+cfg['flanking_len'], 
                            spliceai_model, 
                            cfg['auto_find_bc_pos'],
                            auto_find_ex_pos=cfg['auto_find_ex_pos'], 
                            batch_size=cfg['batch_size'], 
                            transform_x=transform_x, 
                            preload=cfg['preload'], 
                            preload_embeddings=cfg['preload_embeddings'], 
                            trainer=trainer.trainer, 
                            num_workers=cfg['num_workers'])
full_train_ds.open()
train_size = int(len(full_train_ds)*0.8)
ids = np.random.permutation(len(full_train_ds))
train_ds = torch.utils.data.dataset.Subset(full_train_ds, ids[:train_size])
train_dl = torch.utils.data.dataloader.DataLoader(train_ds, batch_size=cfg['spliceai_batch_size'], num_workers=cfg['num_workers'])
val_ds = torch.utils.data.dataset.Subset(full_train_ds, ids[train_size:])
val_dl = torch.utils.data.dataloader.DataLoader(val_ds, batch_size=cfg['spliceai_batch_size'], num_workers=cfg['num_workers'])

test_ds = LiaoDatasetEmbedded(cfg['test_csv_path'], 
                            cfg['plasmid_path'], 
                            cfg['context_len']+cfg['flanking_len'], 
                            spliceai_model, 
                            cfg['auto_find_bc_pos'],
                            auto_find_ex_pos=cfg['auto_find_ex_pos'], 
                            batch_size=cfg['batch_size'], 
                            transform_x=transform_x, 
                            preload=cfg['preload'], 
                            preload_embeddings=cfg['preload_embeddings'], 
                            trainer=trainer.trainer, 
                            num_workers=cfg['num_workers'])
test_ds.open()
test_dl = torch.utils.data.dataloader.DataLoader(test_ds, batch_size=cfg['spliceai_batch_size'], num_workers=cfg['num_workers'])
    

'LOCUS       Exported                8612 bp DNA     circular SYN 10-JUN-2020\n'
Found locus 'Exported' size '8612' residue_type 'DNA'
Some fields may be wrong.
/Users/jackdesmarais/miniconda3/envs/architecture_search_env/lib/python3.12/site-packages/lightning/pytorch/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/Users/jackdesmarais/miniconda3/envs/architecture_search_env/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Predicting: |          | 0/? [00:00<?, ?it/s]

'LOCUS       Exported                8612 bp DNA     circular SYN 10-JUN-2020\n'
Found locus 'Exported' size '8612' residue_type 'DNA'
Some fields may be wrong.
/Users/jackdesmarais/miniconda3/envs/architecture_search_env/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Predicting: |          | 0/? [00:00<?, ?it/s]

In [26]:
############################################################
################### 5. Training ############################
################### Not Model specific #####################
############################################################
cfg['training_set_batches'] = len(train_dl)


if cfg['sae_type'] == 'topk':
    model = TopKSAE(cfg)
elif cfg['sae_type'] == 'vanilla':
    model = VanillaSAE(cfg)
elif cfg['sae_type'] == 'jumprelu':
    model = JumpReLUSAE(cfg)
elif cfg['sae_type'] == 'batch_topk':
    model = BatchTopKSAE(cfg)

final_model = trainer.train(model, train_dl, val_dl)


############################################################
################### 6. testing/validation ##################
################### Not Model specific #####################
############################################################

val_metrics = trainer.validate(val_dl)

print(val_metrics)
with open(cfg['outpath'] + f"{cfg['name']}_{cfg['seed']}_val_metrics.json", 'w') as f:
    json.dump(val_metrics, f)

test_metrics = trainer.test(test_dl)

print(test_metrics)
with open(cfg['outpath'] + f"{cfg['name']}_{cfg['seed']}_test_metrics.json", 'w') as f:
    json.dump(test_metrics, f)

/Users/jackdesmarais/miniconda3/envs/architecture_search_env/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Validation: |          | 0/? [00:00<?, ?it/s]


  | Name         | Type | Params | Mode
---------------------------------------------
  | other params | n/a  | 1.6 M  | n/a 
---------------------------------------------
1.6 M     Trainable params
0         Non-trainable params
1.6 M     Total params
6.341     Total estimated model params size (MB)
0         Modules in train mode
0         Modules in eval mode


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_aux_loss         0.003313531633466482
       val_l0_norm                 32.0
       val_l1_loss                  0.0
       val_l1_norm          1.1899683475494385
       val_l2_loss         0.0058435541577637196
        val_loss           0.009157088585197926
  val_num_dead_features       11983.41796875
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/jackdesmarais/miniconda3/envs/architecture_search_env/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined