In [2]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import random
from sklearn.metrics import accuracy_score,f1_score
import json

### Preprocessing

In [3]:
proteins = pd.read_csv("../Data/2018-06-06-ss.cleaned.csv")

In [4]:
proteins

Unnamed: 0,pdb_id,chain_code,seq,sst8,sst3,len,has_nonstd_aa
0,1A30,C,EDL,CBC,CEC,3,False
1,1B05,B,KCK,CBC,CEC,3,False
2,1B0H,B,KAK,CBC,CEC,3,False
3,1B1H,B,KFK,CBC,CEC,3,False
4,1B2H,B,KAK,CBC,CEC,3,False
...,...,...,...,...,...,...,...
393727,4UWE,D,MGDGGEGEDEVQFLRTDDEVVLQCSATVLKEQLKLCLAAEGFGNRL...,CCCCCCCCCCCCCCBTTCEEEEEEEEEETTEEEEEEEECCCSSCCB...,CCCCCCCCCCCCCCECCCEEEEEEEEEECCEEEEEEEECCCCCCCE...,5037,True
393728,5J8V,A,MGDGGEGEDEVQFLRTDDEVVLQCSATVLKEQLKLCLAAEGFGNRL...,CCCCCCCCCCCCCCCSSSCCEEEECSEETTEECCEECCEEETTEEE...,CCCCCCCCCCCCCCCCCCCCEEEECCEECCEECCEECCEEECCEEE...,5037,False
393729,5J8V,B,MGDGGEGEDEVQFLRTDDEVVLQCSATVLKEQLKLCLAAEGFGNRL...,CCCCCCCCCCCCCCCSSSCCEEEECSEETTEECCEECCEEETTEEE...,CCCCCCCCCCCCCCCCCCCCEEEECCEECCEECCEECCEEECCEEE...,5037,False
393730,5J8V,C,MGDGGEGEDEVQFLRTDDEVVLQCSATVLKEQLKLCLAAEGFGNRL...,CCCCCCCCCCCCCCCSSSCCEEEECSEETTEECCEECCEEETTEEE...,CCCCCCCCCCCCCCCCCCCCEEEECCEECCEECCEECCEEECCEEE...,5037,False


In [5]:
def remove_empty(sequence):
    s = set(list(sequence))
    if len(s)==1:
        letter = s.pop()
        if letter == "*": 
            return 1
        else: 
            return 0
    return 0 

In [6]:
sample = proteins[
    (proteins["len"]>=1) &
    (proteins["len"]<=100)]

In [7]:
sample = sample[["seq","sst3","sst8"]]
sample = sample.drop_duplicates()

In [8]:
sample["remove"] = sample["seq"].apply(remove_empty) 
sample = sample[sample["remove"]==0].copy()

In [9]:
sample["len"] = sample["seq"].apply(len)
sample

Unnamed: 0,seq,sst3,sst8,remove,len
0,EDL,CEC,CBC,0,3
1,KCK,CEC,CBC,0,3
2,KAK,CEC,CBC,0,3
3,KFK,CEC,CBC,0,3
5,KMK,CEC,CBC,0,3
...,...,...,...,...,...
61918,MAVKTGIAIGLNKGKKVTQMTPAPKISYKKGAASNRTKFVRSLVRE...,CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCHHHHHHHHHHHH...,CCCCCCCCCCCCCCCCCCCCCCCCCCCCSCCCCCHHHHHHHHHHHH...,0,100
61920,MAVKTGIAIGLNKGKKVTQMTPAPKISYKKGAASNRTKFVRSLVRE...,CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCHHHHHHHHHHHH...,CCCCCCCCCCCCCCCCCCCCCCCCCCCCCTTCCCHHHHHHHHHHHH...,0,100
61921,RYNDYKLDFRRQQMQDFFLAHKDEEWFRSKYHPDEVGKRRQEARGA...,CCCCHHHHHHHHHHHHHHHHCCCCHHHHHHHCHHHHHHHHHHHHHH...,CCCCHHHHHHHHHHHHHHHHTSSCHHHHHHHCHHHHHHHHHHHHHH...,0,100
61922,RYNDYKLDFRRQQMQDFFLAHKDEEWFRSKYHPDEVGKRRQEARGA...,CCCCCHHHHHHHHHHHHHHHCCCCHHHHHHHCHHHHHHHHHHHHHH...,CCCCCHHHHHHHHHHHHHHHTSSCHHHHHHHCHHHHHHHHHHHHHH...,0,100


In [10]:
sample = sample.sample(frac=1)

In [34]:
with open('../static/input_char2index.json') as f:
    input_word2index = json.load(f)

with open('../static/input_index2char.json') as f:
    input_index2word = json.load(f)


with open('../static/output_char2index8.json') as f:
    output_word2index = json.load(f)

with open('../static/output_index2char8.json') as f:
    output_index2word = json.load(f)

In [35]:
SOS_token = 0

class Lang:
    def __init__(self, word2index={}, index2word= {0: "SOS"}):
        self.word2index = word2index
        self.word2count = {}
        self.index2word = index2word
        self.n_words = len(self.index2word)

    def addSentence(self, sentence):
        for word in list(sentence):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            if word not in self.word2count:
                self.word2count[word] = 1
            else:
                self.word2count[word] += 1

In [36]:
def prepareData(lang1, lang2, input_word2index, input_index2word,output_word2index,output_index2word):

    input_lang = Lang(input_word2index, input_index2word)
    output_lang = Lang(output_word2index,output_index2word) 

    pairs = list(zip(lang1,lang2))

    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])

    print("Counted words:")
    print(f"Sequence: {input_lang.n_words}")
    print(f"Structure: {output_lang.n_words}")
    return input_lang, output_lang, pairs

input_lang, output_lang, pairs = prepareData(sample["seq"], sample["sst8"],input_word2index, input_index2word,output_word2index,output_index2word)

print(random.choice(pairs))

Counted words:
Sequence: 22
Structure: 9
('GPSQPKVPEWVNTPSTCCLKYYEKVLPRRLVVGYRKALNCHLPAIIFVTKRNREVCTNPNDDWVQEYIKDPNLPLLPTRNLSTVKIITAKNGQPQLLNSQ', 'CCCCCCCCCCCCSCEEECSSCCSSCCCGGGEEEEEEETTSSSCEEEEEETTSCEEEECTTSHHHHHHHTCTTCCBCCCCCCCCCCCCCCCCCCCCCCCCC')


In [37]:
MAX_LENGTH = sample["len"].max()+1

In [38]:
def indexesFromSentence(lang, sentence):
    return [lang.word2index[word] for word in list(sentence)]

def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    return torch.tensor(indexes, dtype=torch.long).view(1, -1)

def tensorsFromPair(pair):
    input_tensor = tensorFromSentence(input_lang, pair[0])
    target_tensor = tensorFromSentence(output_lang, pair[1])
    return (input_tensor, target_tensor)

n = len(pairs)
input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)
target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)

for idx, (inp, tgt) in enumerate(pairs):
    inp_ids = indexesFromSentence(input_lang, inp)
    tgt_ids = indexesFromSentence(output_lang, tgt)
    input_ids[idx, :len(inp_ids)] = inp_ids
    target_ids[idx, :len(tgt_ids)] = tgt_ids

In [39]:
train_size = int(len(input_ids)*0.6)
test_size = int(len(input_ids)*0.2)

X = input_ids
y = target_ids

X_train = torch.tensor(X[:train_size], dtype=torch.long)
y_train = torch.tensor(y[:train_size],dtype=torch.long)

X_test = torch.tensor(X[train_size:train_size+test_size],dtype=torch.long)
y_test = torch.tensor(y[train_size:train_size+test_size],dtype=torch.long)

X_val = torch.tensor(X[train_size+test_size:],dtype=torch.long)
y_val = torch.tensor(y[train_size+test_size:],dtype=torch.long)

In [40]:
SOS_freq = (torch.tensor(y, dtype=torch.long).shape[0] * torch.tensor(y, dtype=torch.long).shape[1]) - torch.count_nonzero(torch.tensor(y, dtype=torch.long))

In [41]:
vocab = output_lang.word2index
word_freq = output_lang.word2count

vocab.update({"SOS":0})
word_freq.update({"SOS":int(SOS_freq)})

vocab_size = len(vocab)

weights = torch.zeros(vocab_size)

for word, idx in vocab.items():
    weights[idx] = 1.0 / (word_freq[word]) 
    
weights = weights / weights.sum()
print(weights)

tensor([2.6514e-04, 5.4273e-04, 1.4653e-02, 9.1033e-04, 1.6556e-03, 1.8337e-03,
        7.0960e-03, 5.4944e-04, 9.7249e-01])


### Model

In [42]:
class LSTM(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size, output_size, num_layers=1):
        super().__init__()

        self.inpit_size = input_size
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.embed = nn.Embedding(input_size, embedding_size)
        self.lstm = nn.LSTM(embedding_size, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(2*hidden_size, output_size)

    def forward(self, x):
        x = self.embed(x)
        x = F.relu(x)
        x , hidden = self.lstm(x)
        x = self.fc(x)

        x = F.log_softmax(x, dim = -1)
        return x


In [43]:
random_num = 20

In [44]:
X_train[random_num]

tensor([11, 13, 18, 13, 13, 18, 12, 14,  4,  3,  1, 18, 18,  6,  6, 14, 12, 14,
        13,  3,  3, 10, 13, 21,  2,  6, 12,  6, 18, 14, 18, 18,  9, 15, 18, 10,
        14, 15, 11,  1, 14, 11, 11, 16, 13, 12, 18, 17,  1,  7, 14, 18, 12, 11,
        13,  4, 13, 14,  6, 14, 10, 13, 11,  3,  4, 12, 11, 18,  2, 15, 14, 10,
        14, 18, 15, 14,  8, 15, 15, 13, 15, 13,  2,  3, 15, 13, 15, 13, 13, 12,
        10, 13, 10, 16, 15, 19, 14,  0,  0,  0,  0])

In [45]:
y_train[random_num]

tensor([1, 1, 1, 1, 1, 5, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 3, 3, 3, 3, 3, 3,
        3, 1, 1, 5, 5, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 5, 5, 5, 1, 1, 1, 1, 3,
        3, 3, 3, 3, 3, 4, 4, 1, 1, 3, 3, 3, 3, 1, 5, 1, 1, 4, 4, 1, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 3, 3, 3, 3, 1, 1, 1, 3, 3, 3, 3, 3, 3,
        1, 0, 0, 0, 0])

In [46]:
learning_rate=0.001
batch_size = 128
hidden_size = 256
n_epochs = 100

patience = 5
best_result = np.inf

In [47]:
train_loader = DataLoader(list(zip(X_train,y_train)), batch_size=batch_size)
test_loader = DataLoader(list(zip(X_test,y_test)), batch_size=batch_size)

In [48]:
model = LSTM(input_lang.n_words, 64, hidden_size, output_lang.n_words)

In [49]:
optimizer = torch.optim.Adam(params = model.parameters(), lr = learning_rate)
loss_fn = torch.nn.CrossEntropyLoss(weight=weights)

In [50]:
test_loss_array = []

for epoch in range(n_epochs):

    total_loss = 0
    for X_batch, y_batch in train_loader:
        
        optimizer.zero_grad()

        output = model(X_batch)

        loss = loss_fn(output.permute(1,2,0), y_batch.permute(1,0))
        
        loss.backward()
        optimizer.step()

        total_loss += loss.item() 


    test_loss = 0
    with torch.no_grad():  

        for X_batch, y_batch in test_loader:

            output = model(X_batch)
            loss = loss_fn(output.permute(1,2,0), y_batch.permute(1,0))

            test_loss+=loss

    loss = total_loss / (len(X_train) // batch_size)
    loss_test = test_loss / (len(y_test) // batch_size)

    test_loss_array.append(loss_test)

    if loss_test < best_result:
        torch.save(model.state_dict(), "../Models/lstm_8.pth")

    print(f"Epoch: {epoch}, Train loss: {loss}, Test loss: {loss_test}")

    if len(test_loss_array)>patience+1:
        if not (any(x > (test_loss_array[-1]+0.015) for x in test_loss_array[len(test_loss_array)-patience-1:-1])):
            break
        

Epoch: 0, Train loss: 1.6969677504867013, Test loss: 1.6217074394226074
Epoch: 1, Train loss: 1.539151581850919, Test loss: 1.5294550657272339
Epoch: 2, Train loss: 1.4470167948742105, Test loss: 1.4659918546676636
Epoch: 3, Train loss: 1.3644175165229373, Test loss: 1.44328773021698
Epoch: 4, Train loss: 1.2911029497180322, Test loss: 1.3245487213134766
Epoch: 5, Train loss: 1.2092470466488539, Test loss: 1.2775790691375732
Epoch: 6, Train loss: 1.13335273753513, Test loss: 1.2493984699249268
Epoch: 7, Train loss: 1.0429179873129335, Test loss: 1.2482428550720215
Epoch: 8, Train loss: 1.0294994323542623, Test loss: 1.1932414770126343
Epoch: 9, Train loss: 0.9872657710855658, Test loss: 1.152303695678711
Epoch: 10, Train loss: 0.8999971294342869, Test loss: 1.1191556453704834
Epoch: 11, Train loss: 0.8424947809691381, Test loss: 1.0925854444503784
Epoch: 12, Train loss: 0.7937890075974994, Test loss: 1.0788254737854004
Epoch: 13, Train loss: 0.7522346306629856, Test loss: 1.06333374977

In [56]:
model.load_state_dict(torch.load("../Models/lstm_8.pth"))
model.eval()

LSTM(
  (embed): Embedding(22, 64)
  (lstm): LSTM(64, 256, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=512, out_features=9, bias=True)
)

In [57]:
with torch.no_grad():
    
    outputs_pred = model(X_val)

    _, topi = outputs_pred.topk(1)
    decoded_ids = topi.squeeze()

In [58]:
pred = []
for idx in decoded_ids:
    decoded_structure = []
    for id in idx:
        if id.item() == SOS_token:
            break
        decoded_structure.append(output_lang.index2word[str(id.item())])
    pred.append("".join(decoded_structure))

print(pred[:5])

['CCCEEEEBCTTSSSSEEEEESSCSCECSSEEECCSCCCTCEEEETSEEEECCCCBTSSEEEEEEBCTTBBCCCCSSCCCCGC', 'CEEEEEEBTSBSSEEEEEBBSSBBBSTBSEEEEEEEEETSTCTTTSEBETTSBBCCTTCBBESSBBCHHHHHHHHTSCSEEEEEEESTCCCCCCCCC', 'CCCHHHHHHHHCCSTSSSCBBHHHHHHHHHHTTSSSHHHHHHHHTHSTSSSSSSBHHHHHHHHTTSSTTBBHHGGGGC', 'CCSSCCEESSSGSTSTTCBEEHHHHHTTTTSSSSSSBSTTSSBCTTTBSTGGGTSSCBEHHETHBTTTTGHTHTTTEEEEECCSSCC', 'CEEEESSCBCCSSSTBCCBCTTCEEEEEECSSSSEEEEEETTTCCEEEEEGGGEEEC']


In [59]:
target=[]
for idx in y_val:
    decoded_structure = []
    for id in idx:
        if id.item() == SOS_token:
            break
        decoded_structure.append(output_lang.index2word[str(id.item())])
    target.append("".join(decoded_structure))

print(target[:5])

['CEEEEEEEEEEEBSSCCEEEEGGGTCCSSEEECCCCCSSSSEEEEEEEEEETTHHHHSCCEEEEEETTTTEEEEEECCCCCC', 'CEEEECCBTTBTSEEEEESTTCCSBTTBCSEEEEEECTTCHHHHTCCCCTTCEEEEETTEECTTCCHHHHHHHHHSCCSCEEEEEECCSSSCCCCCC', 'CCCHHHHHHHHHCTTCSSEECHHHHHHHHHHHHTCCHHHHHHHHHHHCTTCSSSEEHHHHHHHHHHCHHHHHHHHTTC', 'CCCBCCCTTCTTCSSCCSHHHHHHHHHHHHCGGGSCCCHHHHHHHCCBSCSCTTCCCBCGGGCCCCHHHHHHHHHHHTTSCCSCBCC', 'CCCEESSCBCCCSTTBCCBCTTCBCCEEECTTSSEEEEECTTTCCEEEEEGGGEECC']


The results are significantly lower than with SST3 model. They would definitely require improvements in further steps but for now this model will be used with the app with and users will be informed that this model is less accurate. 

In [60]:
def char_level_metrics(predictions, targets):
    accuracy = 0
    f1 = 0
    
    for pred, target in zip(list(predictions), list(targets)):
        if len(pred)<len(target):
            pred = pred + ("$" * (len(target)-len(pred)))
        if len(pred)>len(target):
            target = target + ("$" * (len(pred)-len(target)))

        accuracy += accuracy_score(list(pred),list(target))
        f1 += f1_score(list(pred),list(target), average="macro")

    return accuracy/len(predictions), f1/len(predictions)

ac, f1 = char_level_metrics(pred, target)

print(f'Character-level accuracy: {ac*100}%')
print(f'Character-level f1: {f1*100}%')
print(f'Exact match: {accuracy_score(pred,target)*100}%')

Character-level accuracy: 64.07956678825676%
Character-level f1: 45.40226884967352%
Exact match: 0.21231422505307856%
