# Introduction

This notebook is intended to serve as a guide to train custom Multiplexer models. The following sections demonstrate how the BelugaMultiplexer model was trained and includes starter code to generate data, define model parameters, and perform back-propogation. Additionally, since Models may have varying input and output dimensionality, many parameters are left empty for users to fill-in according to the size of their own model.

# Model Training

For dimensions to match the pre-written code in both the training notebook and the command line interface tool, the user models must match the following dimension format: 

For the user's **Base Model**: \
--inputs must be of shape `[batch_size, 4, sequence_length]` \
--the output must be of shape `[batch_size, predicted_features]` 

If the user's base model is already trained and does not match the this input/output format (for example, it may include an extra dimension), it is recommended that users adjust the dimensions of the input/output in the forward method of their model.

For the user's **Multiplexer Model**: \
--inputs must be in the shape `[batch_size, 4, sequence_length]` \
--the output must be in the shape `[batch_size, predicted_features, 4, sequence_length]` 

where: \
--`batch_size` is the number of sequence in the same tensor. \
--`sequence_length` is the number base-pairs in a sequence. \
--`predicted_features` is the number of predictions made for 1 sequence by the base model.

#### 1A) Import packages

In [1]:
import math
import pyfasta
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import numpy as np

import sys
sys.path.append('../models')
from BelugaMultiplexer import BelugaMultiplexer
from Beluga import Beluga

#### 1B) Generate Data

The following methods were used to generate training data for Beluga Multiplexer. `all_mutations` generated a random 2000 base-pair reference sequence as well as a set of alternative sequences that represent every positional mutation. Each base-pair was 1-hot encoded (e.g 'A' = [1,0,0,0]) such that a 4x2000 dimensioned tensor represents the 2,000 base-pair sequence and a 8000x4x2000 dimensioned tensor represents the alternative mutations. Note that the length of a sequence is an adjustable parameter but was set to 2000 for Beluga Multiplexer training.

In [2]:
def all_mutations(pos, chrome_num, length):
    """
    returns an encoded sequence and every possible mutation
    
    Args:
        pos : int
            Center of the sequence
        
        chrome_num : string
            The chromosome the sequence is sampled from
        
        length : int
            The number of base-pairs sampled
    
    Returns:
        encoded_ref: (4,length) sized encoding of the sequence drawn from 'chrome_num' centered at 'pos'
        
        mutations: (4*length, 4, length) sized encoding representing all possible mutations of the reference allele

    """
    
    if length % 2 == 0:
        lower = int(length/2) - 1
        upper = int(length/2)
        
    else:
        lower = int(length/2)
        upper = int(length/2)
    
    seq = genome.sequence({'chr': chrome_num, 'start': pos - lower , 'stop': pos + upper})


    #encode the sequence
    mydict = {'A': torch.tensor([1, 0, 0, 0]), 'G': torch.tensor([0, 1, 0, 0]),
            'C': torch.tensor([0, 0, 1, 0]), 'T': torch.tensor([0, 0, 0, 1]),
            'N': torch.tensor([0, 0, 0, 0]), 'H': torch.tensor([0, 0, 0, 0]),
            'a': torch.tensor([1, 0, 0, 0]), 'g': torch.tensor([0, 1, 0, 0]),
            'c': torch.tensor([0, 0, 1, 0]), 't': torch.tensor([0, 0, 0, 1]),
            'n': torch.tensor([0, 0, 0, 0]), '-': torch.tensor([0, 0, 0, 0])}
    
    
    #this dictionary returns a list of possible mutations for each nucleotide
    mutation_dict = {'a': ['a','g', 'c', 't'], 'A':['a','g', 'c', 't'],
                    'c': ['a','g', 'c', 't'], 'C':['a','g', 'c', 't'],
                    'g': ['a','g', 'c', 't'], 'G':['a','g', 'c', 't'],
                    't': ['a','g', 'c', 't'], 'T':['a','g', 'c', 't'],
                    'n': ['n', 'n', 'n', 'n'], 'N':['n', 'n', 'n', 'n'],
                    '-': ['n', 'n', 'n', 'n']}
    
    #each column is the encoding for each nucleotide in the original seq
    encoded_ref = torch.zeros((4, len(seq)))
    for i in range(len(seq)):
        #this implements the encoding
        encoded_ref[:,i] = mydict[seq[i]]

    
    mutations = torch.tile(encoded_ref, (length*4, 1, 1)) 
    
    for j in range(len(seq)):
        #for each element in the original sequence, create 4 "mutation layers" 
        i = j*4
        mutations[i, :, j] = mydict[mutation_dict[seq[j]][0]]
        mutations[i + 1, :, j] = mydict[mutation_dict[seq[j]][1]]
        mutations[i + 2, :, j] = mydict[mutation_dict[seq[j]][2]]
        mutations[i + 3, :, j] = mydict[mutation_dict[seq[j]][3]]
        
        
    return encoded_ref, mutations

The training data was then generated by taking both the reference and alternative sequences and passing them through the Beluga model in batches. We then calculated a tensor that represented the log-fold difference between these predictions, and used it as our training target: given a reference, we wanted to train the BelugaMultiplexer model to predict the log-fold difference between the reference and all possible alternative predictions made by Beluga.

The methods `training_data` generates training data from every chromosomes but 'chr8' and 'chr9' while `validation_data` only draws DNA sequences from chromosomes 'chr8' and 'chr9'.

The methods `training_data` and `validation_data` create the data while `gen_training_data` and `gen_validation_data` enable users to specify how many sets of training data to use.

In [3]:
def log_fold(alt, ref):
    """
    Returns the log fold change of alt, ref
    
    equals: log(((alt+1e-6) * (1-ref+1e-6)) /((1-alt+1e-6) * (ref+1e-6)) 
    """
    
    e = 10**(-6)
    top = (alt + e)*(1 - ref + e)
    bot = (1 - alt + e) * (ref + e)
    
    return torch.log(top/bot)


    
def training_data(length, training_CHRS, model, batch_size = 64, device = 'cuda'):
    """
    generates 1 training sample. The input is a randomly generated chromosome and the target is a set 
    of model_output_dim predictions. 
    
    Args:
        length : int
            The number of base-pairs in one generated sequence
        
        training_CHRS : array
            Set of chromosomes to be considered for training
        
        model : Neural Network Object
            The model that makes the predictions
        
        batch_size : int
            Size of batch that model takes in
        
        device: 'cpu' or 'cuda'
    
    
    Returns:
        1 training sample (input, target)
    
    """
  
    model = model
    model.eval()
    


    size = []    
    for i in training_CHRS:
        chr_length = len(genome[i])
        size.append(chr_length) 
        size_normalized = size/np.sum(size)
    
    
    #Sample chromosome with probability proportional its length
    training_chromosome = np.random.choice(training_CHRS, p = size_normalized)


    #Reject samples with too many 'N' values
    N_count = 11
    while N_count > 10:
        if length % 2 == 0:
            lower = int(length/2) - 1
            upper = int(length/2)
        
        else:
            lower = int(length/2)
            upper = int(length/2)
    
        pos = np.random.randint(lower, len(genome[training_chromosome]) - upper)
        seq = genome.sequence({'chr': training_chromosome, 'start': pos - lower , 'stop': pos + upper})
        
        #checks number of N's in the sequence is less than 10
        N_count = seq.count("N") 
            

    ref_arr_encoded, alt_arr_encoded = all_mutations(pos, training_chromosome, length)
 
    with torch.no_grad():
        reference_pred = model.forward(ref_arr_encoded.unsqueeze(0).float().to(device))
    
 
    alt_pred_arr = []
    for i in range(int(math.floor(length*4/batch_size))):
        inputs = alt_arr_encoded[i*batch_size : (i+1)*batch_size] 
        input = inputs.to(device).float()
        with torch.no_grad():
            alt_pred_arr.append(model.forward(input))        
    alt_predictions = torch.vstack(alt_pred_arr)

    

    return ref_arr_encoded, log_fold(alt_predictions, reference_pred)




def validation_data(length, validation_CHRS,  model, batch_size = 16, device = 'cuda'):
    """
    generates 1 validation sample. The input is a randomly generated chromosome and the target [FIX!!!!!]. 
    
    Args:
        length : int
            The number of base-pairs in one generated sequence
        
        validation_CHRS : array
            Set of chromosomes to be considered for validation testing
        
        model : Neural Network Object
            The model that makes the predictions
        
        batch_size : int
            Size of batch that model takes in
        
        device: 'cpu' or 'cuda'
    
    
    Returns: 2 torch.tensor 
        1 validation sample (input, target)
    
    """

    model = model
    model.eval()
 
    #Sample chromosome with probability proportional its length
    val_probability = [len(genome[i]) for i in validation_CHRS]
    val_prob_normalized = val_probability/np.sum(val_probability)
    val_chromosome = np.random.choice(["chr8", "chr9"], p = val_prob_normalized)
    
    
    #Reject samples with too many 'N' values
    N_count = 11
    while N_count > 10:
        if length % 2 == 0:
            lower = int(length/2) - 1
            upper = int(length/2)
        
        else:
            lower = int(length/2)
            upper = int(length/2)
    
        pos = np.random.randint(lower, len(genome[val_chromosome]) - upper)
        seq = genome.sequence({'chr': val_chromosome, 'start': pos - lower , 'stop': pos + upper})
        
        #checks number of N's in the sequence is less than 10
        N_count = seq.count("N")  
            
    ref_arr_encoded, alt_arr_encoded = all_mutations(pos, val_chromosome, length)
    
    
    ref_input = ref_arr_encoded.unsqueeze(0).to(device).float()
    with torch.no_grad():
        reference = model.forward(ref_input)

    alt_pred_arr = []
    for i in range(int(math.floor(length*4/batch_size))):
        inputs = alt_arr_encoded[i*batch_size : (i+1)*batch_size] 
        input = inputs.to(device).float()
        with torch.no_grad():
            alt_pred_arr.append(model.forward(input))        
    alt_predictions = torch.vstack(alt_pred_arr)


    return ref_arr_encoded, log_fold( alt_predictions, reference)
    


`gen_training_data` and `gen_validation_data` generate mutliple batches of training data and validation data, respectively. Within these methods, users can specify the number of sequences they want to generate, the length of each sequence, the original model they want to train the mutliplexer on, and the dimension of the original model's output. These methods then each return a tensor that contains `num_seqs` (input, target) pairs that can directly be used to train or validate the multiplexer model

In [4]:
def gen_training_data(num_seqs, length, training_CHRS, model, model_output_dim, batch_size = 16, device = 'cuda'):
    """
    Generates num_seqs # of training samples by calling gen_training_data
    
    Args:
        num_seqs : int
            The number of training samples generated by the method
        
        length : int
            The number of base-pairs in one generate sequence
        
        training_CHRS : array
            Set of chromosomes to be considered for training
        
        model : Neural Network Object
            The model that makes the predictions
        
        model_output_dim: int
            The dimension of the output prediction made by the model
        
        batch_size: int
            batch_size used when generating training target
        
        device: 'cpu' or 'cuda'
        
    Returns:
        training_input_arr: an array of model inputs used as training data
        
        target_arr: an array of targets used for training data
    
    """
    training_input_arr = torch.zeros((num_seqs, 4, length))
    target_arr = torch.zeros((num_seqs, length*4, model_output_dim))
    for i in range(num_seqs):
        training_input, target = training_data(length, training_CHRS, model, batch_size, device)
        training_input_arr[i, :, :] = training_input
        target_arr[i, :, :] = target
        
        
    return training_input_arr.float(), target_arr.float()
    

    
    
def gen_validation_data(num_seqs, length, validation_CHRS, model, model_output_dim, batch_size = 16, device = 'cuda'):
    """
    Generates num_seqs # of training samples by calling gen_validation_data
    
    Args:
        num_seqs: int
            The number of training samples generated by the method
        
        length : int
            The number of base-pairs in one generate sequence
        
        validation_CHRS : array
            Set of chromosomes to be considered for validation
        
        model : Neural Network Object
            The model that makes the predictions
        
        model_output_dim : int
            The dimension of the output prediction made by the model
        
        batch_size : int
            batch_size used when generating training target
        
        device: 'cpu' or 'cuda'
        
    Returns:
        val_input_arr : torch.tensor
            An array of model inputs
        
        val_target_arr : torch.tensor
            An array of labels for the model input data
    
    """
    val_input_arr = torch.zeros((num_seqs, 4, length))
    val_target_arr = torch.zeros((num_seqs, length*4, model_output_dim))
    
    for i in range(num_seqs):
        val_input, target = validation_data(length, validation_CHRS, model, batch_size, device)
        val_input_arr[i, :, :] = val_input
        val_target_arr[i, :, :] = target
        
    return val_input_arr.float(), val_target_arr.float()
  


class Multiplexer_Data(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.length = x.shape[0]
        
    def __getitem__(self, index):
        sample = self.x[index,:,:], self.y[index, :, :]
        return sample
    
    def __len__(self):
        return self.length


#### 1C) Model Training

The model was trained using the Adam optimzer and MSE loss. Validation is conducted every 200 epochs and if the model improves on the validation data, the optimizer and model parameters are saved. 

In [5]:
def train( val_data, Multiplexer_model, Student_model, length, training_CHRS, model_output_dim,
          optimizer , loss_function, epochs = 10000, num_seqs = 16, batch_size = 16, device = 'cuda'):
    """
    Trains the Multiplexer Model
    
    Args:
        val_data : torch.utils.data.dataloader.DataLoader object
            Validation data
        
        Multiplexer_model: Neural Network object
            A multiplexer model that learns to predict variants of the student model
        
        Student_model: Neural Network object
            The model the multiplexer model is training from
        
        length : int
            The number of base-pairs in one generate sequence
            
        
        training_CHRS : array
            Set of chromosomes to be considered for training
        
        model_output_dim : int
            The dimension of the output prediction made by the model
        
        optimizer : pytorch optimizer object
            The optimizer function used to train the Multiplexer_model
        
        loss_function : pytorch loss function object
            The loss function used to train the Multiplexer_model
        
        epochs : int
            Number of epochs the Multiplexer is trained on
        
        num_seqs : int
            The number of training samples generated by the method
        
        batch_size: int
            Batch_size used when generating training target from the Student model
        
        device: 'cpu' or 'cuda'
    
    
    """
    
    lowest_loss = float("inf")
    training_loss = 0
    
    
    for epoch in range(epochs):      
        Multiplexer_model.train()
        optimizer.zero_grad()
        training_loss = 0

        x,y = gen_training_data(num_seqs, length, training_CHRS, Student_model, model_output_dim, batch_size, device)
        y = y.transpose(1,2)


        yhat = Multiplexer_model.forward(x.to(device))
        y = torch.reshape(y , (yhat.shape[0], model_output_dim, 4, length)).to(device)
        
        training_loss = loss_function(yhat, y)
        

        #Update params
        training_loss.backward()
        optimizer.step()
            
        if epoch % 200 == 0:
            print("Training loss on Epoch ", epoch, "is ", training_loss.item())

            
        #validation test    
        if epoch % 200 == 0:
            Multiplexer_model.eval()
            validation_loss = 0
            
            for x,y in val_data:
                with torch.no_grad():


                    yhat = Multiplexer_model.forward(x.to(device))
                    y = torch.reshape(y , (yhat.shape[0], model_output_dim, 4, length)).to(device)
                    validation_loss += loss_function(yhat, y) 
                    
            print("Validation loss on epoch ", epoch, "is ", validation_loss.item())
            if validation_loss < lowest_loss:
                lowest_loss = validation_loss
                ###Uncomment the save methods to save the state_dict and optimizer
                #torch.save(model.state_dict(), "Multiplexer_params.pth")
                #torch.save(optimizer.state_dict(), "Multiplexer_optim.pth")
               

#### 1D) Final Components

After running all of the code above, load in the Genome and Beluga parameters, generate a set of validation data, and start training the model!

In [None]:
genome = pyfasta.Fasta('../../data/hg19.fa')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


##Load in Base model with trained parameters
Beluga_model = Beluga().to(device)
Beluga_model.load_state_dict(torch.load('../../data/deepsea.beluga.pth'))

#define size of validation data
val_num_seqs = 8 #number of sequences to use as validation
model_output_dim = 2002 #number of features predicted by the length
length = 2000 #length of the model input
batch_size = 2 #batch_size used to create training and validation data

#create training and validation splits
training_CHRS = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7','chr10', 'chr11', 'chr12', 'chr13', 
            'chr14', 'chr15', 'chr16', 'chr17','chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX','chrY']

validation_CHRS = ['chr8', 'chr9']


#generated validation data
val_arr, val_labels_arr = gen_validation_data(val_num_seqs, length, validation_CHRS, Beluga_model, model_output_dim, batch_size, device)
validation_data_obj = Multiplexer_Data(val_arr, val_labels_arr)
validation_data = DataLoader( validation_data_obj, batch_size)


#Define Multiplexer model and training hyper-parameters
BM = BelugaMultiplexer().to(device)
optimizer = torch.optim.Adam(BM.parameters(), lr = 0.001)  
loss_function = nn.MSELoss()
epochs = 10
num_seqs = 2

print("started training")
train( validation_data, BM, Beluga_model, length, training_CHRS, model_output_dim,
          optimizer , loss_function, epochs , num_seqs , batch_size , device)




started training


The code above serves as a template for training a Multiplexer model. Many features are adjusted and it is recommended that you adapt the code to your specific model. 