# Lab 3: RNNs and LSTMs
In this tutorial we create neural networks using [Recurrent Neural Networks (RNN)](https://pytorch.org/docs/stable/generated/torch.nn.RNN.html) and [Long Short-Term Memory (LSTM)](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html) cells. The data and its pre-processing for of this notebook is identical to the first lab (Perceptron) to keep the amount of new information limited. Some of the required code-blocks are empty - requiring your imput to complete the model. A few additional questions at the end challenge you to play around with the code and try things for yourselves.

During the session, you will create a RNN to translate DNA sequences into protein sequences. 

In [None]:

# import pytorch
import torch
import torch.nn as nn
from torch import Tensor
from torch import optim
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
import lightning as L

# import basic functionality
import random
import numpy as np
import pandas as pd
import itertools

# libraries for plotting
import seaborn as sns
import matplotlib.pyplot as plt

import Bio
from Bio import SeqIO


# Step 1: Pre-processing the data
Here we download and pre-process the dataset. As before, we only consider DNA sequences that are protein coding, contain a integer number of codons, have a start and stop codon, and do not contain any uncertain nucleotides. Finally, we remove duplicates and randomly mix the sequences. Nothing is different from the last time, so you can simply execute these steps and move on to Step 2.

In [None]:

# download and unpack DNA coding sequences for human, mouse and yeast
############################

!mkdir -p ~/all_seqs
%cd ~/

!wget -P ~/all_seqs/ https://ftp.ensembl.org/pub/current_fasta/homo_sapiens/cds/Homo_sapiens.GRCh38.cds.all.fa.gz
!gzip -df "all_seqs/Homo_sapiens.GRCh38.cds.all.fa.gz"


In [None]:

# function that loads and processes a FASTA file containing coding sequences
def load_species_cds(file_name):
    dna_seq = []
    prot_seq = []
    for record in SeqIO.parse(file_name, "fasta"):
        # ensure that sequences are protein coding
        if 'gene_biotype:protein_coding' in record.description:
            if 'transcript_biotype:protein_coding' in record.description:
                if ' cds ' in record.description:
                    if len(record.seq) % 3 == 0:
                        dna_seq.append(str(record.seq))
                        prot_seq.append(str(record.seq.translate()))
                        
    # keep sequences that are protein coding
    dna_seq_cod = []
    prot_seq_cod = []
    for i in range(len(prot_seq)):
        if (prot_seq[i][0]=='M') & (prot_seq[i][-1]=='*'):
            dna_seq_cod.append(dna_seq[i])
            prot_seq_cod.append(prot_seq[i])

    # avoid sequences with undetermined/uncertain nucleotides
    dna_seq_cod = [dna_seq_cod[i] for i in range(len(dna_seq_cod)) if ('N' not in dna_seq_cod[i])]
    prot_seq_cod = [prot_seq_cod[i] for i in range(len(dna_seq_cod)) if ('N' not in dna_seq_cod[i])]
 
    # remove duplicates and randomly mix the list of sequences
    seqs = list(zip(dna_seq_cod, prot_seq_cod))
    seqs = list(set(seqs))
    random.shuffle(seqs)
    dna_seq_cod, prot_seq_cod = zip(*seqs)

    # pack samples as a list of dictionaries and return result
    seq_data = [{'dna':dna_seq_cod[i],'prot':prot_seq_cod[i]} for i in range(len(dna_seq_cod))]
    return seq_data
    

In [None]:

# load coding sequences for different species
print('loading human proteins')
seq_data = load_species_cds("all_seqs/Homo_sapiens.GRCh38.cds.all.fa")

# take a look at some sequences
[seq_data[i]['dna'][0:20]+'...'+seq_data[i]['dna'][-20:] for i in range(5)]

In [None]:
print('number of sequences: ', len(seq_data))

# Step 2: Encoding the sequences
Having prepared the coding sequences and their translation, we convert them into a numeric representation as vectors. To do so, we first construct a language class that stores words of each language and allows us to convert between encoding/indices and words in a language. We define a function that allows us to store any sequence of words (i.e., codons or bases) as a numeric representation. Here we extend every sequence with a start of sentence <SOS> and end of sentence <EOS> token, such that the model knows when to start and stop translating. Finally, we can extend every sentence to the same length by padding with an empty "word" that is not translated or used, but allows us to use the identical-length numerical representation of the sentences as input for the model. The language allows for a simple encoding of words as numbers or as numerical vectors (one-hot-encoding). The 'encode' function converts an input sentence to the specified encoding.

In [None]:

# class to store a language
class Language:
    # initialize the language, as standard we have a padding to equalize sentence lengths (PAD)
    def __init__(self, name, codon_len):
        self.name = name
        self.word2index = {"<PAD>": 0}
        self.encoding = {}
        self.index2word = {0: "<PAD>"}
        self.n_words = 1  # Count SOS and EOS
        self.codon_length = codon_len

    # function to add sentence to language (add new words in the sentence to our language)
    def addSentence(self, sentence):
        for word in [sentence[i:i+self.codon_length] for i in range(0, len(sentence), self.codon_length)]:
            self.addWord(word)

    # function to add word to language
    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.index2word[self.n_words] = word
            self.n_words += 1
            
    # function to convert indices to one-hot encodings (i.e., 3 becomes [0,0,0,1,0,0,...])
    def as_one_hot(self):
        for key in self.word2index:
            new_val = np.zeros(len(self.word2index),dtype=np.int32)
            new_val[self.word2index[key]] = 1
            self.encoding[key] = new_val
    
    # function to convert indices to simple encodings
    def as_encoding(self):
        self.encoding = self.word2index

    # function to encode (and pad) a sentence
    # we use this to take an input sentence and convert it to a sequence of arrays that represent that sentence in a given language
    # in the context of proteins, think of this as encoding the bases or codons
    def encode(self, sentence, max_len):
        pad = [self.encoding["<PAD>"]]

        # split sentence in blocks of a given codon_length
        sentence_split = [sentence[i:i+self.codon_length] for i in range(0, len(sentence), self.codon_length)]
            
        # encode sentence in the given language
        sentence_encoded = [self.encoding[word] for word in sentence_split]

        # only pad or truncate if a maximum length is specified
        if max_len is not None:
            if len(sentence_split) < max_len: 
                # sentence is shorter than max length; pad to maximum length
                n_pads = max_len - len(sentence_split)
                return torch.Tensor(np.array(sentence_encoded + pad * n_pads))
            else: 
                # sentence is longer than max length; truncate
                sentence_truncated = sentence_encoded[:max_len]
                return torch.Tensor(np.array(sentence_truncated))
        else:
            return torch.Tensor(np.array(sentence_encoded))


In [None]:
########################
### step2a: create languages for DNA and protein sequences
########################

# create a language for DNA and protein sequences
dna_lang = Language(name="dna", codon_len = 3) # IMPORTANT: here we set the codon length to 3!
prot_lang = Language(name="prot", codon_len = 1)

# split the sequence data ('seq_data') that we defined above into sensible training, validation and test sets
# think about how much data would realistically be necessary to learn the problem of translating DNA sequences
train_set, val_set, test_set, _ = torch.utils.data.random_split(..., ...)

# memorize the dna and protein languages by parsing all sequences
for cur_seq in train_set:
    dna_lang.addSentence(cur_seq['dna'])
    prot_lang.addSentence(cur_seq['prot'])

# create an one-hot-encoding for all words codons and a simple encoding for all amino acids
# call the appropriate functions for each of the two languages
dna_lang.FUNCTION_CALL()
prot_lang.FUNCTION_CALL()


In [None]:

# here we define a function for encoding a dataset of dna and protein sequences
def encode_dataset(dataset, dna_lang, prot_lang, max_length):
    dataset_encoded = [
        { 
          'dna'  : dna_lang.encode(dataset[i]['dna'], max_len=max_length),
          'prot' : prot_lang.encode(dataset[i]['prot'], max_len=max_length)
        } for i in range(len(dataset))
    ]
    return dataset_encoded
    

In [None]:
########################
### step2b: encode your sequences here
########################

# define maximum number of codons
# we truncate any sequence longer than this length, and pad any sequence shorter than this length
# think about a sensible length for the input sequences
max_length = ...

# encode the training and validation data
train_set_encoded = FUNCTION_CALL(..., dna_lang, prot_lang, ...) 
val_set_encoded = FUNCTION_CALL(..., dna_lang, prot_lang, ...)


In [None]:

# take a look at the encoding of a DNA
train_set[0]['dna'], train_set_encoded[0]['dna'].shape


In [None]:

# take a look at the encoding of a protein
train_set[0]['prot'], train_set_encoded[0]['prot'].shape


In [None]:

# define dataloader for the encoded sequences
def get_dataloader(dataset, batch_size):
    cur_sampler = RandomSampler(dataset)
    cur_dataloader = DataLoader(dataset, sampler=cur_sampler, batch_size=batch_size, drop_last=True, num_workers=16)
    return cur_dataloader
    

In [None]:
########################
### step 2c: create a dataloader for the validation and training sequences
########################

# how many samples should be trained on simultaneously?
batch_size = ...

# define dataloader for training
train_loader = FUNCTION_CALL(..., ...)
val_loader = FUNCTION_CALL(..., ...)


# Step 3: Define model
We created languages for DNA and protein sequences, and encoded all sequences through the encodings defined by these languages. We then created data loaders for these encoded sequences. As a final preparation, we define our [RNN](https://pytorch.org/docs/stable/generated/torch.nn.RNN.html).
Create a class that instructs pytorch to make a RNN model using the given input, hidden size and output parameters. You'll have to define the init and forward functions. Additionally, in the Pytorch Lightning class, you must choose a loss function and optimizer that are appropriate for the problem you are trying to solve. Once you have thought about your model, we will go over the architecture together with the class. Hint: for the RNN, you'll need to also define the hidden state.

In [None]:

# Define the device (CPU or GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
########################
### step 3a: define the model architecture
########################
class MyRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MyRNN,self).__init__()

        ########################
        # define model layers (rnn), pseudocode:
        # look up in the documentation! https://pytorch.org/docs/stable/generated/torch.nn.RNN.html
        self.rnn = nn.RNN(...)
        ########################

    def forward(self, input_data):
        input_data = input_data.to(device)
        
        ########################
        # define model and hidden state
        ...
        
        return ...
        ########################

In [None]:
########################
### step 3b: define the lightning module to train the model
########################

# lightning module to train the sequence model
class SequenceModelLightning(L.LightningModule):
    def __init__(self, input_size, hidden_size, output_size, lr=0.1):
        super().__init__()
        self.model = MyRNN(input_size, hidden_size, output_size)
        self.lr = lr

        ########################
        # define loss function here, pseudocode:
        self.loss = ...
        ########################

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        input_tensor = batch['dna']
        target_tensor = batch['prot']
        
        output = self.model(input_tensor)
        loss = self.loss(input=output, target=target_tensor)
        self.log("train_loss", loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        input_tensor = batch['dna']
        target_tensor = batch['prot']

        output = self.model(input_tensor)
        loss = self.loss(input=output, target=target_tensor)
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        # define optimizer here
        return optim.Adam(self.model.parameters(), lr=self.lr)
    

In [None]:
########################
### step3c: define the input parameters for the training loop
########################

# define the model and training loop
# think of the dimensionality of your input data (dna sequences) and output data (protein sequence), and where these numbers are stored
lit_model = SequenceModelLightning(input_size = ...,
                                  hidden_size = 0,
                                  output_size = ...,
                                  lr = ...)

# define the trainer
trainer = L.Trainer(devices = 1, 
                    max_epochs = ...)

# learn the weights of the model
trainer.fit(lit_model, ..., ...)


# Step 4: Test the model on random test sequences

In [None]:

# show the model architecture
my_rnn = lit_model.model
my_rnn


In [None]:

# we encode the test data using the same dna and protein language encodings we defined before
# if you change the languages, you need to re-encode the test sequences as well!
test_set_encoded = encode_dataset(test_set, dna_lang, prot_lang, max_length)


In [None]:
########################
### step4: define the input tensor and get the prediction from your model
########################

# pick a random sequence from the test set
random_pair = np.random.randint(0,len(test_set))

# get the encoded dna sequence and its known protein translation
dna_sequence = np.array([test_set_encoded[random_pair]['dna']])
protein_translation = test_set[random_pair]['prot']
target_tensor = test_set_encoded[random_pair]['prot']

# send model and input sequence to device, compute translation of sequence
my_rnn.to(device)
########################
input_tensor = ...
output = ...
########################

loss = ((target_tensor - output.cpu())**2).mean()
print('loss', loss)

# convert output back to protein sequence by taking the most likely amino acid per position, print results
result = "".join([prot_lang.index2word[i] for i in output.cpu().topk(1)[1].view(-1).numpy()])
print(''+protein_translation)
print(result, end='\n\n')

# print accuracy
result = "".join([prot_lang.index2word[i] for i in output.cpu().topk(1)[1].view(-1).numpy() if i not in [key for key in Language('',1).index2word]])
min_len = np.min([len(result),len(protein_translation)])
print('Accuracy of aa calling over the sequence: ', np.sum([protein_translation[i] == result[i] for i in range(min_len)])/min_len)


# Step 5: Interpret the RNN hidden state
The RNN you trained has a hidden state, a matrix of size (# of codons) x (# of amino acids). Load the hidden state and interpret the parameter values.

In [None]:
########################
### step5: interpret the hidden state of your RNN
########################

# load the hidden state from your RNN.
# annotate the hidden state with the matching codons (in 'dna_lang.index2word') and amino acids (in 'prot_lang.index2word')
# store the result in a dataframe called 'rnn_param'
rnn_param = 


In [None]:

import seaborn as sns
import matplotlib.pyplot as plt

# plot the hidden state of the RNN. What is the interpretation?
use_map =  sns.color_palette("RdBu",as_cmap=True)
clus = sns.clustermap(
                        rnn_param,
                        center=0,
                        xticklabels=True, yticklabels=True, 
                        row_cluster=True, col_cluster=True, 
                        cmap=use_map,
                        figsize=(5,10),
                        method='complete', metric='cityblock'
                     )

plt.rcParams['pdf.fonttype'] = 42 
plt.tight_layout()
plt.show()


# Steps:
2a: create languages for the DNA and protein sequences <br>
2b: encode the training and validation sequences <br>
2c: create a dataloader for the validation and training sequences <br>
3a: define your RNN model <br>
3b: define the lightning module to train the model <br>
3c: define the input parameters for the training loop <br>
Train your model :) <br>
4: define the input tensor and get the prediction from your model <br>
5: interpret the hidden state of your RNN

# Questions:
-the model is trained on truncated and padded sequences. Change the setup to train your model on arbitrary length sequences (their actual length). Before training your model, think about the number of samples to use for training, the batch size, and number of epochs. <br>
-change the codon length to train on nucleotides instead of codons. This is a bit tricky because of the dimensionality of your input / output data! Does your RNN still train well? <br>
-add a different model that uses a LSTM instead of a RNN. Does the LSTM train better? <br>