In [None]:
############################################################
################### 1. General setup #######################
################### Mostly not model specific ##############
############################################################

import sys
sys.path.append('../src')
import argparse
import torch
from SAE_models import get_cfg, TopKSAE, VanillaSAE, JumpReLUSAE, BatchTopKSAE
from SAE_training import SAETraining
from torch.utils.data import DataLoader
import numpy as np
import json

# SpliceAI specific imports
from datasets import LiaoDatasetEmbedded
from SpliceAI import SpliceAI

In [None]:
#Add spliceAI specific arguments

spliceai_args = {
    'k': 32,  # Number of filters in convolutional layers
    'wsets': [11],  # Kernel widths for each MegaBlock
    'dsets': [1],  # Dilation rates for each MegaBlock
    'nt_dims': 4,  # Number of nucleotide dimensions
    'output_dim': 3,  # Number of output dimensions
    'dropout_rate': None,  # Dropout rate
    'block_count': 4,  # Number of MegaBlocks
    'hook_point': 'mb 1',  # Embedding layer to use
    'embedding_dim': 32,  # Embedding dimension
    'input_length': 176,  # Input length
    'positions_to_use': [0,75],  # Positions to use for training
    'csv_path': './data/Liao_Dataset/liao_training_set.csv',  # Path to training set CSV file
    'plasmid_path': './data/Liao_Dataset/liao_plasmid.gbk',  # Path to plasmid file
    'test_csv_path': './data/Liao_Dataset/liao_test_set.csv',  # Path to test set CSV file
    'flanking_len': 6,  # Flanking length
    'add_context_len': True,  # Are positions relative to input length rather than context length
    'auto_find_bc_pos': True,  # Enable automatic finding of BC positions
    'auto_find_ex_pos': True,  # Enable automatic finding of exon positions
    'preload': True,  # Enable data preloading
    'preload_embeddings': True,  # Enable embeddings preloading
    'num_workers': 16,  # Number of workers for data loading
    'spliceai_batch_size': 100,  # Batch size for training
    'h5_file': 'SpliceAI_Models/SpliceNet80_g1.h5',  # Path to SpliceAI model file
    'model_name': 'SpliceAI'  # Name of the model
}

cfg = get_cfg(spliceai_args)

cfg

: 

In [None]:
trainer = SAETraining(cfg)

: 

In [None]:
# Setup the SpliceAI embedding dataset

def one_hot_encode(x):
    """
    Convert DNA sequence to one-hot encoded tensor.
    
    Parameters
    ----------
    x : numpy.ndarray
        Input sequence array
        
    Returns
    -------
    torch.Tensor
        One-hot encoded tensor of shape (4, sequence_length)
    """
    var_ar = x[:,None] == np.array(['A','C','G', 'T'])
    var_ar = var_ar.T
    var_t = torch.Tensor(var_ar).float()
    return(var_t)


spliceai_model = SpliceAI(k=cfg['k'], 
                    wsets=cfg['wsets'], 
                    dsets=cfg['dsets'], 
                    nt_dims=cfg['nt_dims'], 
                    output_dim=cfg['output_dim'], 
                    dropout_rate=cfg['dropout_rate'], 
                    block_count=cfg['block_count'], 
                    embedding_layer=cfg['hook_point'], 
                    embedding_dim=cfg['embedding_dim'], 
                    input_length=cfg['input_length'], 
                    positions_to_use=cfg['positions_to_use'])

spliceai_model.load_from_h5_file(cfg['h5_file'])
spliceai_model.mode = 'embed'

cfg['context_len'] = spliceai_model.cl

if cfg['add_context_len']:
    cfg['positions_to_use'] = [pos+cfg['context_len']//2 for pos in cfg['positions_to_use']]


transform_x = one_hot_encode

full_train_ds = LiaoDatasetEmbedded(cfg['csv_path'], 
                            cfg['plasmid_path'], 
                            cfg['context_len']+cfg['flanking_len'], 
                            spliceai_model, 
                            cfg['auto_find_bc_pos'],
                            auto_find_ex_pos=cfg['auto_find_ex_pos'], 
                            batch_size=cfg['batch_size'], 
                            transform_x=transform_x, 
                            preload=cfg['preload'], 
                            preload_embeddings=cfg['preload_embeddings'], 
                            trainer=trainer.trainer, 
                            num_workers=cfg['num_workers'])
full_train_ds.open()
train_size = int(len(full_train_ds)*0.8)
ids = np.random.permutation(len(full_train_ds))
train_ds = torch.utils.data.dataset.Subset(full_train_ds, ids[:train_size])
train_dl = torch.utils.data.dataloader.DataLoader(train_ds, batch_size=cfg['spliceai_batch_size'], num_workers=cfg['num_workers'])
val_ds = torch.utils.data.dataset.Subset(full_train_ds, ids[train_size:])
val_dl = torch.utils.data.dataloader.DataLoader(val_ds, batch_size=cfg['spliceai_batch_size'], num_workers=cfg['num_workers'])

test_ds = LiaoDatasetEmbedded(cfg['test_csv_path'], 
                            cfg['plasmid_path'], 
                            cfg['context_len']+cfg['flanking_len'], 
                            spliceai_model, 
                            cfg['auto_find_bc_pos'],
                            auto_find_ex_pos=cfg['auto_find_ex_pos'], 
                            batch_size=cfg['batch_size'], 
                            transform_x=transform_x, 
                            preload=cfg['preload'], 
                            preload_embeddings=cfg['preload_embeddings'], 
                            trainer=trainer.trainer, 
                            num_workers=cfg['num_workers'])
test_ds.open()
test_dl = torch.utils.data.dataloader.DataLoader(test_ds, batch_size=cfg['spliceai_batch_size'], num_workers=cfg['num_workers'])
    

: 

In [None]:
############################################################
################### 5. Training ############################
################### Not Model specific #####################
############################################################

final_model = trainer.train(model, train_dl, val_dl)


############################################################
################### 6. testing/validation ##################
################### Not Model specific #####################
############################################################

val_metrics = trainer.validate(val_dl)

print(val_metrics)
with open(cfg['outpath'] + f"{cfg['name']}_{cfg['seed']}_val_metrics.json", 'w') as f:
    json.dump(val_metrics, f)

test_metrics = trainer.test(test_dl)

print(test_metrics)
with open(cfg['outpath'] + f"{cfg['name']}_{cfg['seed']}_test_metrics.json", 'w') as f:
    json.dump(test_metrics, f)