In [21]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import random
from sklearn.metrics import accuracy_score

In [22]:
proteins = pd.read_csv("../Data/2018-06-06-ss.cleaned.csv")

In [23]:
proteins

Unnamed: 0,pdb_id,chain_code,seq,sst8,sst3,len,has_nonstd_aa
0,1A30,C,EDL,CBC,CEC,3,False
1,1B05,B,KCK,CBC,CEC,3,False
2,1B0H,B,KAK,CBC,CEC,3,False
3,1B1H,B,KFK,CBC,CEC,3,False
4,1B2H,B,KAK,CBC,CEC,3,False
...,...,...,...,...,...,...,...
393727,4UWE,D,MGDGGEGEDEVQFLRTDDEVVLQCSATVLKEQLKLCLAAEGFGNRL...,CCCCCCCCCCCCCCBTTCEEEEEEEEEETTEEEEEEEECCCSSCCB...,CCCCCCCCCCCCCCECCCEEEEEEEEEECCEEEEEEEECCCCCCCE...,5037,True
393728,5J8V,A,MGDGGEGEDEVQFLRTDDEVVLQCSATVLKEQLKLCLAAEGFGNRL...,CCCCCCCCCCCCCCCSSSCCEEEECSEETTEECCEECCEEETTEEE...,CCCCCCCCCCCCCCCCCCCCEEEECCEECCEECCEECCEEECCEEE...,5037,False
393729,5J8V,B,MGDGGEGEDEVQFLRTDDEVVLQCSATVLKEQLKLCLAAEGFGNRL...,CCCCCCCCCCCCCCCSSSCCEEEECSEETTEECCEECCEEETTEEE...,CCCCCCCCCCCCCCCCCCCCEEEECCEECCEECCEECCEEECCEEE...,5037,False
393730,5J8V,C,MGDGGEGEDEVQFLRTDDEVVLQCSATVLKEQLKLCLAAEGFGNRL...,CCCCCCCCCCCCCCCSSSCCEEEECSEETTEECCEECCEEETTEEE...,CCCCCCCCCCCCCCCCCCCCEEEECCEECCEECCEECCEEECCEEE...,5037,False


I started with sampling the dataset: 
- selected sequences no longer than 20 amino acids (in this architecture I used shorter sequences than in LSTM because of the resources needed to train on longer sequences)
- selected the important columns
- deduplicated dataset
- removed sequences that were only composed of "*" sign - which indicated nonstandard amino acids (B, O, U, X, or Z)

In [24]:
sample = proteins[
    (proteins["len"]>=1) &
    (proteins["len"]<=20)]

In [25]:
sample = sample[["seq","sst3","sst8"]]

In [26]:
sample = sample.drop_duplicates()

In [27]:
sample["len"] = sample["seq"].apply(len)
sample = sample.sample(frac=1)

Based on the seq2seq tutorial (https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html) I created Lang class that will store the input "language" - in my case amino acid sequence and the output "language" - three-state (Q3) secondary structure

The sequences and structures are characterised by their own set of characters. 

In [28]:
SOS_token = 0
EOS_token = 1

class Lang:
    def __init__(self):
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2  

    def addSentence(self, sentence):
        for word in list(sentence):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

In [29]:
def prepareData(lang1, lang2, reverse=False):

    input_lang = Lang()
    output_lang = Lang() 

    pairs = list(zip(lang1,lang2))

    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])

    print("Counted words:")
    print(f"Sequence: {input_lang.n_words}")
    print(f"Structure: {output_lang.n_words}")
    return input_lang, output_lang, pairs

input_lang, output_lang, pairs = prepareData(sample["seq"], sample["sst3"])

print(random.choice(pairs))

Counted words:
Sequence: 23
Structure: 5
('*YSPTSPSYSPTSPS', 'CCCCCCCECCCCCCC')


In [30]:
output_lang.word2count

{'C': 53524, 'E': 7831, 'H': 12860}

In [31]:
output_lang.word2index

{'C': 2, 'E': 3, 'H': 4}

In [32]:
MAX_LENGTH = sample["len"].max()+2

In [33]:
def indexesFromSentence(lang, sentence):
    return [lang.word2index[word] for word in list(sentence)]

def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long).view(1, -1)

def tensorsFromPair(pair):
    input_tensor = tensorFromSentence(input_lang, pair[0])
    target_tensor = tensorFromSentence(output_lang, pair[1])
    return (input_tensor, target_tensor)

n = len(pairs)
input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)
target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)

for idx, (inp, tgt) in enumerate(pairs):
    inp_ids = indexesFromSentence(input_lang, inp)
    tgt_ids = indexesFromSentence(output_lang, tgt)
    inp_ids.append(EOS_token)
    tgt_ids.append(EOS_token)
    input_ids[idx, :len(inp_ids)] = inp_ids
    target_ids[idx, :len(tgt_ids)] = tgt_ids

In [34]:
train_size = int(len(input_ids)*0.6)
test_size = int(len(input_ids)*0.2)

X = input_ids
y = target_ids

X_train = torch.tensor(X[:train_size], dtype=torch.long)
y_train = torch.tensor(y[:train_size],dtype=torch.long)

X_test = torch.tensor(X[train_size:train_size+test_size],dtype=torch.long)
y_test = torch.tensor(y[train_size:train_size+test_size],dtype=torch.long)

X_val = torch.tensor(X[train_size+test_size:],dtype=torch.long)
y_val = torch.tensor(y[train_size+test_size:],dtype=torch.long)

I decided to calculate weights of the output characters, and use them later in the loss function. The frequencies of the characters in the structure vary a lot with C letter being the most frequent. I wanted to avoid the situation where model learns only to output the majority group, so I calculated the reverse probability of the word frequency. 

In [35]:
SOS_freq = (torch.tensor(y, dtype=torch.long).shape[0] * torch.tensor(y, dtype=torch.long).shape[1]) - torch.count_nonzero(torch.tensor(y, dtype=torch.long))

In [36]:
vocab = output_lang.word2index
word_freq = output_lang.word2count

vocab.update({"SOS":0,"EOS":1})
word_freq.update({"SOS":int(SOS_freq),"EOS":len(sample)})

vocab_size = len(vocab)

weights = torch.zeros(vocab_size)

for word, idx in vocab.items():
    weights[idx] = 1.0 / (word_freq[word]) 
    
weights = weights / weights.sum()
print(weights)

tensor([0.0468, 0.4058, 0.0456, 0.3119, 0.1899])


I based the model on seq2seq tutorial (https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html).

I tested the result with the presented architecture as well as using the decoder without attention mechanism. The results were better with attention. Then I used bidirectional encoder which led to better results. I also decided to add embedding size parameter in the decoder to be able to adjust its size independently from hidden size. 

In [37]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, dropout_p=0.1):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True, num_layers=1)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, input):
        embedded = self.dropout(self.embedding(input))
        output, hidden = self.gru(embedded)

        hidden = torch.sum(hidden,dim=0).unsqueeze(dim=0)
        output = torch.chunk(output, 2 , dim = 2)[0] + torch.chunk(output, 2 , dim = 2)[1]

        return output, hidden

In [38]:
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size):
        super(BahdanauAttention, self).__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)

    def forward(self, query, keys):
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))
        scores = scores.squeeze(2).unsqueeze(1)

        weights = F.softmax(scores, dim=-1)
        context = torch.bmm(weights, keys)

        return context, weights


class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, embedding_dim, output_size):
        super(DecoderRNN, self).__init__()
        self.embedding = nn.Embedding(output_size, embedding_dim)

        self.attention = BahdanauAttention(hidden_size)

        self.gru = nn.GRU(embedding_dim+hidden_size, hidden_size, batch_first=True, num_layers=1)
        self.out = nn.Linear(hidden_size, output_size)
        

    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.empty(batch_size, 1, dtype=torch.long).fill_(SOS_token)
        decoder_hidden = encoder_hidden
        decoder_outputs = []
        attentions = []

        for i in range(MAX_LENGTH):
            decoder_output, decoder_hidden, attn_weights = self.forward_step(decoder_input, decoder_hidden, encoder_outputs)
            decoder_outputs.append(decoder_output)
            attentions.append(attn_weights)

            if target_tensor is not None:
                decoder_input = target_tensor[:, i].unsqueeze(1)
            else:
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()

        decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        attentions = torch.cat(attentions, dim=1)

        return decoder_outputs, decoder_hidden, attentions

    def forward_step(self, decoder_input, hidden, encoder_outputs):
        output = self.embedding(decoder_input)

        query = hidden.permute(1, 0, 2)

        context, attn_weights = self.attention(query, encoder_outputs)
        input_gru = torch.cat((output, context), dim=2)

        output, hidden = self.gru(input_gru, hidden)
        output = self.out(output)
        return output, hidden, attn_weights

I experimented with learning rate, decided to use large batch size to speed up learning process, because the output was not highly dimensional I decided to use small hidden size and embedding.

Finally in the training loop I implemented early stopping and saving best model, I wanted to be able to optimise the number of epochs and prevent overfitting, so if during the training the train loss is decreasing but test loss is rising I stop the training, and save the model that was having the lowest test loss.

In [39]:
learning_rate=0.001
batch_size = 128
hidden_size = 32
embedding_size = 32
n_epochs = 100

patience = 5

In [40]:
train_loader = DataLoader(list(zip(X_train,y_train)), batch_size=batch_size)
test_loader = DataLoader(list(zip(X_test,y_test)), batch_size=batch_size)

In [43]:
encoder = EncoderRNN(input_lang.n_words, hidden_size)
decoder = DecoderRNN(hidden_size, embedding_size, output_lang.n_words)

encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
loss_fn = nn.NLLLoss(weight=weights)

In [None]:
best_result = np.inf
test_loss_array = []

for epoch in range(n_epochs):

    total_loss = 0
    for X_batch, y_batch in train_loader:
        
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        encoder_outputs, encoder_hidden = encoder(X_batch)
        decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, y_batch)

        loss = loss_fn(
            decoder_outputs.view(-1, decoder_outputs.size(-1)),
            y_batch.view(-1)
        )
        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()

        total_loss += loss.item() 

    test_loss = 0
    with torch.no_grad():  
            
            for X_batch, y_batch in test_loader:

                encoder_outputs, encoder_hidden = encoder(X_batch)
                decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, y_batch)
                t_loss = loss_fn(decoder_outputs.view(-1, decoder_outputs.size(-1)),y_batch.view(-1))

                test_loss+=t_loss

    loss = total_loss / (len(X_train) // batch_size)
    loss_test = test_loss / (len(y_test) // batch_size)

    test_loss_array.append(loss_test)

    if loss_test < best_result:
        torch.save(encoder.state_dict(), "./encoder.pth")
        torch.save(decoder.state_dict(), "./decoder.pth")

    print(f"Epoch: {epoch}, Train loss: {loss}, Test loss: {loss_test}")

    if len(test_loss_array)>patience+1:
        if not (any(x > test_loss_array[-1] for x in test_loss_array[len(test_loss_array)-patience-1:-1])):
            break

Epoch: 0, Train loss: 1.325981765985489, Test loss: 1.0632933378219604
Epoch: 1, Train loss: 0.7983005387442452, Test loss: 0.704858660697937
Epoch: 2, Train loss: 0.5554359619106565, Test loss: 0.5345368981361389
Epoch: 3, Train loss: 0.4426747815949576, Test loss: 0.4648110568523407
Epoch: 4, Train loss: 0.3920128984110696, Test loss: 0.41604119539260864
Epoch: 5, Train loss: 0.3603114549602781, Test loss: 0.38830944895744324
Epoch: 6, Train loss: 0.34307172362293514, Test loss: 0.37202373147010803
Epoch: 7, Train loss: 0.33369557027305874, Test loss: 0.36122068762779236
Epoch: 8, Train loss: 0.3210261049015181, Test loss: 0.3555908203125
Epoch: 9, Train loss: 0.3136281041162355, Test loss: 0.34400221705436707
Epoch: 10, Train loss: 0.30405750977141516, Test loss: 0.3403071165084839
Epoch: 11, Train loss: 0.3226709280695234, Test loss: 0.3445257246494293
Epoch: 12, Train loss: 0.3008771411010197, Test loss: 0.33709049224853516
Epoch: 13, Train loss: 0.29378327088696615, Test loss: 0.

I loaded the best model and used it for validation.

In [44]:
encoder.load_state_dict(torch.load("../Models/encoder.pth"))
decoder.load_state_dict(torch.load("../Models/decoder.pth"))

<All keys matched successfully>

In [46]:
with torch.no_grad():

    encoder_outputs, encoder_hidden = encoder(X_val)
    decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)

    _, topi = decoder_outputs.topk(1)
    decoded_ids = topi.squeeze()

    pred = []
    for idx in decoded_ids:
        decoded_structure = []
        for id in idx:
            if id.item() == EOS_token or id.item() == SOS_token:
                break
            decoded_structure.append(output_lang.index2word[id.item()])
        pred.append("".join(decoded_structure))
    
    print(pred[:5])

['CCCCCCCCCCCCC', 'CHHHHHH', 'CHHHHHHHHH', 'CCCCCCCCCCCCCC', 'CCCCCCCCCC']


In [47]:
target=[]
for idx in y_val:
    decoded_structure = []
    for id in idx:
        if id.item() == EOS_token:
            break
        decoded_structure.append(output_lang.index2word[id.item()])
    target.append("".join(decoded_structure))

print(target[:5])

['CCCCCCCCCCCCCCCCCC', 'CCCCCCCC', 'CCCCCECCCCCC', 'CCCHHHHHHHHHHHCCCCC', 'CCCCCCCCCCCCC']


In the end I calculated the accuracy at the character and sentence level. However I think the character-level statistic is by far more important as single mistakes are inevitable with such long and repetitive sequences. 

In [None]:
def char_level_acc(predictions, targets):
    accuracy = 0
    
    for pred, target in zip(list(predictions), list(targets)):
        if len(pred)<len(target):
            pred = pred + ("$" * (len(target)-len(pred)))
        if len(pred)>len(target):
            target = target + ("$" * (len(pred)-len(target)))

        accuracy += accuracy_score(list(pred),list(target))

    return accuracy / len(predictions)

print(f'Character-level accuracy: {char_level_acc(pred, target)*100}%')
print(f'Exact match: {accuracy_score(pred,target)*100}%')

Character-level accuracy: 42.13788322906667%
Exact match: 0.0%
