In [6]:
import torch 
import torch.nn
from pytorch_transformers import BertTokenizer, BertModel, BertForMaskedLM, AdamW, WarmupLinearSchedule
import logging
import pandas as pd
from biopandas.pdb import PandasPdb
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

In [2]:
sample_sub = pd.read_csv("/srv01/technion/morant/Storage/sample_submission.csv")
test = pd.read_csv("/srv01/technion/morant/Storage/test.csv")
train_updates = pd.read_csv("/srv01/technion/morant/Storage/train_updates_20220929.csv")

In [7]:
class CustomProteinDataset(Dataset):
    def __init__(self, csv_file, wt_struc_pred):
        self.csv_file = pd.read_csv(csv_file)
        self.wt_struc_pred = PandasPdb().read_pdb(wt_struc_pred)
        
        # Tokenization of train
        aa2num = {'A': 1, 'R': 2, 'N': 3, 'D': 4, 'C': 5, 'Q': 6, 'E': 7, 'G': 8, 'H': 9, 'I': 10, 'L': 11, 'K': 12, 'M': 13,
                  'F': 14, 'P': 15, 'O': 16, 'S': 17, 'U': 18, 'T': 19, 'W': 20, 'Y': 21, 'V': 22, 'B': 23, 'Z': 24, 'X': 25, 'J': 26}

        # Tokenization!!
        self.csv_file['protein_sequence_tokenized'] = self.csv_file['protein_sequence'].apply(lambda s: [aa2num[x] for x in s])
        self.csv_file['len_Before_tokenization'] = self.csv_file['protein_sequence'].apply(len)
        self.csv_file = self.csv_file[self.csv_file['len_Before_tokenization']<=512].reset_index()
        max_len = self.csv_file['protein_sequence_tokenized'].apply(len).max()
        if max_len < 512:
            max_len = 512

        self.csv_file['protein_sequence_tokenized'] = self.csv_file['protein_sequence_tokenized'].apply(
            lambda x: np.pad(x, (0, max_len-len(x))))
        self.tokens_tensor = self.csv_file['protein_sequence_tokenized']
        self.tokens_tensor = torch.tensor(np.array([ x for x in self.tokens_tensor.values ]))[:, :512]
        self.tokens_mskd_tensor = self.csv_file['len_Before_tokenization']
        self.tokens_mskd_tensor = torch.tensor([np.pad(np.ones(x),(0,max_len-x))
                                                for x in self.tokens_mskd_tensor.values])[:, :512].float()
        
        if 'tm' in self.csv_file.columns:
            self.tm_tensor = torch.tensor(self.csv_file['tm'])
            
        else:
            self.tm_tensor = None

    def __len__(self):
        return len(self.csv_file)

    def __getitem__(self, idx):
        if self.tm_tensor is not None:
            return self.tokens_tensor[idx], self.tokens_mskd_tensor[idx],  self.tm_tensor[idx]
        else:
            return self.tokens_tensor[idx], self.tokens_mskd_tensor[idx]

In [4]:
training_data = CustomProteinDataset('/srv01/technion/morant/Storage/train.csv',
                                     '/srv01/technion/morant/Storage/wildtype_structure_prediction_af2.pdb')
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
# test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

  self.tokens_mskd_tensor = torch.tensor([np.pad(np.ones(x),(0,max_len-x))


In [8]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = BertModel.from_pretrained('bert-base-uncased')
        self.linear = torch.nn.Sequential(torch.nn.Linear(in_features=393216, out_features=20000),
                                          torch.nn.ReLU(),
                                          torch.nn.Linear(in_features=20000, out_features=10000),
                                          torch.nn.ReLU(),
                                          torch.nn.Linear(in_features=10000, out_features=1000),
                                          torch.nn.ReLU(),
                                          torch.nn.Linear(in_features=1000, out_features=1))
        
    def forward(self, batch, attention_mask):
        result = self.model(batch, token_type_ids=torch.zeros_like(batch), attention_mask=attention_mask)[0]
        res_flat = torch.flatten(result, start_dim=1)
        lin = self.linear(res_flat)

        return lin        

In [6]:
model = Model()
loss = torch.nn.MSELoss()

In [8]:
# Training (when we'll get there)
# Parameters:
lr = 1e-7
max_grad_norm = 0.7
num_total_steps = 1000
num_warmup_steps = 500
warmup_proportion = float(num_warmup_steps) / float(num_total_steps)  # 0.1

### In PyTorch-Transformers, optimizer and schedules are splitted and instantiated like this:
optimizer = AdamW(model.parameters(), lr=lr, correct_bias=False)  # To reproduce BertAdam specific behavior set correct_bias=False
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_total=num_total_steps)  # PyTorch scheduler
### and used like this:
for i in range(num_total_steps):
    for batch,attention_mask, train_tm in train_dataloader:
#     for batch,attention_mask, train_tm in zip(batched_tok_ten, batched_tok_mskd, batched_train_tm):
        loss_new = loss(model(batch, attention_mask),train_tm.float()[:,None])
        loss_new.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)  # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
        optimizer.step()
        scheduler.step()
        print(f"loss_new={loss_new}")


loss_new=326.58544921875
loss_new=349.9559326171875
loss_new=289.32427978515625
loss_new=243.9564666748047
loss_new=499.64599609375
loss_new=340.3167419433594
loss_new=390.4734191894531
loss_new=339.09759521484375
loss_new=309.1086730957031
loss_new=326.2749938964844
loss_new=289.9901123046875
loss_new=272.30377197265625
loss_new=253.0203857421875
loss_new=381.80572509765625
loss_new=397.52154541015625
loss_new=370.40472412109375
loss_new=305.77520751953125
loss_new=191.43597412109375
loss_new=369.45086669921875
loss_new=448.6501159667969
loss_new=419.50274658203125
loss_new=239.0080108642578
loss_new=338.41046142578125
loss_new=457.9380798339844
loss_new=516.4689331054688
loss_new=346.7664794921875
loss_new=309.12725830078125
loss_new=323.1715087890625
loss_new=325.1849670410156
loss_new=253.74864196777344
loss_new=316.572021484375
loss_new=314.2746887207031
loss_new=325.5556945800781
loss_new=428.58306884765625
loss_new=292.0516357421875
loss_new=350.1146545410156
loss_new=280.992004

loss_new=315.3497619628906
loss_new=274.94500732421875
loss_new=278.5273132324219
loss_new=202.3355712890625
loss_new=269.4154052734375
loss_new=191.351318359375
loss_new=208.30316162109375
loss_new=333.724609375
loss_new=279.8355712890625
loss_new=270.4465637207031
loss_new=199.1775360107422
loss_new=192.83511352539062
loss_new=241.1897735595703
loss_new=201.81417846679688
loss_new=141.26068115234375
loss_new=147.46449279785156
loss_new=225.9024658203125
loss_new=120.71314239501953
loss_new=180.33358764648438
loss_new=119.72064971923828
loss_new=227.5009002685547
loss_new=264.8646240234375
loss_new=273.8433837890625
loss_new=222.6355438232422
loss_new=262.9760437011719
loss_new=127.4345703125
loss_new=289.230712890625
loss_new=222.5192413330078
loss_new=103.15982055664062
loss_new=317.38214111328125
loss_new=203.0672149658203
loss_new=238.67633056640625
loss_new=258.5260009765625
loss_new=291.3746643066406
loss_new=243.1175537109375
loss_new=249.37588500976562
loss_new=254.93206787109

KeyboardInterrupt: 

In [11]:
# torch.save(model, '/srv01/technion/morant/Storage/enzyme-stability/model_trained.pt')

In [9]:
model = torch.load('/srv01/technion/morant/Storage/enzyme-stability/model_trained.pt')

In [None]:
# Running test

In [9]:
test

Unnamed: 0,seq_id,protein_sequence,pH,data_source
0,31390,VPVNPEPDATSVENVAEKTGSGDSQSDPIKADLEVKGQSALPFDVD...,8,Novozymes
1,31391,VPVNPEPDATSVENVAKKTGSGDSQSDPIKADLEVKGQSALPFDVD...,8,Novozymes
2,31392,VPVNPEPDATSVENVAKTGSGDSQSDPIKADLEVKGQSALPFDVDC...,8,Novozymes
3,31393,VPVNPEPDATSVENVALCTGSGDSQSDPIKADLEVKGQSALPFDVD...,8,Novozymes
4,31394,VPVNPEPDATSVENVALFTGSGDSQSDPIKADLEVKGQSALPFDVD...,8,Novozymes
...,...,...,...,...
2408,33798,VPVNPEPDATSVENVILKTGSGDSQSDPIKADLEVKGQSALPFDVD...,8,Novozymes
2409,33799,VPVNPEPDATSVENVLLKTGSGDSQSDPIKADLEVKGQSALPFDVD...,8,Novozymes
2410,33800,VPVNPEPDATSVENVNLKTGSGDSQSDPIKADLEVKGQSALPFDVD...,8,Novozymes
2411,33801,VPVNPEPDATSVENVPLKTGSGDSQSDPIKADLEVKGQSALPFDVD...,8,Novozymes


In [11]:
test_data = CustomProteinDataset('/srv01/technion/morant/Storage/test.csv',
                                     '/srv01/technion/morant/Storage/wildtype_structure_prediction_af2.pdb')
test_dataloader = DataLoader(test_data, batch_size=10)


  self.tokens_mskd_tensor = torch.tensor([np.pad(np.ones(x),(0,max_len-x))


In [13]:
import os

In [14]:
# Training (when we'll get there)
res = []
### and used like this:
for idx, (batch, attention_mask) in enumerate(test_dataloader):
#     for batch,attention_mask, train_tm in zip(batched_tok_ten, batched_tok_mskd, batched_train_tm):
    loss_new = model(batch, attention_mask)
#     res.append(loss_new)
    to_save = os.path.join('/srv01/technion/morant/Storage/enzyme-stability/res_test', f"{idx}.pt")
    torch.save(loss_new, to_save)
    print(f"idx={idx}")
    print(f"loss_new={loss_new}")


idx=0
loss_new=tensor([[48.4121],
        [48.4058],
        [48.4140],
        [48.4403],
        [48.4347],
        [48.4465],
        [48.4240],
        [48.4407],
        [48.4186],
        [48.4337]], grad_fn=<AddmmBackward0>)
idx=1
loss_new=tensor([[48.4199],
        [48.4217],
        [48.4408],
        [48.4219],
        [48.4171],
        [48.4341],
        [48.4148],
        [48.4023],
        [48.4189],
        [48.4279]], grad_fn=<AddmmBackward0>)
idx=2
loss_new=tensor([[48.4142],
        [48.4294],
        [48.4068],
        [48.4053],
        [48.4054],
        [48.3968],
        [48.4061],
        [48.4052],
        [48.4041],
        [48.4231]], grad_fn=<AddmmBackward0>)
idx=3
loss_new=tensor([[48.4020],
        [48.4185],
        [48.3978],
        [48.4168],
        [48.4014],
        [48.4157],
        [48.3995],
        [48.4166],
        [48.3982],
        [48.4156]], grad_fn=<AddmmBackward0>)
idx=4
loss_new=tensor([[48.4073],
        [48.4057],
        [48.4036],


idx=36
loss_new=tensor([[48.4067],
        [48.4339],
        [48.4344],
        [48.4305],
        [48.4289],
        [48.4246],
        [48.4349],
        [48.4086],
        [48.4070],
        [48.4101]], grad_fn=<AddmmBackward0>)
idx=37
loss_new=tensor([[48.4088],
        [48.3957],
        [48.3904],
        [48.3953],
        [48.4153],
        [48.4176],
        [48.4199],
        [48.4163],
        [48.4160],
        [48.4101]], grad_fn=<AddmmBackward0>)
idx=38
loss_new=tensor([[48.4166],
        [48.4041],
        [48.4180],
        [48.4410],
        [48.4207],
        [48.4287],
        [48.4260],
        [48.4372],
        [48.4364],
        [48.4162]], grad_fn=<AddmmBackward0>)
idx=39
loss_new=tensor([[48.4090],
        [48.4118],
        [48.4015],
        [48.4136],
        [48.4275],
        [48.4191],
        [48.4179],
        [48.4246],
        [48.4050],
        [48.4201]], grad_fn=<AddmmBackward0>)
idx=40
loss_new=tensor([[48.4165],
        [48.4164],
        [48.41

idx=72
loss_new=tensor([[48.4176],
        [48.4223],
        [48.4210],
        [48.4299],
        [48.4075],
        [48.4345],
        [48.4146],
        [48.4402],
        [48.4343],
        [48.4486]], grad_fn=<AddmmBackward0>)
idx=73
loss_new=tensor([[48.4163],
        [48.4243],
        [48.4323],
        [48.4116],
        [48.4233],
        [48.4437],
        [48.4344],
        [48.4189],
        [48.4413],
        [48.4229]], grad_fn=<AddmmBackward0>)
idx=74
loss_new=tensor([[48.4286],
        [48.4282],
        [48.4370],
        [48.4240],
        [48.4279],
        [48.4294],
        [48.4167],
        [48.4214],
        [48.4164],
        [48.4200]], grad_fn=<AddmmBackward0>)
idx=75
loss_new=tensor([[48.4346],
        [48.4291],
        [48.4091],
        [48.4308],
        [48.4179],
        [48.4163],
        [48.4147],
        [48.4112],
        [48.4175],
        [48.4251]], grad_fn=<AddmmBackward0>)
idx=76
loss_new=tensor([[48.3948],
        [48.3974],
        [48.41

idx=108
loss_new=tensor([[48.4312],
        [48.4350],
        [48.4132],
        [48.4218],
        [48.4040],
        [48.4063],
        [48.4013],
        [48.4114],
        [48.4208],
        [48.4058]], grad_fn=<AddmmBackward0>)
idx=109
loss_new=tensor([[48.4014],
        [48.4187],
        [48.4159],
        [48.4234],
        [48.4254],
        [48.4268],
        [48.4303],
        [48.4234],
        [48.4371],
        [48.4293]], grad_fn=<AddmmBackward0>)
idx=111
loss_new=tensor([[48.4361],
        [48.4136],
        [48.4219],
        [48.4232],
        [48.4222],
        [48.4348],
        [48.4220],
        [48.4270],
        [48.4257],
        [48.4111]], grad_fn=<AddmmBackward0>)
idx=112
loss_new=tensor([[48.4339],
        [48.4356],
        [48.4341],
        [48.4393],
        [48.4397],
        [48.4252],
        [48.4220],
        [48.4205],
        [48.4178],
        [48.4250]], grad_fn=<AddmmBackward0>)
idx=113
loss_new=tensor([[48.4343],
        [48.4117],
        [

idx=145
loss_new=tensor([[48.4287],
        [48.4296],
        [48.4428],
        [48.4336],
        [48.4119],
        [48.4403],
        [48.4200],
        [48.4205],
        [48.4240],
        [48.4417]], grad_fn=<AddmmBackward0>)
idx=146
loss_new=tensor([[48.4270],
        [48.4242],
        [48.4536],
        [48.4393],
        [48.4416],
        [48.4016],
        [48.3841],
        [48.4000],
        [48.4012],
        [48.4138]], grad_fn=<AddmmBackward0>)
idx=147
loss_new=tensor([[48.4403],
        [48.4394],
        [48.4359],
        [48.4219],
        [48.4160],
        [48.4238],
        [48.4487],
        [48.4344],
        [48.4375],
        [48.4289]], grad_fn=<AddmmBackward0>)
idx=148
loss_new=tensor([[48.4444],
        [48.4417],
        [48.4294],
        [48.4516],
        [48.4183],
        [48.4109],
        [48.3993],
        [48.4144],
        [48.3878],
        [48.4057]], grad_fn=<AddmmBackward0>)
idx=149
loss_new=tensor([[48.4251],
        [48.4376],
        [

idx=181
loss_new=tensor([[48.4355],
        [48.4232],
        [48.4167],
        [48.4138],
        [48.4095],
        [48.4205],
        [48.4242],
        [48.4103],
        [48.4081],
        [48.4215]], grad_fn=<AddmmBackward0>)
idx=182
loss_new=tensor([[48.4239],
        [48.4200],
        [48.4362],
        [48.4394],
        [48.4436],
        [48.4306],
        [48.4443],
        [48.4224],
        [48.4348],
        [48.4375]], grad_fn=<AddmmBackward0>)
idx=183
loss_new=tensor([[48.4502],
        [48.4340],
        [48.4472],
        [48.4499],
        [48.4482],
        [48.4235],
        [48.3942],
        [48.4280],
        [48.4234],
        [48.4035]], grad_fn=<AddmmBackward0>)
idx=184
loss_new=tensor([[48.4229],
        [48.4338],
        [48.4205],
        [48.4327],
        [48.4077],
        [48.4306],
        [48.4206],
        [48.4300],
        [48.4103],
        [48.4064]], grad_fn=<AddmmBackward0>)
idx=185
loss_new=tensor([[48.4260],
        [48.4248],
        [

idx=217
loss_new=tensor([[48.4341],
        [48.4341],
        [48.4398],
        [48.4270],
        [48.4394],
        [48.4114],
        [48.4176],
        [48.4023],
        [48.4183],
        [48.4248]], grad_fn=<AddmmBackward0>)
idx=218
loss_new=tensor([[48.4081],
        [48.4172],
        [48.4213],
        [48.4301],
        [48.4418],
        [48.4330],
        [48.4376],
        [48.4514],
        [48.4375],
        [48.4497]], grad_fn=<AddmmBackward0>)
idx=219
loss_new=tensor([[48.3963],
        [48.4044],
        [48.4147],
        [48.4193],
        [48.4002],
        [48.4088],
        [48.4116],
        [48.4034],
        [48.4279],
        [48.4046]], grad_fn=<AddmmBackward0>)
idx=220
loss_new=tensor([[48.4221],
        [48.4091],
        [48.4190],
        [48.4241],
        [48.4361],
        [48.4283],
        [48.4302],
        [48.4283],
        [48.4402],
        [48.4333]], grad_fn=<AddmmBackward0>)
idx=221
loss_new=tensor([[48.4180],
        [48.4391],
        [

In [None]:
for f in range():
    d=torch.load()

In [16]:
d = [torch.load(f"/srv01/technion/morant/Storage/enzyme-stability/res_test/{i}.pt") for i in range(242)]
res =  torch.cat(d)

In [25]:
res=res.squeeze()

In [26]:
df = pd.DataFrame(data={'seq_id': test['seq_id'], 'tm': res.detach()})
df

Unnamed: 0,seq_id,tm
0,31390,48.412102
1,31391,48.405800
2,31392,48.413971
3,31393,48.440281
4,31394,48.434673
...,...,...
2408,33798,48.432423
2409,33799,48.431812
2410,33800,48.420433
2411,33801,48.431179


In [27]:
df.to_csv("/srv01/technion/morant/Storage/enzyme-stability/submission.csv")

In [19]:
test

Unnamed: 0,seq_id,protein_sequence,pH,data_source
0,31390,VPVNPEPDATSVENVAEKTGSGDSQSDPIKADLEVKGQSALPFDVD...,8,Novozymes
1,31391,VPVNPEPDATSVENVAKKTGSGDSQSDPIKADLEVKGQSALPFDVD...,8,Novozymes
2,31392,VPVNPEPDATSVENVAKTGSGDSQSDPIKADLEVKGQSALPFDVDC...,8,Novozymes
3,31393,VPVNPEPDATSVENVALCTGSGDSQSDPIKADLEVKGQSALPFDVD...,8,Novozymes
4,31394,VPVNPEPDATSVENVALFTGSGDSQSDPIKADLEVKGQSALPFDVD...,8,Novozymes
...,...,...,...,...
2408,33798,VPVNPEPDATSVENVILKTGSGDSQSDPIKADLEVKGQSALPFDVD...,8,Novozymes
2409,33799,VPVNPEPDATSVENVLLKTGSGDSQSDPIKADLEVKGQSALPFDVD...,8,Novozymes
2410,33800,VPVNPEPDATSVENVNLKTGSGDSQSDPIKADLEVKGQSALPFDVD...,8,Novozymes
2411,33801,VPVNPEPDATSVENVPLKTGSGDSQSDPIKADLEVKGQSALPFDVD...,8,Novozymes
