In [18]:
import torch 
import torch.nn
from pytorch_transformers import BertTokenizer, BertModel, BertForMaskedLM, AdamW, WarmupLinearSchedule
import logging
import pandas as pd
from biopandas.pdb import PandasPdb
import numpy as np


In [19]:
sample_sub = pd.read_csv("/srv01/technion/morant/Storage/sample_submission.csv")
test = pd.read_csv("/srv01/technion/morant/Storage/test.csv")
train = pd.read_csv("/srv01/technion/morant/Storage/train.csv")
train_updates = pd.read_csv("/srv01/technion/morant/Storage/train_updates_20220929.csv")
ppdb=PandasPdb()
wt_structure_pred = ppdb.read_pdb("/srv01/technion/morant/Storage/wildtype_structure_prediction_af2.pdb")

In [25]:
# Tokenization of train
aa2num = {'A': 1, 'R': 2, 'N': 3, 'D': 4, 'C': 5, 'Q': 6, 'E': 7, 'G': 8, 'H': 9, 'I': 10, 'L': 11, 'K': 12, 'M': 13,
          'F': 14, 'P': 15, 'O': 16, 'S': 17, 'U': 18, 'T': 19, 'W': 20, 'Y': 21, 'V': 22, 'B': 23, 'Z': 24, 'X': 25, 'J': 26}

# Tokenization!!
train['protein_sequence_tokenized'] = train['protein_sequence'].apply(lambda s: [aa2num[x] for x in s])
len_Before_tokenization = train['protein_sequence'].apply(len)

In [26]:
len_Before_tokenization

0         341
1         286
2         497
3         265
4        1451
         ... 
31385     549
31386     469
31387     128
31388     593
31389     537
Name: protein_sequence, Length: 31390, dtype: int64

In [28]:
train['len_Before_tokenization'] = len_Before_tokenization

In [27]:
max_len = train['protein_sequence_tokenized'].apply(len).max()

train['protein_sequence_tokenized'] = train['protein_sequence_tokenized'].apply(lambda x: np.pad(x, (0, max_len-len(x))))

In [33]:
train['len_Before_tokenization'][:2]

0    341
1    286
Name: len_Before_tokenization, dtype: int64

In [50]:
tokens_tensor = train['protein_sequence_tokenized'][:2]
tokens_tensor = torch.tensor(np.array([ x for x in tokens_tensor.values ]))[:, :512]
tokens_mskd_tensor = train['len_Before_tokenization'][:2]
tokens_mskd_tensor = torch.tensor([np.pad(np.ones(x),(0,max_len-x)) for x in tokens_mskd_tensor.values])[:, :512].float()

In [51]:
tokens_mskd_tensor

tensor([[1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.]])

In [52]:
# Load pre-trained model (weights)
model = BertModel.from_pretrained('bert-base-uncased')

# Set the model in evaluation mode to desactivate the DropOut modules
# This is IMPORTANT to have reproductible results during evaluation!
# model.eval()


# See the models docstrings for the detail of the inputs
outputs = model(tokens_tensor, token_type_ids=torch.zeros_like(tokens_tensor))
# PyTorch-Transformers models always output tuples.
# See the models docstrings for the detail of all the outputs
# In our case, the first element is the hidden state of the last layer of the Bert model
encoded_layers = outputs[0]
# We have encoded our input sequence in a FloatTensor of shape (batch size, sequence length, model hidden dimension)
# assert tuple(encoded_layers.shape) == (1, len(indexed_tokens), model.config.hidden_size)

In [53]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = BertModel.from_pretrained('bert-base-uncased')
        self.linear = torch.nn.Sequential(torch.nn.Linear(in_features=393216, out_features=1000),
                                          torch.nn.ReLU(),
                                          torch.nn.Linear(in_features=1000, out_features=1))
        
    def forward(self, batch, attention_mask):
        result = self.model(batch, token_type_ids=torch.zeros_like(batch), attention_mask=attention_mask)[0]
        res_flat = torch.flatten(result, start_dim=1)
        lin = self.linear(res_flat)

        return lin        

In [54]:
model = Model()

In [55]:
loss = torch.nn.MSELoss()

In [56]:
attention_mask.shape

torch.Size([2, 32767])

In [None]:
# Training (when we'll get there)
# Parameters:
lr = 1e-7
max_grad_norm = 0.7
num_total_steps = 50
num_warmup_steps = 20
warmup_proportion = float(num_warmup_steps) / float(num_total_steps)  # 0.1

### In PyTorch-Transformers, optimizer and schedules are splitted and instantiated like this:
optimizer = AdamW(model.parameters(), lr=lr, correct_bias=False)  # To reproduce BertAdam specific behavior set correct_bias=False
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_total=num_total_steps)  # PyTorch scheduler
### and used like this:
for i in range(num_total_steps):
    for batch,attention_mask in zip(tokens_tensor[None], tokens_mskd_tensor[None]):
        loss_new = loss(model(batch, attention_mask),torch.tensor(train['tm'].iloc[:2]).float()[:,None])
        loss_new.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)  # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
        optimizer.step()
        scheduler.step()
        print(f"loss_new={loss_new}")


loss_new=158.76083374023438
loss_new=158.76083374023438
loss_new=158.76068115234375
loss_new=158.7605438232422
loss_new=158.76034545898438
loss_new=158.76019287109375
loss_new=158.75997924804688
loss_new=158.76004028320312


In [207]:
loss_new

tensor(6734.7754, grad_fn=<MseLossBackward0>)

In [193]:
torch.tensor(train['tm'])

tensor([75.7000, 50.5000, 40.5000,  ..., 64.6000, 50.7000, 37.6000],
       dtype=torch.float64)

In [192]:
train.head()

Unnamed: 0,seq_id,protein_sequence,pH,data_source,tm,protein_sequence_tokenized
0,0,AAAAKAAALALLGEAPEVVDIWLPAGWRQPFRVFRLERKGDGVLVG...,7.0,doi.org/10.1038/s41592-020-0801-4,75.7,"[1, 1, 1, 1, 12, 1, 1, 1, 11, 1, 11, 11, 8, 7,..."
1,1,AAADGEPLHNEEERAGAGQVGRSLPQESEEQRTGSRPRRRRDLGSR...,7.0,doi.org/10.1038/s41592-020-0801-4,50.5,"[1, 1, 1, 4, 8, 7, 15, 11, 9, 3, 7, 7, 7, 2, 1..."
2,2,AAAFSTPRATSYRILSSAGSGSTRADAPQVRRLHTTRDLLAKDYYA...,7.0,doi.org/10.1038/s41592-020-0801-4,40.5,"[1, 1, 1, 14, 17, 19, 15, 2, 1, 19, 17, 21, 2,..."
3,3,AAASGLRTAIPAQPLRHLLQPAPRPCLRPFGLLSVRAGSARRSGLL...,7.0,doi.org/10.1038/s41592-020-0801-4,47.2,"[1, 1, 1, 17, 8, 11, 2, 19, 1, 10, 15, 1, 6, 1..."
4,4,AAATKSGPRRQSQGASVRTFTPFYFLVEPVDTLSVRGSSVILNCSA...,7.0,doi.org/10.1038/s41592-020-0801-4,49.5,"[1, 1, 1, 19, 12, 17, 8, 15, 2, 2, 6, 17, 6, 8..."


In [44]:
train.head()

Unnamed: 0,seq_id,protein_sequence,pH,data_source,tm
0,0,"[1, 1, 1, 1, 12, 1, 1, 1, 11, 1, 11, 11, 8, 7,...",7.0,doi.org/10.1038/s41592-020-0801-4,75.7
1,1,"[1, 1, 1, 4, 8, 7, 15, 11, 9, 3, 7, 7, 7, 2, 1...",7.0,doi.org/10.1038/s41592-020-0801-4,50.5
2,2,"[1, 1, 1, 14, 17, 19, 15, 2, 1, 19, 17, 21, 2,...",7.0,doi.org/10.1038/s41592-020-0801-4,40.5
3,3,"[1, 1, 1, 17, 8, 11, 2, 19, 1, 10, 15, 1, 6, 1...",7.0,doi.org/10.1038/s41592-020-0801-4,47.2
4,4,"[1, 1, 1, 19, 12, 17, 8, 15, 2, 2, 6, 17, 6, 8...",7.0,doi.org/10.1038/s41592-020-0801-4,49.5


In [32]:
train

Unnamed: 0,seq_id,protein_sequence,pH,data_source,tm
0,0,AAAAKAAALALLGEAPEVVDIWLPAGWRQPFRVFRLERKGDGVLVG...,7.0,doi.org/10.1038/s41592-020-0801-4,75.7
1,1,AAADGEPLHNEEERAGAGQVGRSLPQESEEQRTGSRPRRRRDLGSR...,7.0,doi.org/10.1038/s41592-020-0801-4,50.5
2,2,AAAFSTPRATSYRILSSAGSGSTRADAPQVRRLHTTRDLLAKDYYA...,7.0,doi.org/10.1038/s41592-020-0801-4,40.5
3,3,AAASGLRTAIPAQPLRHLLQPAPRPCLRPFGLLSVRAGSARRSGLL...,7.0,doi.org/10.1038/s41592-020-0801-4,47.2
4,4,AAATKSGPRRQSQGASVRTFTPFYFLVEPVDTLSVRGSSVILNCSA...,7.0,doi.org/10.1038/s41592-020-0801-4,49.5
...,...,...,...,...,...
31385,31385,YYMYSGGGSALAAGGGGAGRKGDWNDIDSIKKKDLHHSRGDEKAQG...,7.0,doi.org/10.1038/s41592-020-0801-4,51.8
31386,31386,YYNDQHRLSSYSVETAMFLSWERAIVKPGAMFKKAVIGFNCNVDLI...,7.0,doi.org/10.1038/s41592-020-0801-4,37.2
31387,31387,YYQRTLGAELLYKISFGEMPKSAQDSAENCPSGMQFPDTAIAHANV...,7.0,doi.org/10.1038/s41592-020-0801-4,64.6
31388,31388,YYSFSDNITTVFLSRQAIDDDHSLSLGTISDVVESENGVVAADDAR...,7.0,doi.org/10.1038/s41592-020-0801-4,50.7


In [5]:
test.head()

Unnamed: 0,seq_id,protein_sequence,pH,data_source
0,31390,VPVNPEPDATSVENVAEKTGSGDSQSDPIKADLEVKGQSALPFDVD...,8,Novozymes
1,31391,VPVNPEPDATSVENVAKKTGSGDSQSDPIKADLEVKGQSALPFDVD...,8,Novozymes
2,31392,VPVNPEPDATSVENVAKTGSGDSQSDPIKADLEVKGQSALPFDVDC...,8,Novozymes
3,31393,VPVNPEPDATSVENVALCTGSGDSQSDPIKADLEVKGQSALPFDVD...,8,Novozymes
4,31394,VPVNPEPDATSVENVALFTGSGDSQSDPIKADLEVKGQSALPFDVD...,8,Novozymes


In [6]:
train.head()

Unnamed: 0,seq_id,protein_sequence,pH,data_source,tm
0,0,AAAAKAAALALLGEAPEVVDIWLPAGWRQPFRVFRLERKGDGVLVG...,7.0,doi.org/10.1038/s41592-020-0801-4,75.7
1,1,AAADGEPLHNEEERAGAGQVGRSLPQESEEQRTGSRPRRRRDLGSR...,7.0,doi.org/10.1038/s41592-020-0801-4,50.5
2,2,AAAFSTPRATSYRILSSAGSGSTRADAPQVRRLHTTRDLLAKDYYA...,7.0,doi.org/10.1038/s41592-020-0801-4,40.5
3,3,AAASGLRTAIPAQPLRHLLQPAPRPCLRPFGLLSVRAGSARRSGLL...,7.0,doi.org/10.1038/s41592-020-0801-4,47.2
4,4,AAATKSGPRRQSQGASVRTFTPFYFLVEPVDTLSVRGSSVILNCSA...,7.0,doi.org/10.1038/s41592-020-0801-4,49.5


In [26]:
train_updates

Unnamed: 0,seq_id,protein_sequence,pH,data_source,tm
0,69,,,,
1,70,,,,
2,71,,,,
3,72,,,,
4,73,,,,
...,...,...,...,...,...
2429,30738,,,,
2430,30739,,,,
2431,30740,,,,
2432,30741,,,,
