In [3]:
import torch 
import torch.nn
from pytorch_transformers import BertTokenizer, BertModel, BertForMaskedLM, AdamW, WarmupLinearSchedule
import logging
import pandas as pd
from biopandas.pdb import PandasPdb
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

In [None]:
sample_sub = pd.read_csv("/srv01/technion/morant/Storage/sample_submission.csv")
test = pd.read_csv("/srv01/technion/morant/Storage/test.csv")
train_updates = pd.read_csv("/srv01/technion/morant/Storage/train_updates_20220929.csv")

In [4]:
class CustomProteinDataset(Dataset):
    def __init__(self, csv_file, wt_struc_pred):
        self.csv_file = pd.read_csv(csv_file)
        self.wt_struc_pred = PandasPdb().read_pdb(wt_struc_pred)
        
        # Tokenization of train
        aa2num = {'A': 1, 'R': 2, 'N': 3, 'D': 4, 'C': 5, 'Q': 6, 'E': 7, 'G': 8, 'H': 9, 'I': 10, 'L': 11, 'K': 12, 'M': 13,
                  'F': 14, 'P': 15, 'O': 16, 'S': 17, 'U': 18, 'T': 19, 'W': 20, 'Y': 21, 'V': 22, 'B': 23, 'Z': 24, 'X': 25, 'J': 26}

        # Tokenization!!
        self.csv_file['protein_sequence_tokenized'] = self.csv_file['protein_sequence'].apply(lambda s: [aa2num[x] for x in s])
        self.csv_file['len_Before_tokenization'] = self.csv_file['protein_sequence'].apply(len)
        max_len = self.csv_file['protein_sequence_tokenized'].apply(len).max()

        self.csv_file['protein_sequence_tokenized'] = self.csv_file['protein_sequence_tokenized'].apply(
            lambda x: np.pad(x, (0, max_len-len(x))))
        self.tokens_tensor = self.csv_file['protein_sequence_tokenized']
        self.tokens_tensor = torch.tensor(np.array([ x for x in self.tokens_tensor.values ]))[:, :512]
        self.tokens_mskd_tensor = self.csv_file['len_Before_tokenization']
        self.tokens_mskd_tensor = torch.tensor([np.pad(np.ones(x),(0,max_len-x))
                                                for x in self.tokens_mskd_tensor.values])[:, :512].float()
        self.tm_tensor = torch.tensor(self.csv_file['tm'])
        

    def __len__(self):
        return len(self.csv_file)

    def __getitem__(self, idx): 
        return self.tokens_tensor[idx], self.tokens_mskd_tensor[idx],  self.tm_tensor[idx]

In [5]:
training_data = CustomProteinDataset('/srv01/technion/morant/Storage/train.csv',
                                     '/srv01/technion/morant/Storage/wildtype_structure_prediction_af2.pdb')
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
# test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

  self.tokens_mskd_tensor = torch.tensor([np.pad(np.ones(x),(0,max_len-x))


In [13]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = BertModel.from_pretrained('bert-base-uncased')
        self.linear = torch.nn.Sequential(torch.nn.Linear(in_features=393216, out_features=10000),
                                          torch.nn.ReLU(),
                                          torch.nn.Linear(in_features=10000, out_features=1000),
                                          torch.nn.ReLU(),
                                          torch.nn.Linear(in_features=1000, out_features=1))
        
    def forward(self, batch, attention_mask):
        result = self.model(batch, token_type_ids=torch.zeros_like(batch), attention_mask=attention_mask)[0]
        res_flat = torch.flatten(result, start_dim=1)
        lin = self.linear(res_flat)

        return lin        

In [14]:
model = Model()
loss = torch.nn.MSELoss()

In [None]:
# Training (when we'll get there)
# Parameters:
lr = 1e-5
max_grad_norm = 0.7
num_total_steps = 1000
num_warmup_steps = 500
warmup_proportion = float(num_warmup_steps) / float(num_total_steps)  # 0.1

### In PyTorch-Transformers, optimizer and schedules are splitted and instantiated like this:
optimizer = AdamW(model.parameters(), lr=lr, correct_bias=False)  # To reproduce BertAdam specific behavior set correct_bias=False
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_total=num_total_steps)  # PyTorch scheduler
### and used like this:
for i in range(num_total_steps):
    for batch,attention_mask, train_tm in train_dataloader:
#     for batch,attention_mask, train_tm in zip(batched_tok_ten, batched_tok_mskd, batched_train_tm):
        loss_new = loss(model(batch, attention_mask),train_tm.float()[:,None])
        loss_new.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)  # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
        optimizer.step()
        scheduler.step()
        print(f"loss_new={loss_new}")
