In [1]:
import torch 
import torch.nn
from pytorch_transformers import BertTokenizer, BertModel, BertForMaskedLM, AdamW, WarmupLinearSchedule
import logging
import pandas as pd
from biopandas.pdb import PandasPdb
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

In [2]:
sample_sub = pd.read_csv("/srv01/technion/morant/Storage/sample_submission.csv")
test = pd.read_csv("/srv01/technion/morant/Storage/test.csv")
train_updates = pd.read_csv("/srv01/technion/morant/Storage/train_updates_20220929.csv")

In [3]:
class CustomProteinDataset(Dataset):
    def __init__(self, csv_file, wt_struc_pred):
        self.csv_file = pd.read_csv(csv_file)
        self.wt_struc_pred = PandasPdb().read_pdb(wt_struc_pred)
        
        # Tokenization of train
        aa2num = {'A': 1, 'R': 2, 'N': 3, 'D': 4, 'C': 5, 'Q': 6, 'E': 7, 'G': 8, 'H': 9, 'I': 10, 'L': 11, 'K': 12, 'M': 13,
                  'F': 14, 'P': 15, 'O': 16, 'S': 17, 'U': 18, 'T': 19, 'W': 20, 'Y': 21, 'V': 22, 'B': 23, 'Z': 24, 'X': 25, 'J': 26}

        # Tokenization!!
        self.csv_file['protein_sequence_tokenized'] = self.csv_file['protein_sequence'].apply(lambda s: [aa2num[x] for x in s])
        self.csv_file['len_Before_tokenization'] = self.csv_file['protein_sequence'].apply(len)
        self.csv_file = self.csv_file[self.csv_file['len_Before_tokenization']<=512].reset_index()
        max_len = self.csv_file['protein_sequence_tokenized'].apply(len).max()

        self.csv_file['protein_sequence_tokenized'] = self.csv_file['protein_sequence_tokenized'].apply(
            lambda x: np.pad(x, (0, max_len-len(x))))
        self.tokens_tensor = self.csv_file['protein_sequence_tokenized']
        self.tokens_tensor = torch.tensor(np.array([ x for x in self.tokens_tensor.values ]))[:, :512]
        self.tokens_mskd_tensor = self.csv_file['len_Before_tokenization']
        self.tokens_mskd_tensor = torch.tensor([np.pad(np.ones(x),(0,max_len-x))
                                                for x in self.tokens_mskd_tensor.values])[:, :512].float()
        self.tm_tensor = torch.tensor(self.csv_file['tm'])
        

    def __len__(self):
        return len(self.csv_file)

    def __getitem__(self, idx): 
        return self.tokens_tensor[idx], self.tokens_mskd_tensor[idx],  self.tm_tensor[idx]

In [4]:
training_data = CustomProteinDataset('/srv01/technion/morant/Storage/train.csv',
                                     '/srv01/technion/morant/Storage/wildtype_structure_prediction_af2.pdb')
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
# test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

  self.tokens_mskd_tensor = torch.tensor([np.pad(np.ones(x),(0,max_len-x))


In [5]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = BertModel.from_pretrained('bert-base-uncased')
        self.linear = torch.nn.Sequential(torch.nn.Linear(in_features=393216, out_features=20000),
                                          torch.nn.ReLU(),
                                          torch.nn.Linear(in_features=20000, out_features=10000),
                                          torch.nn.ReLU(),
                                          torch.nn.Linear(in_features=10000, out_features=1000),
                                          torch.nn.ReLU(),
                                          torch.nn.Linear(in_features=1000, out_features=1))
        
    def forward(self, batch, attention_mask):
        result = self.model(batch, token_type_ids=torch.zeros_like(batch), attention_mask=attention_mask)[0]
        res_flat = torch.flatten(result, start_dim=1)
        lin = self.linear(res_flat)

        return lin        

In [6]:
model = Model()
loss = torch.nn.MSELoss()

In [8]:
# Training (when we'll get there)
# Parameters:
lr = 1e-7
max_grad_norm = 0.7
num_total_steps = 1000
num_warmup_steps = 500
warmup_proportion = float(num_warmup_steps) / float(num_total_steps)  # 0.1

### In PyTorch-Transformers, optimizer and schedules are splitted and instantiated like this:
optimizer = AdamW(model.parameters(), lr=lr, correct_bias=False)  # To reproduce BertAdam specific behavior set correct_bias=False
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_total=num_total_steps)  # PyTorch scheduler
### and used like this:
for i in range(num_total_steps):
    for batch,attention_mask, train_tm in train_dataloader:
#     for batch,attention_mask, train_tm in zip(batched_tok_ten, batched_tok_mskd, batched_train_tm):
        loss_new = loss(model(batch, attention_mask),train_tm.float()[:,None])
        loss_new.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)  # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
        optimizer.step()
        scheduler.step()
        print(f"loss_new={loss_new}")


loss_new=326.58544921875
loss_new=349.9559326171875
loss_new=289.32427978515625
loss_new=243.9564666748047
loss_new=499.64599609375
loss_new=340.3167419433594
loss_new=390.4734191894531
loss_new=339.09759521484375
loss_new=309.1086730957031
loss_new=326.2749938964844
loss_new=289.9901123046875
loss_new=272.30377197265625
loss_new=253.0203857421875
loss_new=381.80572509765625
loss_new=397.52154541015625
loss_new=370.40472412109375
loss_new=305.77520751953125
loss_new=191.43597412109375
loss_new=369.45086669921875
loss_new=448.6501159667969
loss_new=419.50274658203125
loss_new=239.0080108642578
loss_new=338.41046142578125
loss_new=457.9380798339844
loss_new=516.4689331054688
loss_new=346.7664794921875
loss_new=309.12725830078125
loss_new=323.1715087890625
loss_new=325.1849670410156
loss_new=253.74864196777344
loss_new=316.572021484375
loss_new=314.2746887207031
loss_new=325.5556945800781
loss_new=428.58306884765625
loss_new=292.0516357421875
loss_new=350.1146545410156
loss_new=280.992004

loss_new=315.3497619628906
loss_new=274.94500732421875
loss_new=278.5273132324219
loss_new=202.3355712890625
loss_new=269.4154052734375
loss_new=191.351318359375
loss_new=208.30316162109375
loss_new=333.724609375
loss_new=279.8355712890625
loss_new=270.4465637207031
loss_new=199.1775360107422
loss_new=192.83511352539062
loss_new=241.1897735595703
loss_new=201.81417846679688
loss_new=141.26068115234375
loss_new=147.46449279785156
loss_new=225.9024658203125
loss_new=120.71314239501953
loss_new=180.33358764648438
loss_new=119.72064971923828
loss_new=227.5009002685547
loss_new=264.8646240234375
loss_new=273.8433837890625
loss_new=222.6355438232422
loss_new=262.9760437011719
loss_new=127.4345703125
loss_new=289.230712890625
loss_new=222.5192413330078
loss_new=103.15982055664062
loss_new=317.38214111328125
loss_new=203.0672149658203
loss_new=238.67633056640625
loss_new=258.5260009765625
loss_new=291.3746643066406
loss_new=243.1175537109375
loss_new=249.37588500976562
loss_new=254.93206787109

KeyboardInterrupt: 

In [None]:
# Running test

In [None]:
import os

In [26]:
os.environ

environ{'SHELL': '/bin/bash',
        'JPY_API_TOKEN': 'b2c0e785b69842beb3fd9231d0290cb6',
        'USER': 'morant',
        'JUPYTERHUB_BASE_URL': '/',
        'JUPYTERHUB_CLIENT_ID': 'jupyterhub-user-morant',
        'JUPYTERHUB_API_TOKEN': 'b2c0e785b69842beb3fd9231d0290cb6',
        'PATH': '/Local/md_kaplan/anaconda3/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin',
        'MKL_NUM_THREADS': '3',
        'PWD': '/srv01/technion/morant',
        'JUPYTERHUB_SERVER_NAME': '',
        'LANG': 'en_US.UTF-8',
        'JUPYTERHUB_API_URL': 'http://127.0.0.1:8002/hub/api',
        'SHLVL': '0',
        'HOME': '/srv01/technion/morant',
        'JUPYTERHUB_USER': 'morant',
        'JUPYTERHUB_ACTIVITY_URL': 'http://127.0.0.1:8002/hub/api/users/morant/activity',
        'JUPYTERHUB_OAUTH_CALLBACK_URL': '/user/morant/oauth_callback',
        'JUPYTERHUB_HOST': '',
        'JUPYTERHUB_SERVICE_PREFIX': '/user/morant/',
        'PYDEVD_USE_FRAME_EVAL': 'NO',
        'JPY_PARENT_PID': '37