In [27]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import random
from sklearn.metrics import accuracy_score,f1_score

### Preprocessing

In [28]:
proteins = pd.read_csv("../Data/2018-06-06-ss.cleaned.csv")

In [29]:
proteins

Unnamed: 0,pdb_id,chain_code,seq,sst8,sst3,len,has_nonstd_aa
0,1A30,C,EDL,CBC,CEC,3,False
1,1B05,B,KCK,CBC,CEC,3,False
2,1B0H,B,KAK,CBC,CEC,3,False
3,1B1H,B,KFK,CBC,CEC,3,False
4,1B2H,B,KAK,CBC,CEC,3,False
...,...,...,...,...,...,...,...
393727,4UWE,D,MGDGGEGEDEVQFLRTDDEVVLQCSATVLKEQLKLCLAAEGFGNRL...,CCCCCCCCCCCCCCBTTCEEEEEEEEEETTEEEEEEEECCCSSCCB...,CCCCCCCCCCCCCCECCCEEEEEEEEEECCEEEEEEEECCCCCCCE...,5037,True
393728,5J8V,A,MGDGGEGEDEVQFLRTDDEVVLQCSATVLKEQLKLCLAAEGFGNRL...,CCCCCCCCCCCCCCCSSSCCEEEECSEETTEECCEECCEEETTEEE...,CCCCCCCCCCCCCCCCCCCCEEEECCEECCEECCEECCEEECCEEE...,5037,False
393729,5J8V,B,MGDGGEGEDEVQFLRTDDEVVLQCSATVLKEQLKLCLAAEGFGNRL...,CCCCCCCCCCCCCCCSSSCCEEEECSEETTEECCEECCEEETTEEE...,CCCCCCCCCCCCCCCCCCCCEEEECCEECCEECCEECCEEECCEEE...,5037,False
393730,5J8V,C,MGDGGEGEDEVQFLRTDDEVVLQCSATVLKEQLKLCLAAEGFGNRL...,CCCCCCCCCCCCCCCSSSCCEEEECSEETTEECCEECCEEETTEEE...,CCCCCCCCCCCCCCCCCCCCEEEECCEECCEECCEECCEEECCEEE...,5037,False


In [30]:
def remove_empty(sequence):
    s = set(list(sequence))
    if len(s)==1:
        letter = s.pop()
        if letter == "*": 
            return 1
        else: 
            return 0
    return 0 

In [31]:
sample = proteins[
    (proteins["len"]>=1) &
    (proteins["len"]<=100)]

In [32]:
sample = sample[["seq","sst3","sst8"]]
sample = sample.drop_duplicates()

In [33]:
sample["remove"] = sample["seq"].apply(remove_empty) 
sample = sample[sample["remove"]==0].copy()

In [34]:
sample["len"] = sample["seq"].apply(len)
sample

Unnamed: 0,seq,sst3,sst8,remove,len
0,EDL,CEC,CBC,0,3
1,KCK,CEC,CBC,0,3
2,KAK,CEC,CBC,0,3
3,KFK,CEC,CBC,0,3
5,KMK,CEC,CBC,0,3
...,...,...,...,...,...
61918,MAVKTGIAIGLNKGKKVTQMTPAPKISYKKGAASNRTKFVRSLVRE...,CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCHHHHHHHHHHHH...,CCCCCCCCCCCCCCCCCCCCCCCCCCCCSCCCCCHHHHHHHHHHHH...,0,100
61920,MAVKTGIAIGLNKGKKVTQMTPAPKISYKKGAASNRTKFVRSLVRE...,CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCHHHHHHHHHHHH...,CCCCCCCCCCCCCCCCCCCCCCCCCCCCCTTCCCHHHHHHHHHHHH...,0,100
61921,RYNDYKLDFRRQQMQDFFLAHKDEEWFRSKYHPDEVGKRRQEARGA...,CCCCHHHHHHHHHHHHHHHHCCCCHHHHHHHCHHHHHHHHHHHHHH...,CCCCHHHHHHHHHHHHHHHHTSSCHHHHHHHCHHHHHHHHHHHHHH...,0,100
61922,RYNDYKLDFRRQQMQDFFLAHKDEEWFRSKYHPDEVGKRRQEARGA...,CCCCCHHHHHHHHHHHHHHHCCCCHHHHHHHCHHHHHHHHHHHHHH...,CCCCCHHHHHHHHHHHHHHHTSSCHHHHHHHCHHHHHHHHHHHHHH...,0,100


In [35]:
sample = sample.sample(frac=1)

In [36]:
SOS_token = 0

class Lang:
    def __init__(self):
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS"}
        self.n_words = len(self.index2word)

    def addSentence(self, sentence):
        for word in list(sentence):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

In [37]:
def prepareData(lang1, lang2):

    input_lang = Lang()
    output_lang = Lang() 

    pairs = list(zip(lang1,lang2))

    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])

    print("Counted words:")
    print(f"Sequence: {input_lang.n_words}")
    print(f"Structure: {output_lang.n_words}")
    return input_lang, output_lang, pairs

input_lang, output_lang, pairs = prepareData(sample["seq"], sample["sst8"])

print(random.choice(pairs))

Counted words:
Sequence: 22
Structure: 9
('MIQRTPKIQVYSRHPAENGKSNFLNCYVSGFHPSDIEVDLLKNGERIEKVEHSDLSFSKDWSFYLLYYTEFTPTEKDEYACRVNHVTLSQPKIVKWDRDM', 'CCCBCCEEEEEESSCCCTTSCEEEEEEEEEEBSSCCEEEEEETTEEESCCEECCCEECTTSCEEEEEEEEECCCSSCCEEEEEECTTCSSCEEEECCTTC')


In [38]:
MAX_LENGTH = sample["len"].max()+1

In [39]:
def indexesFromSentence(lang, sentence):
    return [lang.word2index[word] for word in list(sentence)]

def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    return torch.tensor(indexes, dtype=torch.long).view(1, -1)

def tensorsFromPair(pair):
    input_tensor = tensorFromSentence(input_lang, pair[0])
    target_tensor = tensorFromSentence(output_lang, pair[1])
    return (input_tensor, target_tensor)

n = len(pairs)
input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)
target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)

for idx, (inp, tgt) in enumerate(pairs):
    inp_ids = indexesFromSentence(input_lang, inp)
    tgt_ids = indexesFromSentence(output_lang, tgt)
    input_ids[idx, :len(inp_ids)] = inp_ids
    target_ids[idx, :len(tgt_ids)] = tgt_ids

In [40]:
train_size = int(len(input_ids)*0.6)
test_size = int(len(input_ids)*0.2)

X = input_ids
y = target_ids

X_train = torch.tensor(X[:train_size], dtype=torch.long)
y_train = torch.tensor(y[:train_size],dtype=torch.long)

X_test = torch.tensor(X[train_size:train_size+test_size],dtype=torch.long)
y_test = torch.tensor(y[train_size:train_size+test_size],dtype=torch.long)

X_val = torch.tensor(X[train_size+test_size:],dtype=torch.long)
y_val = torch.tensor(y[train_size+test_size:],dtype=torch.long)

In [41]:
SOS_freq = (torch.tensor(y, dtype=torch.long).shape[0] * torch.tensor(y, dtype=torch.long).shape[1]) - torch.count_nonzero(torch.tensor(y, dtype=torch.long))

In [42]:
vocab = output_lang.word2index
word_freq = output_lang.word2count

vocab.update({"SOS":0})
word_freq.update({"SOS":int(SOS_freq)})

vocab_size = len(vocab)

weights = torch.zeros(vocab_size)

for word, idx in vocab.items():
    weights[idx] = 1.0 / (word_freq[word]) 
    
weights = weights / weights.sum()
print(weights)

tensor([2.6514e-04, 5.4273e-04, 1.8337e-03, 1.6556e-03, 7.0960e-03, 5.4944e-04,
        9.1033e-04, 1.4653e-02, 9.7249e-01])


### Model

In [43]:
class LSTM(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size, output_size, num_layers=1):
        super().__init__()

        self.inpit_size = input_size
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.embed = nn.Embedding(input_size, embedding_size)
        self.lstm = nn.LSTM(embedding_size, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(2*hidden_size, output_size)

    def forward(self, x):
        x = self.embed(x)
        x = F.relu(x)
        x , hidden = self.lstm(x)
        x = self.fc(x)

        x = F.log_softmax(x, dim = -1)
        return x


In [44]:
random_num = 20

In [45]:
X_train[random_num]

tensor([ 6,  9, 10,  3,  1, 19, 19,  4,  9,  2, 12,  2,  8,  8,  4,  9,  9, 11,
        12,  4, 10,  4,  9, 15,  4, 10,  1,  4, 10,  9,  2,  1, 15,  1,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0])

In [46]:
y_train[random_num]

tensor([1, 1, 7, 1, 2, 2, 2, 7, 1, 1, 3, 3, 3, 1, 1, 7, 1, 2, 2, 2, 1, 6, 6, 6,
        3, 3, 3, 3, 6, 6, 6, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0])

In [70]:
learning_rate=0.001
batch_size = 128
hidden_size = 256
n_epochs = 100

patience = 5
best_result = np.inf

In [71]:
train_loader = DataLoader(list(zip(X_train,y_train)), batch_size=batch_size)
test_loader = DataLoader(list(zip(X_test,y_test)), batch_size=batch_size)

In [72]:
model = LSTM(input_lang.n_words, 64, hidden_size, output_lang.n_words)

In [73]:
optimizer = torch.optim.Adam(params = model.parameters(), lr = learning_rate)
loss_fn = torch.nn.CrossEntropyLoss(weight=weights)

In [74]:
test_loss_array = []

for epoch in range(n_epochs):

    total_loss = 0
    for X_batch, y_batch in train_loader:
        
        optimizer.zero_grad()

        output = model(X_batch)

        loss = loss_fn(output.permute(1,2,0), y_batch.permute(1,0))
        
        loss.backward()
        optimizer.step()

        total_loss += loss.item() 


    test_loss = 0
    with torch.no_grad():  

        for X_batch, y_batch in test_loader:

            output = model(X_batch)
            loss = loss_fn(output.permute(1,2,0), y_batch.permute(1,0))

            test_loss+=loss

    loss = total_loss / (len(X_train) // batch_size)
    loss_test = test_loss / (len(y_test) // batch_size)

    test_loss_array.append(loss_test)

    if loss_test < best_result:
        torch.save(model.state_dict(), "./lstm_8.pth")

    print(f"Epoch: {epoch}, Train loss: {loss}, Test loss: {loss_test}")

    if len(test_loss_array)>patience+1:
        if not (any(x > (test_loss_array[-1]+0.015) for x in test_loss_array[len(test_loss_array)-patience-1:-1])):
            break
        

Epoch: 0, Train loss: 1.7076015960086475, Test loss: 1.5743149518966675
Epoch: 1, Train loss: 1.5364264592979893, Test loss: 1.4986915588378906
Epoch: 2, Train loss: 1.4531455690210515, Test loss: 1.4474087953567505
Epoch: 3, Train loss: 1.3635973849079825, Test loss: 1.4328032732009888
Epoch: 4, Train loss: 1.2803210543863701, Test loss: 1.4291855096817017
Epoch: 5, Train loss: 1.2039777820158486, Test loss: 1.3789594173431396
Epoch: 6, Train loss: 1.1106386618180708, Test loss: 1.3418521881103516
Epoch: 7, Train loss: 1.057589865995176, Test loss: 1.331078290939331
Epoch: 8, Train loss: 1.0016103293558565, Test loss: 1.2592906951904297
Epoch: 9, Train loss: 0.9421465243353988, Test loss: 1.249661922454834
Epoch: 10, Train loss: 0.8877375381763535, Test loss: 1.2112178802490234
Epoch: 11, Train loss: 0.823029675567993, Test loss: 1.1888582706451416
Epoch: 12, Train loss: 0.7727780019996142, Test loss: 1.1840921640396118
Epoch: 13, Train loss: 0.7432653459936681, Test loss: 1.177272081

In [75]:
model.load_state_dict(torch.load("./lstm_8.pth"))

<All keys matched successfully>

In [76]:
with torch.no_grad():
    
    outputs_pred = model(X_val)

    _, topi = outputs_pred.topk(1)
    decoded_ids = topi.squeeze()

    pred = []
    for idx in decoded_ids:
        decoded_structure = []
        for id in idx:
            if id.item() == SOS_token:
                break
            decoded_structure.append(output_lang.index2word[id.item()])
        pred.append("".join(decoded_structure))
    
    print(pred)

['CCBCCBSCC', 'CCSSSEEEEEEETSHCESEBCCCTCEECHHHHHHHHHSSHHTCCCSCCCCCCCCSCCCCCCCHHHHHHHHHCCC', 'CCCBCCCTHHHHHHHHHHHGGGCBECCCCC', 'CCCCCCCTTGGGCCCCEEEEEEEEEECSSBHHHHHHHHHHTTSTTEEEEBBC', 'CCCCCCBCHHHHTGCBSSSSBEEEETTEEEEETTTTTTCTTCSHHCCTTTSSBCCHHHHCBTCCSHHHHHGGGTBEEEBCGGGGTTBCCCCSCC', 'CCSSCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHSCCC', 'CCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHCC', 'CCCSEEEEEBCCCBGGGGTTCGGCTTTTBBSSCCCCCCCCCCCCC', 'CCCCCCCCCCCCCCCCSCGGGTTCCIIIIIGTCCCTSHHHHHHHHHHCC', 'CCBBEEEEEETTBEBBGTTCEEECTTCBEEEEETTSCCCCBC', 'CCSSBCCBBBSBBSSBCBTTTTTSSBSESSSSSCTTTSBCSSBSSSSSCCCTTSBESSBBEBSSBBSSSSTSSSSSCTSSSSSSEHHHHHTTTTC', 'CCSCCCCCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHCCCCC', 'CBBEECCBBSTSCEEEEEESSCECCSTTTTBTSSSCSTSSSSSSGGGEEEBSSTBBBCTCCCEHHHHHHHHHHHTEEEEEBSSSCEECCCGGGGGCCC', 'CCCSSSCCC', 'CCCCCCCCCCCBCCSCSCCCSCCCCSTHHGHTTTTTCCSCCCC', 'CCCCCBCBSSCCSSSSSSCBCGGGEEEEETTTEBBSSSSSEEEEEESSEEECHTTHSHHHHEHECCBBCSSCBC', 'CCCCCCCCCCCCCCBBBBTCEEEE

In [77]:
target=[]
for idx in y_val:
    decoded_structure = []
    for id in idx:
        if id.item() == SOS_token:
            break
        decoded_structure.append(output_lang.index2word[id.item()])
    target.append("".join(decoded_structure))

print(target)

['CCCSSCCCC', 'CCTTSEEEEEEESCCCCCCCCCCCCCCCCHHHHHHHHHHHHHCCHHHHHHHSSCHHHHHHHHHHHHHHHHHHHC', 'CCEECCCTHHHHHHHHHHHGGGCEEECTTC', 'CCCCCCCCCCCCCCCCEEEEEEEEEECSSHHHHHHHHHHHHTCTTSEEEEBC', 'CCSSCEECHHHHTTSEETTEEEEECSSEEEESSTTTTTCTTCSHHHHHTTTSBCHHHHHTTTCCHHHHHHTTTTEEEEECHHHHTTSSCSSCCC', 'CCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHCCC', 'CCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTCCC', 'CCCCCCHHHHSCCBGGGCTTCHHHHHHCBCCCCCCCCCCCCCCCC', 'CCCCCCCCCCCCCCCTTCCSSTTCCTTTTTTTCCCSSHHHHHHHHHTTC', 'CCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHCCCCCCC', 'CCSCSCCCSSSSCSSTTHHHHHHHHHHHHHTTTCTTHHHHHHHHHHHHCHHHHHHHHHHHHHHHHHHHHHSCSSSSCSSCSSCHHHHHHHHHSCC', 'CCCCCCCCSSCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHCC', 'CEEEEEEEETTEEEEEEEETTSCHHHHHHHTTCCCCCSSCSSSSSTTEEEEEESCEECTTCCSSCHHHHHTTEEEGGGCEESSCEEEECCCHHHHHHC', 'CCSSCCCCC', 'CCCCCCCCCCCCCCCCTTCCCCCCCHHHHHHHHHHHHHHCCCC', 'CCCCCBCCEEECSBCCCSCCCGGGEEEEEECCGGGCTTEEEEEESTTCEEEEETTSHHHHHHHHTSCBCCCCCC', 'CCCCCCCCCCCTTCCCCEEEESTT

The results are significantly lower than with SST3 model. They would definitely require improvements in further steps but for now this model will be used with the app with and users will be informed that this model is less accurate. 

In [78]:
def char_level_metrics(predictions, targets):
    accuracy = 0
    f1 = 0
    
    for pred, target in zip(list(predictions), list(targets)):
        if len(pred)<len(target):
            pred = pred + ("$" * (len(target)-len(pred)))
        if len(pred)>len(target):
            target = target + ("$" * (len(pred)-len(target)))

        accuracy += accuracy_score(list(pred),list(target))
        f1 += f1_score(list(pred),list(target), average="macro")

    return accuracy/len(predictions), f1/len(predictions)

ac, f1 = char_level_metrics(pred, target)

print(f'Character-level accuracy: {ac*100}%')
print(f'Character-level f1: {f1*100}%')
print(f'Exact match: {accuracy_score(pred,target)*100}%')

Character-level accuracy: 66.84674287881774%
Character-level f1: 47.953699335047794%
Exact match: 0.3184713375796179%
