# Preparation

## Import packages

In [1]:
# Inputs – from Nuo
import pickle
import torch
from torch import nn
import numpy as np
from torch.nn import Module
from sklearn import metrics 
import matplotlib.pyplot as plt

## Load traning and testing datasets

### Full-length transcripts

In [87]:
with open("proc/training_dataset.pkl", "rb") as f:
    training_dataset = pickle.load(f)

with open("proc/test_dataset.pkl", "rb") as f:
    test_dataset = pickle.load(f)

### 80nt transcripts

In [2]:
with open("proc/training_dataset_80nt.pkl", "rb") as f:
    training_dataset = pickle.load(f)

with open("proc/test_dataset_80nt.pkl", "rb") as f:
    test_dataset = pickle.load(f)

### Define NtDataset

In [7]:
# Dataset – from Hannah and Lucas

class NtDataset:
    """Nucleotide sequence + splice sites dataset."""
    def __init__(self, dataset):
        self.dataset = dataset
        self.map = {'A':0, 'G':1, 'C':2, 'T':3}
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        nt_seq = self.dataset[idx][0]
        strengths = self.dataset[idx][1]
        
        tokenized_seq = []
        
        for letter in nt_seq:
            tokenized_seq.append(self.map[letter])
            
        return torch.tensor(tokenized_seq), torch.tensor(strengths).type(torch.LongTensor)

## Model

In [8]:
# from torch.optim import optim

class SpliceFormer(nn.Module):
    """Transformer for splice site prediction"""

    def __init__(
        self,
        vocab_size:int,
        model_dim: int,
        n_attn_heads: int,
        n_encoder_layers: int,
        hidden_act: Module,
        dropout: float,
    ) -> None:

        super().__init__()

        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=model_dim,
            nhead=n_attn_heads,
            dim_feedforward=model_dim,
            dropout=dropout,
            activation=hidden_act,
            batch_first=True)

        self.embedding = nn.Embedding(
            vocab_size, embedding_dim=model_dim)
        
        self.vocab_size = vocab_size

        self.encoder = nn.TransformerEncoder(
            encoder_layer=self.encoder_layer, num_layers=n_encoder_layers)

        self.out_layer = nn.Linear(in_features=model_dim, out_features=3, bias=False)
        # self.lm_head.weight = self.embedding.weight

    def forward(self, inputs):
        #self.vocab_size,
        x_emb = self.embedding(inputs)

        # inputs: (batch_size, seq_len, n_tokens)
        encoder_output = self.encoder(x_emb)
        outputs = self.out_layer(encoder_output)

        return outputs

# Training

In [11]:
training_dataset_subset = training_dataset[:200]

In [None]:
# Training

n_epochs = 40
nucleotide_loader = NtDataset(training_dataset_subset)
loss_fn = nn.CrossEntropyLoss()
splice_model = SpliceFormer(vocab_size=4, 
                            model_dim=64, 
                            n_attn_heads=2, 
                            n_encoder_layers=2, 
                            hidden_act=nn.ReLU(), 
                            dropout=0.1)

optimizer = torch.optim.AdamW(splice_model.parameters(), lr=0.00001)

BATCH_SIZE = 32

# training loop 
for epoch in range(n_epochs):

    # running_loss = 0.0
    running_loss = []
    grad_iter = 0
    for seq_number, nucleotide_seq in enumerate(nucleotide_loader):
        inputs, labels = nucleotide_seq

        grad_iter += 1

        # optimizer.zero_grad()
        
        # error is here!
        outputs = splice_model(inputs)
        # print(outputs)
        # print(labels)

        total_loss = loss_fn(outputs, labels) / BATCH_SIZE
        total_loss.backward()
        
        if grad_iter % BATCH_SIZE == 0:
            # optimizer.zero_grad()
            optimizer.step()
            optimizer.zero_grad()
            grad_iter = 0

        # running_loss += total_loss.item()
        running_loss.append(total_loss.item())
        if seq_number % 20000 == 0:
            print(f'epoch: {epoch}, step: {seq_number}, loss: {sum(running_loss) / len(running_loss)}')
            running_loss = []
            torch.save(splice_model.state_dict(), f'proc/tbh_model_{seq_number}.pth')
            # torch.save(splice_model.state_dict(), f'/tbh_model.pth')
            # running_loss = 0.0

            
print("Finished training!\nFinal loss value:", total_loss)
torch.save(splice_model.state_dict(), '.proc/tbh_model_final.pth')

In [None]:
test_loader = NtDataset(test_dataset)
splice_test = SpliceFormer()

y_pred, y_actual = [], []

with torch.no_grad():
    for data in test_loader:

        inputs, labels = data
        outputs = splice_test(nucleotide_seq)
        _, predicted = torch.max(outputs, 1)

        y_pred.extend(predicted)
        y_actual.extend(labels)


confusion_matrix = metrics.confusion_matrix(y_actual, y_pred)
cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=[0,1,2])
cm_display.plot(cmap='GnBu')
plt.show()