In [1]:
PROJECT_NAME = "XProtBert"
LEARNING_RATE = 3e-5
PROT_MAX_LEN = 1024

import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.nn import CrossEntropyLoss, CosineEmbeddingLoss

import transformers
from transformers import BertTokenizer, AutoModel, BertConfig, BertModel, BertForMaskedLM
from transformers import DataCollatorForLanguageModeling, DataCollatorWithPadding

from sklearn.model_selection import train_test_split

from torchmetrics.functional.classification import accuracy
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(name=f'{PROJECT_NAME}_lr-{LEARNING_RATE}_prot_{PROT_MAX_LEN}',
                           project='DistilledProtBert')

# prot_seq = pd.read_csv("data/mol_trans/protein_sequences.csv")
with open("data/fasta_list.pkl", "rb") as f:
    fasta_list = pickle.load(f)

print(len(fasta_list))
    
train_data, test_data = train_test_split(fasta_list, test_size=0.1, random_state=42, shuffle=True)
train_data, valid_data = train_test_split(train_data, test_size=5000, random_state=42, shuffle=True)

tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
teacher_model = BertModel.from_pretrained("Rostlab/prot_bert")

config = BertConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=128,
    num_hidden_layers=8,
    num_attention_heads=8,
    intermediate_size=512,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=PROT_MAX_LEN + 2,
    type_vocab_size=1,
    pad_token_id=0,
    position_embedding_type="absolute"
)

class DistilledProtBert(nn.Module):
    def __init__(self, model, hidden_dim, target_dim, vocab_size):
        super().__init__()
        self.model = model
        self.vocab_size = vocab_size
        self.proj = nn.Linear(hidden_dim, target_dim, bias=False)
        
        
    def forward(self, **batch):
        x = self.model.base_model(input_ids=batch["input_ids"], 
                                  token_type_ids=batch['token_type_ids'], 
                                  attention_mask=batch['attention_mask'])
        x = x['last_hidden_state']
        logits = self.proj(x)

        mlm_out = self.model.cls(x)
        pred = torch.argmax(F.softmax(mlm_out, dim=-1), dim=-1)
        label = batch['labels']
        masked_index = label.gt(0)
        
        mlm_loss = F.cross_entropy(mlm_out.reshape(-1, self.vocab_size), label.reshape(-1))
        acc = accuracy(torch.masked_select(pred, masked_index), torch.masked_select(label, masked_index))
        
        return logits, mlm_loss, acc

student_base = BertForMaskedLM(config)
student_model = DistilledProtBert(student_base, hidden_dim=128, target_dim=1024, vocab_size=tokenizer.vocab_size)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjonghyunlee1993[0m. Use [1m`wandb login --relogin`[0m to force relogin


568363


Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [2]:
student_model.load_state_dict(torch.load("weights/DistilledProtBert/test.pt"))

<All keys matched successfully>

In [None]:
student_model.

In [5]:
class CustomDataset(Dataset):
    def __init__(self, data, tokenizer, max_len=1024):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def encode(self, seq):
        return self.tokenizer(" ".join(seq), max_length=self.max_len, truncation=True)
        
        
    def __len__(self):
        return len(self.data)

    
    def __getitem__(self, idx):
        return self.encode(self.data[idx])
    
    
def collate_batch(batch):
    out = []
    for b in batch:
        out.append(b)
        
    return tokenizer.pad(out, return_tensors="pt")

train_dataset = CustomDataset(train_data, tokenizer, max_len=PROT_MAX_LEN)
data_sampler = RandomSampler(train_data, replacement=True, num_samples=100000)
mlm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.3)

train_dataloader = DataLoader(train_dataset, batch_size=128, collate_fn=mlm_collator,
                              num_workers=16, pin_memory=True, prefetch_factor=2, 
                              drop_last=True, sampler=data_sampler)


valid_dataset = CustomDataset(valid_data, tokenizer, max_len=PROT_MAX_LEN)
valid_dataloader = DataLoader(valid_dataset, batch_size=128, num_workers=16, 
                              pin_memory=True, prefetch_factor=2, collate_fn=mlm_collator)

test_dataset = CustomDataset(test_data, tokenizer, max_len=PROT_MAX_LEN)
test_dataloader = DataLoader(test_dataset, batch_size=128, num_workers=16, 
                             pin_memory=True, prefetch_factor=2, collate_fn=mlm_collator)

In [40]:
student_model.eval()

In [None]:
sample_input = tokenizer(" ".join(train_data[0]), return_tensors="pt")
student_model(**sample_input)

In [24]:
mlm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.3)
test_dataset = CustomDataset(test_data, tokenizer, max_len=PROT_MAX_LEN)
test_dataloader = DataLoader(test_dataset, batch_size=1, num_workers=16, 
                             pin_memory=False, prefetch_factor=2, collate_fn=mlm_collator)

In [30]:
for batch in test_dataloader:
    x = model.bert(input_ids=batch["input_ids"], 
                      token_type_ids=batch['token_type_ids'], 
                      attention_mask=batch['attention_mask'])
    out = model.cls(x['last_hidden_state'])
    print(out)
    break

tensor([[[-17.7812,  -7.3981,  -7.3770,  ...,  -6.8732,  -6.9766,  -7.1213],
         [ -9.2509,  -3.1022,  -3.3210,  ...,  -2.8256,  -2.9352,  -3.1158],
         [-12.7206,  -5.3841,  -5.3178,  ...,  -4.8874,  -4.9733,  -5.2469],
         ...,
         [-15.9112,  -6.2360,  -5.7352,  ...,  -5.7123,  -5.9387,  -5.8965],
         [-15.8572,  -6.2203,  -5.7444,  ...,  -5.6790,  -5.9199,  -5.8770],
         [-17.6861,  -7.3433,  -7.3187,  ...,  -6.7905,  -6.8923,  -7.0513]]],
       grad_fn=<ViewBackward0>)


In [37]:
logits = model(**batch)['logits']
pred = torch.argmax(F.softmax(logits, dim=-1), dim=-1)

label = batch['labels']
valid_index = label.gt(0)

from sklearn.metrics import accuracy_score
accuracy_score(torch.masked_select(label, valid_index), torch.masked_select(pred, valid_index))

0.14

In [38]:
student_model(**batch)

(tensor([[[-3.0133e-04, -4.2184e-04,  2.0127e-04,  ..., -1.8263e-04,
            5.2682e-04, -2.7157e-04],
          [-1.5056e-04, -8.5584e-04,  4.0688e-04,  ..., -2.0433e-04,
            2.3460e-04,  1.3156e-04],
          [-3.1432e-04, -5.3086e-04,  2.5640e-04,  ..., -2.3224e-04,
            3.3249e-04, -9.1479e-05],
          ...,
          [-4.3508e-05, -5.2341e-04,  2.4741e-04,  ..., -1.7559e-04,
            3.7043e-04, -1.3200e-04],
          [-3.6844e-05, -5.4157e-04,  2.5454e-04,  ..., -1.7415e-04,
            3.5932e-04, -1.1321e-04],
          [-2.7331e-04, -4.4654e-04,  1.6115e-04,  ..., -1.6344e-04,
            5.3620e-04, -2.0818e-04]]], grad_fn=<UnsafeViewBackward0>),
 tensor(2.9803, grad_fn=<NllLossBackward0>),
 tensor(0.1400))