In [1]:
PROJECT_NAME = "ProtBert"
LEARNING_RATE = 3e-5
PROT_MAX_LEN = 1024

import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.nn import CrossEntropyLoss, CosineEmbeddingLoss

import transformers
from transformers import BertTokenizer, AutoModel, BertConfig, BertModel, BertForMaskedLM
from transformers import DataCollatorForLanguageModeling, DataCollatorWithPadding

from sklearn.model_selection import train_test_split

from torchmetrics.functional.classification import accuracy
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(name=f'{PROJECT_NAME}_lr-{LEARNING_RATE}_prot_{PROT_MAX_LEN}',
                           project='DistilledProtBert')

# prot_seq = pd.read_csv("data/mol_trans/protein_sequences.csv")
with open("data/fasta_list.pkl", "rb") as f:
    fasta_list = pickle.load(f)

print(len(fasta_list))
    
train_data, test_data = train_test_split(fasta_list, test_size=0.1, random_state=42, shuffle=True)
train_data, valid_data = train_test_split(train_data, test_size=5000, random_state=42, shuffle=True)

tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
# model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert")
model = BertForMaskedLM.from_pretrained("yarongef/DistilProtBert")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjonghyunlee1993[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666924808329592, max=1.0)…

568363


In [2]:
class CustomDataset(Dataset):
    def __init__(self, data, tokenizer, max_len=1024):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def encode(self, seq):
        return self.tokenizer(" ".join(seq), max_length=self.max_len, truncation=True)
        
        
    def __len__(self):
        return len(self.data)

    
    def __getitem__(self, idx):
        return self.encode(self.data[idx])
    
    
def collate_batch(batch):
    out = []
    for b in batch:
        out.append(b)
        
    return tokenizer.pad(out, return_tensors="pt")

train_dataset = CustomDataset(train_data, tokenizer, max_len=PROT_MAX_LEN)
data_sampler = RandomSampler(train_data, replacement=True, num_samples=100000)
mlm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.3)

train_dataloader = DataLoader(train_dataset, batch_size=150, collate_fn=mlm_collator,
                              num_workers=16, pin_memory=True, prefetch_factor=2, 
                              drop_last=True, sampler=data_sampler)

valid_dataset = CustomDataset(valid_data, tokenizer, max_len=PROT_MAX_LEN)
valid_dataloader = DataLoader(valid_dataset, batch_size=150, num_workers=16, 
                              pin_memory=True, prefetch_factor=2, collate_fn=mlm_collator)

test_dataset = CustomDataset(test_data, tokenizer, max_len=PROT_MAX_LEN)
test_dataloader = DataLoader(test_dataset, batch_size=150, num_workers=16, 
                             pin_memory=True, prefetch_factor=2, collate_fn=mlm_collator)

In [3]:
class DistilledBERT(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        for param in self.model.bert.parameters():
            param.requires_grad = False
        
    
    def step(self, batch):
        out = self.model(**batch)
        logits, loss = out['logits'], out['loss']

        pred = torch.argmax(F.softmax(logits, dim=-1), dim=-1)
        label = batch['labels']
        masked_index = label.gt(0)
        
        acc = accuracy(torch.masked_select(pred, masked_index), torch.masked_select(label, masked_index))
        
        return loss, acc
    
    
    def training_step(self, batch, batch_idx):
        loss, acc = self.step(batch)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True)
        
        return loss
    
    
    def validation_step(self, batch, batch_idx):
        loss, acc = self.step(batch)
        self.log('valid_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('valid_acc', acc, on_step=False, on_epoch=True, prog_bar=True)

    
    def test_step(self, batch, batch_idx):
        loss, acc = self.step(batch)
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('test_acc', acc, on_step=True, on_epoch=True, prog_bar=True)

    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=LEARNING_RATE)
        
        return optimizer
    
    
callbacks = [
    ModelCheckpoint(monitor='valid_loss', save_top_k=3, dirpath=f'weights/{PROJECT_NAME}', filename='DTI-{epoch:03d}-{valid_loss:.4f}-{valid_acc:.4f}'),
]

predictor = DistilledBERT(model)

trainer = pl.Trainer(max_epochs=3, gpus=[1], enable_progress_bar=True, 
                     callbacks=callbacks, logger=wandb_logger, precision=16)

  rank_zero_deprecation(
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [4]:
trainer.test(predictor, test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_acc_epoch         0.34653547406196594
        test_loss            2.154139995574951
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 2.154139995574951, 'test_acc_epoch': 0.34653547406196594}]