In [1]:
# from Bio import SeqIO

# fasta_file = SeqIO.parse(open("data/uniprot_sprot.fasta"), 'fasta')
# fasta_list = []

# for fasta in fasta_file:
#     fasta_list.append(str(fasta.seq))
    
# import pickle 

# with open("data/fasta_list.pkl", "wb") as f:
#     pickle.dump(fasta_list, f)

In [2]:
PROJECT_NAME = "XProtBert"
LEARNING_RATE = 1e-4
PROT_MAX_LEN = 1024

import pickle
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.nn import CrossEntropyLoss, CosineEmbeddingLoss

import transformers
from transformers import BertTokenizer, AutoModel, BertConfig, BertModel, BertForMaskedLM
from transformers import DataCollatorForLanguageModeling, DataCollatorWithPadding

from sklearn.model_selection import train_test_split

from torchmetrics.functional.classification import accuracy
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(name=f'{PROJECT_NAME}_lr-{LEARNING_RATE}_prot_{PROT_MAX_LEN}',
                           project='DistilledProtBert')

# prot_seq = pd.read_csv("data/mol_trans/protein_sequences.csv")
with open("data/fasta_list.pkl", "rb") as f:
    fasta_list = pickle.load(f)

print(len(fasta_list))
    
train_data, test_data = train_test_split(fasta_list, test_size=0.1, random_state=42, shuffle=True)
train_data, valid_data = train_test_split(train_data, test_size=5000, random_state=42, shuffle=True)

tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
teacher_model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert")

config = BertConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=512,
    num_hidden_layers=12,
    num_attention_heads=8,
    intermediate_size=1024,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=PROT_MAX_LEN + 2,
    type_vocab_size=1,
    pad_token_id=0,
    position_embedding_type="absolute"
)

student_model = BertForMaskedLM(config)

  rank_zero_warn(


568363


Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
class CustomDataset(Dataset):
    def __init__(self, data, tokenizer, max_len=1024):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def encode(self, seq):
        return self.tokenizer(" ".join(seq), max_length=self.max_len, truncation=True)
        
        
    def __len__(self):
        return len(self.data)

    
    def __getitem__(self, idx):
        return self.encode(self.data[idx])
    
    
def collate_batch(batch):
    out = []
    for b in batch:
        out.append(b)
        
    return tokenizer.pad(out, return_tensors="pt")

train_dataset = CustomDataset(train_data, tokenizer, max_len=PROT_MAX_LEN)
data_sampler = RandomSampler(train_data, replacement=True, num_samples=100000)
mlm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.3)

train_dataloader = DataLoader(train_dataset, batch_size=32, collate_fn=mlm_collator,
                              num_workers=16, pin_memory=True, prefetch_factor=5, 
                              drop_last=True, sampler=data_sampler)

valid_dataset = CustomDataset(valid_data, tokenizer, max_len=PROT_MAX_LEN)
valid_dataloader = DataLoader(valid_dataset, batch_size=32, num_workers=16, 
                              pin_memory=True, prefetch_factor=5, collate_fn=mlm_collator)

test_dataset = CustomDataset(test_data, tokenizer, max_len=PROT_MAX_LEN)
test_dataloader = DataLoader(test_dataset, batch_size=32, num_workers=16, 
                             pin_memory=True, prefetch_factor=5, collate_fn=mlm_collator)

In [4]:
class DistilledBERT(pl.LightningModule):
    def __init__(self, teacher_model, student_model):
        super().__init__()
        self.teacher_model = teacher_model
        for param in self.teacher_model.parameters():
            param.requires_grad = False
        self.teacher_model.eval()
        
        self.student_model = student_model
        self.proj = nn.Linear(512, 1024, bias=False)
        
    
    def step(self, batch):
        teacher_hidden = teacher_model.bert(input_ids=batch['input_ids'], 
                   attention_mask=batch['attention_mask'],
                   token_type_ids=batch['token_type_ids'])['last_hidden_state']
        
        student_hidden = student_model.bert(input_ids=batch['input_ids'], 
                   attention_mask=batch['attention_mask'],
                   token_type_ids=batch['token_type_ids'])['last_hidden_state']
        student_hidden_proj = self.proj(student_hidden)
        
        student_logits = student_model.cls(student_hidden)
        student_loss = F.cross_entropy(student_logits.reshape(-1, 30), batch['labels'].reshape(-1))
        
        logit_loss = F.mse_loss(student_hidden_proj, teacher_hidden.detach())
        
        total_loss = (student_loss + logit_loss)/2
        
        pred = torch.argmax(F.softmax(student_logits, dim=-1), dim=-1)
        acc = accuracy(torch.masked_select(pred, batch['labels'].gt(0)), torch.masked_select(batch['labels'], batch['labels'].gt(0)))
        
        return total_loss, acc
    
    
    def training_step(self, batch, batch_idx):
        loss, acc = self.step(batch)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True)
        
        return loss
    
    
    def validation_step(self, batch, batch_idx):
        student_out = student_model(**batch)
        student_loss, student_logits = student_out['loss'], student_out['logits']
        
        pred = torch.argmax(F.softmax(student_out['logits'], dim=-1), dim=-1)
        label = batch['labels']
        acc = accuracy(torch.masked_select(pred, label.gt(0)), torch.masked_select(label, label.gt(0)))
        
        self.log('valid_loss', student_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('valid_acc', acc, on_step=False, on_epoch=True, prog_bar=True)

    
    def test_step(self, batch, batch_idx):
        student_out = student_model(**batch)
        student_loss, student_logits = student_out['loss'], student_out['logits']
        
        pred = torch.argmax(F.softmax(student_out['logits'], dim=-1), dim=-1)
        label = batch['labels']
        acc = accuracy(torch.masked_select(pred, label.gt(0)), torch.masked_select(label, label.gt(0)))
        
        self.log('test_loss', student_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('test_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
    
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.student_model.parameters(), lr=LEARNING_RATE)
        
        return optimizer
    
    
callbacks = [
    ModelCheckpoint(monitor='valid_loss', save_top_k=3, dirpath=f'weights/{PROJECT_NAME}', filename='DTI-{epoch:03d}-{valid_loss:.4f}-{valid_acc:.4f}'),
]

predictor = DistilledBERT(teacher_model, student_model)

trainer = pl.Trainer(max_epochs=10, gpus=[0], enable_progress_bar=True, 
                     callbacks=callbacks, logger=wandb_logger, precision=16)

  rank_zero_deprecation(
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(predictor, train_dataloader, valid_dataloader)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type            | Params
--------------------------------------------------
0 | teacher_model | BertForMaskedLM | 419 M 
1 | student_model | BertForMaskedLM | 26.0 M
2 | proj          | Linear          | 524 K 
--------------------------------------------------
26.6 M    Trainable params
419 M     Non-trainable params
446 M     Total params
892.994   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

In [5]:
# predictor.load_from_checkpoint("weights/DistilledProtBert/DTI-epoch=002-valid_loss=0.2429-valid_acc=0.9101.ckpt",
#     teacher_model=teacher_model, student_model=student_model)

In [5]:
trainer.test(predictor, train_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.31578874588012695
        test_loss           1.6354262828826904
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 1.6354262828826904, 'test_acc': 0.31578874588012695}]

In [15]:
torch.save(predictor.student_model.state_dict(), "weights/XProtBert/XProtBERT_150K_update.pt")

In [18]:
predictor.student_model.model.save_pretrained("weights/XProtBert/XProtBERT_mlm.pt")