# Lab 3: Solutions

In [None]:
########################
### step2a: create languages for DNA and protein sequences
########################

# create a language for DNA and protein sequences
dna_lang = Language(name="dna", codon_len=3)
prot_lang = Language(name="prot", codon_len=1)

# split the sequence data ('seq_data') that we defined above into sensible training, validation and test sets
# think about how much data would realistically be necessary to learn the problem of translating DNA sequences
train_set, val_set, test_set, _ = torch.utils.data.random_split(seq_data, [0.1,0.1,0.1,0.7])

# memorize the dna and protein languages by parsing all sequences
for cur_seq in train_set:
    dna_lang.addSentence(cur_seq['dna'])
    prot_lang.addSentence(cur_seq['prot'])

# create an one-hot-encoding for all words codons and a simple encoding for all amino acids
# call the appropriate functions for each of the two languages
dna_lang.as_one_hot()
prot_lang.as_one_hot()


In [None]:
########################
### step2b: encode your sequences here
########################

# define maximum number of codons
# we truncate any sequence longer than this length, and pad any sequence shorter than this length
# think about a sensible length for the input sequences
max_length = None

# encode the training and validation data
train_set_encoded = encode_dataset(train_set, dna_lang, prot_lang, max_length) 
val_set_encoded = encode_dataset(val_set, dna_lang, prot_lang, max_length)


In [None]:
########################
### step 2c: create a dataloader for the validation and training sequences
########################

# how many samples should be trained on simultaneously?
batch_size = 1

# define dataloader for training
train_loader = get_dataloader(train_set_encoded, batch_size)
val_loader = get_dataloader(val_set_encoded, batch_size)


In [None]:
########################
### step 3a: define the model architecture
########################
class MyRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MyRNN,self).__init__()
        
        # input parameters
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        # define model layers (rnn), pseudocode:
#        self.rnn = nn.LSTM(self.input_size, self.output_size, num_layers=1, batch_first = True, bias=False)
        self.rnn = nn.RNN(self.input_size, self.output_size, num_layers=1, batch_first = True, bias=False)

    def forward(self,inp):
        inp1 = inp.to(device)
        
        # define initial hidden and cell states of rnn, e.g.:
        h0 = torch.randn(1, inp1.size(0), self.output_size).to(inp.device)
#        c0 = torch.randn(1, inp1.size(0), self.output_size).to(inp.device)

        # run LSTM, pseudocode:
#        output_rnn, (hn,cn) = self.rnn(inp1, (h0,c0))
        output_rnn, (hn) = self.rnn(inp1, (h0))

        return output_rnn#[:,2::3] # when using nucleotides

In [None]:
########################
### step 3b: define the lightning module to train the model
########################

# lightning module to train the sequence model
class SequenceModelLightning(L.LightningModule):
    def __init__(self, input_size, hidden_size, output_size, lr=0.1):
        super().__init__()
        self.model = MyRNN(input_size, hidden_size, output_size)
        self.lr = lr

        # define loss function here, pseudocode:
        self.loss = nn.MSELoss()

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        input_tensor = batch['dna']
        target_tensor = batch['prot']
        
        output = self.model(input_tensor)
        loss = self.loss(input=output, target=target_tensor)
        self.log("train_loss", loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        input_tensor = batch['dna']
        target_tensor = batch['prot']

        output = self.model(input_tensor)
        loss = self.loss(input=output, target=target_tensor)
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        # define optimizer here
        return optim.Adam(self.model.parameters(), lr=self.lr)
    

In [None]:
########################
### step3c: define the input parameters for the training loop
########################

# define the model and training loop
# think of the dimensionality of your input data (dna sequences) and output data (protein sequence), and where these numbers are stored
lit_model = SequenceModelLightning(input_size = dna_lang.n_words,
                                  hidden_size = 0,#prot_lang.n_words,
                                  output_size = prot_lang.n_words,
                                  lr = 0.05)

# define the trainer
trainer = L.Trainer(devices = 1, 
                    max_epochs = 10)

# learn the weights of the model
trainer.fit(lit_model, train_loader, val_loader)


In [None]:
########################
### step4: define the input tensor and get the prediction from your model
########################

# pick a random sequence from the test set
random_pair = np.random.randint(0,len(test_set))

# get the encoded dna sequence and its known protein translation
dna_sequence = np.array([test_set_encoded[random_pair]['dna']])
protein_translation = test_set[random_pair]['prot']
target_tensor = test_set_encoded[random_pair]['prot']

# send model and input sequence to device, compute translation of sequence
my_rnn.to(device)
input_tensor = torch.Tensor(dna_sequence).to(device)
output = my_rnn(input_tensor)

loss = ((target_tensor - output.cpu())**2).mean()
print('loss', loss)

# convert output back to protein sequence by taking the most likely amino acid per position, print results
result = "".join([prot_lang.index2word[i] for i in output.cpu().topk(1)[1].view(-1).numpy()])
print(''+protein_translation)
print(result, end='\n\n')

# print accuracy
result = "".join([prot_lang.index2word[i] for i in output.cpu().topk(1)[1].view(-1).numpy() if i not in [key for key in Language('',1).index2word]])
min_len = np.min([len(result),len(protein_translation)])
print('Accuracy of aa calling over the sequence: ', np.sum([protein_translation[i] == result[i] for i in range(min_len)])/min_len)


In [None]:
########################
### step5: interpret the hidden state of your RNN
########################

# load the hidden state from your RNN.
# annotate the hidden state with the matching codons (in 'dna_lang.index2word') and amino acids (in 'prot_lang.index2word')
# store the result in a dataframe called 'rnn_param'
rnn_param = pd.DataFrame(next(my_rnn.rnn.parameters()).detach().numpy()).T
rnn_param.index = list(dna_lang.index2word.values())
rnn_param.columns = list(prot_lang.index2word.values())
