In [1]:
import torch 
import torch.nn
from pytorch_transformers import BertTokenizer, BertModel, BertForMaskedLM, AdamW, WarmupLinearSchedule
import logging
import pandas as pd
from biopandas.pdb import PandasPdb
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

In [2]:
sample_sub = pd.read_csv("/srv01/technion/morant/Storage/sample_submission.csv")
test = pd.read_csv("/srv01/technion/morant/Storage/test.csv")
train_updates = pd.read_csv("/srv01/technion/morant/Storage/train_updates_20220929.csv")

In [3]:
class CustomProteinDataset(Dataset):
    def __init__(self, csv_file, wt_struc_pred):
        self.csv_file = pd.read_csv(csv_file)
        self.wt_struc_pred = PandasPdb().read_pdb(wt_struc_pred)
        
        # Tokenization of train
        aa2num = {'A': 1, 'R': 2, 'N': 3, 'D': 4, 'C': 5, 'Q': 6, 'E': 7, 'G': 8, 'H': 9, 'I': 10, 'L': 11, 'K': 12, 'M': 13,
                  'F': 14, 'P': 15, 'O': 16, 'S': 17, 'U': 18, 'T': 19, 'W': 20, 'Y': 21, 'V': 22, 'B': 23, 'Z': 24, 'X': 25, 'J': 26}

        # Tokenization!!
        self.csv_file['protein_sequence_tokenized'] = self.csv_file['protein_sequence'].apply(lambda s: [aa2num[x] for x in s])
        self.csv_file['len_Before_tokenization'] = self.csv_file['protein_sequence'].apply(len)
        self.csv_file = self.csv_file[self.csv_file['len_Before_tokenization']<=512].reset_index()
        max_len = self.csv_file['protein_sequence_tokenized'].apply(len).max()

        self.csv_file['protein_sequence_tokenized'] = self.csv_file['protein_sequence_tokenized'].apply(
            lambda x: np.pad(x, (0, max_len-len(x))))
        self.tokens_tensor = self.csv_file['protein_sequence_tokenized']
        self.tokens_tensor = torch.tensor(np.array([ x for x in self.tokens_tensor.values ]))[:, :512]
        self.tokens_mskd_tensor = self.csv_file['len_Before_tokenization']
        self.tokens_mskd_tensor = torch.tensor([np.pad(np.ones(x),(0,max_len-x))
                                                for x in self.tokens_mskd_tensor.values])[:, :512].float()
        self.tm_tensor = torch.tensor(self.csv_file['tm'])
        

    def __len__(self):
        return len(self.csv_file)

    def __getitem__(self, idx): 
        return self.tokens_tensor[idx], self.tokens_mskd_tensor[idx],  self.tm_tensor[idx]

In [None]:
training_data = CustomProteinDataset('/srv01/technion/morant/Storage/train.csv',
                                     '/srv01/technion/morant/Storage/wildtype_structure_prediction_af2.pdb')
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
# test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

In [35]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = BertModel.from_pretrained('bert-base-uncased')
        self.linear = torch.nn.Sequential(torch.nn.Linear(in_features=393216, out_features=100000),
                                          torch.nn.ReLU(),
                                          torch.nn.Linear(in_features=100000, out_features=50000),
                                          torch.nn.ReLU(),
                                          torch.nn.Linear(in_features=50000, out_features=10000),
                                          torch.nn.ReLU()
                                          torch.nn.Linear(in_features=10000, out_features=1000),
                                          torch.nn.ReLU(),
                                          torch.nn.Linear(in_features=1000, out_features=1))
        
    def forward(self, batch, attention_mask):
        result = self.model(batch, token_type_ids=torch.zeros_like(batch), attention_mask=attention_mask)[0]
        res_flat = torch.flatten(result, start_dim=1)
        lin = self.linear(res_flat)

        return lin        

In [36]:
model = Model()
loss = torch.nn.MSELoss()

In [None]:
# Training (when we'll get there)
# Parameters:
lr = 1e-7
max_grad_norm = 0.7
num_total_steps = 1000
num_warmup_steps = 500
warmup_proportion = float(num_warmup_steps) / float(num_total_steps)  # 0.1

### In PyTorch-Transformers, optimizer and schedules are splitted and instantiated like this:
optimizer = AdamW(model.parameters(), lr=lr, correct_bias=False)  # To reproduce BertAdam specific behavior set correct_bias=False
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_total=num_total_steps)  # PyTorch scheduler
### and used like this:
for i in range(num_total_steps):
    for batch,attention_mask, train_tm in train_dataloader:
#     for batch,attention_mask, train_tm in zip(batched_tok_ten, batched_tok_mskd, batched_train_tm):
        loss_new = loss(model(batch, attention_mask),train_tm.float()[:,None])
        loss_new.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)  # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
        optimizer.step()
        scheduler.step()
        print(f"loss_new={loss_new}")


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  ../torch/csrc/utils/python_arg_parser.cpp:1055.)
  exp_avg.mul_(beta1).add_(1.0 - beta1, grad)


loss_new=2588.7138671875
loss_new=2933.031494140625
loss_new=2891.791015625
loss_new=2683.242919921875
loss_new=2435.296142578125
loss_new=2711.696533203125
loss_new=2823.892578125
loss_new=2985.628173828125
loss_new=2453.697265625
loss_new=2693.365234375
loss_new=2943.279052734375
loss_new=2900.676513671875
loss_new=2617.969482421875
loss_new=2676.50048828125
loss_new=2595.653076171875
loss_new=2529.22216796875
loss_new=2519.93017578125
loss_new=2374.563720703125
loss_new=2490.346923828125
loss_new=2787.250244140625
loss_new=2651.0361328125
loss_new=2442.63232421875
loss_new=2740.577392578125
loss_new=2673.361572265625
loss_new=2604.515625
loss_new=2568.933837890625
loss_new=2724.67822265625
loss_new=2898.152587890625
loss_new=2314.645751953125
loss_new=2579.6044921875
loss_new=2588.68798828125
loss_new=2644.273193359375
loss_new=2401.4755859375
loss_new=2223.34326171875
loss_new=3197.2607421875
loss_new=2608.49072265625
loss_new=2520.092041015625
loss_new=2709.92724609375
loss_new=22

loss_new=778.5361938476562
loss_new=737.8702392578125
loss_new=707.481689453125
loss_new=706.6082763671875
loss_new=779.3256225585938
loss_new=684.3375854492188
loss_new=857.5392456054688
loss_new=587.1343994140625
loss_new=708.4258422851562
loss_new=694.8218383789062
loss_new=631.8123168945312
loss_new=683.4513549804688
loss_new=667.30224609375
loss_new=679.6611328125
loss_new=617.989990234375
loss_new=691.072509765625
loss_new=978.2987670898438
loss_new=519.8538208007812
loss_new=547.8507080078125
loss_new=575.575927734375
loss_new=661.50146484375
loss_new=617.6160278320312
loss_new=613.4464721679688
loss_new=729.0999145507812
loss_new=808.7354125976562
loss_new=621.0847778320312
loss_new=685.1561889648438
loss_new=561.875732421875
loss_new=650.8305053710938
loss_new=766.6705322265625
loss_new=555.0289916992188
loss_new=559.6068115234375
loss_new=591.193115234375
loss_new=736.8978881835938
loss_new=645.6614379882812
loss_new=548.2635498046875
loss_new=465.0250244140625
loss_new=526.6

loss_new=236.6387939453125
loss_new=117.91062927246094
loss_new=246.644287109375
loss_new=251.3880157470703
loss_new=228.7705841064453
loss_new=306.841796875
loss_new=185.88922119140625
loss_new=234.2857208251953
loss_new=158.60491943359375
loss_new=190.93568420410156
loss_new=253.20596313476562
loss_new=163.13760375976562
loss_new=138.79116821289062
loss_new=203.09246826171875
loss_new=219.00100708007812
loss_new=237.66677856445312
loss_new=147.00636291503906
loss_new=196.34210205078125
loss_new=192.3108673095703
loss_new=216.99034118652344
loss_new=240.80108642578125
loss_new=195.67413330078125
loss_new=289.9035339355469
loss_new=154.1140594482422
loss_new=193.65225219726562
loss_new=164.03086853027344
loss_new=177.50485229492188
loss_new=268.57037353515625
loss_new=312.608642578125
loss_new=209.63189697265625
loss_new=161.48268127441406
loss_new=227.79776000976562
loss_new=234.4615478515625
loss_new=234.31781005859375
loss_new=274.20880126953125
loss_new=228.6632080078125
loss_new=2

loss_new=312.4903564453125
loss_new=216.77613830566406
loss_new=329.6341247558594
loss_new=168.81150817871094
loss_new=239.7261962890625
loss_new=204.96885681152344
loss_new=273.86883544921875
loss_new=208.2882843017578
loss_new=172.4314727783203
loss_new=181.66526794433594
loss_new=237.61398315429688
loss_new=176.69561767578125
loss_new=228.0795440673828
loss_new=302.0676574707031
loss_new=185.98797607421875
loss_new=235.43423461914062
loss_new=200.8206787109375
loss_new=209.12924194335938
loss_new=258.1180725097656
loss_new=262.2910461425781
loss_new=219.30645751953125
loss_new=138.5529022216797
loss_new=258.8049621582031
loss_new=252.60903930664062
loss_new=199.4834442138672
loss_new=237.1172332763672
loss_new=253.67877197265625
loss_new=327.0286865234375
loss_new=237.0081787109375
loss_new=141.36309814453125
loss_new=195.89512634277344
loss_new=217.35035705566406
loss_new=200.92245483398438
loss_new=304.2408142089844
loss_new=268.5927429199219
loss_new=235.24050903320312
loss_new=2

loss_new=156.173583984375
loss_new=328.2646484375
loss_new=200.0712890625
loss_new=172.72462463378906
loss_new=228.51087951660156
loss_new=211.23019409179688
loss_new=213.54534912109375
loss_new=217.52980041503906
loss_new=147.5924835205078
loss_new=172.11703491210938
loss_new=213.00819396972656
loss_new=232.43722534179688
loss_new=179.5148468017578
loss_new=172.65003967285156
loss_new=127.56724548339844
loss_new=127.94097137451172
loss_new=160.63978576660156
loss_new=202.98899841308594
loss_new=177.85084533691406
loss_new=191.36923217773438
loss_new=138.68653869628906
loss_new=187.2476348876953
loss_new=248.67562866210938
loss_new=263.37298583984375
loss_new=194.6785888671875
loss_new=196.3746337890625
loss_new=248.82876586914062
loss_new=174.37603759765625
loss_new=238.77169799804688
loss_new=198.99708557128906
loss_new=300.4881286621094
loss_new=163.6175994873047
loss_new=216.17666625976562
loss_new=272.7503967285156
loss_new=213.98260498046875
loss_new=222.27850341796875
loss_new=1

In [25]:
import os

In [26]:
os.environ

environ{'SHELL': '/bin/bash',
        'JPY_API_TOKEN': 'b2c0e785b69842beb3fd9231d0290cb6',
        'USER': 'morant',
        'JUPYTERHUB_BASE_URL': '/',
        'JUPYTERHUB_CLIENT_ID': 'jupyterhub-user-morant',
        'JUPYTERHUB_API_TOKEN': 'b2c0e785b69842beb3fd9231d0290cb6',
        'PATH': '/Local/md_kaplan/anaconda3/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin',
        'MKL_NUM_THREADS': '3',
        'PWD': '/srv01/technion/morant',
        'JUPYTERHUB_SERVER_NAME': '',
        'LANG': 'en_US.UTF-8',
        'JUPYTERHUB_API_URL': 'http://127.0.0.1:8002/hub/api',
        'SHLVL': '0',
        'HOME': '/srv01/technion/morant',
        'JUPYTERHUB_USER': 'morant',
        'JUPYTERHUB_ACTIVITY_URL': 'http://127.0.0.1:8002/hub/api/users/morant/activity',
        'JUPYTERHUB_OAUTH_CALLBACK_URL': '/user/morant/oauth_callback',
        'JUPYTERHUB_HOST': '',
        'JUPYTERHUB_SERVICE_PREFIX': '/user/morant/',
        'PYDEVD_USE_FRAME_EVAL': 'NO',
        'JPY_PARENT_PID': '37