# Load Tokenizer

In [1]:
from transformers import AutoTokenizer

fast_tokenizer = AutoTokenizer.from_pretrained("data/target/tokenizer_model")

vocab_size = len(fast_tokenizer.get_vocab().keys())

print(f"load tokenizer\nvocab size: {vocab_size}\nspecial tokens: {fast_tokenizer.all_special_tokens}")

load tokenizer
vocab size: 10261
special tokens: ['<s>', '</s>', '<unk>', '<pad>', '<mask>']


In [2]:
import os
import pickle

if not os.path.exists("data/target/X.pkl"):
    from sklearn.model_selection import train_test_split
    
    with open("data/target/protein_sub_1.pickle", "rb") as f:
        data = pickle.load(f)
        data = list(data.values())
    
        print(f"load dataset ... # of data: {len(data)}")
    
    X_train, X_test = train_test_split(data, test_size=0.01, random_state=42, shuffle=True)
    X_train, X_valid = train_test_split(X_train, test_size=0.01, random_state=42, shuffle=True)
    
    with open("data/target/X.pkl", "wb") as f:
        pickle.dump([X_train, X_valid, X_test], f)
else:
    with open("data/target/X.pkl", "rb") as f:
        X_train, X_valid, X_test = pickle.load(f)
        
print(f"load dataset\nX_train: {len(X_train)}\nX_valid: {len(X_valid)}\nX_test: {len(X_test)}")

load dataset
X_train: 49005000
X_valid: 495000
X_test: 500000


In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, RandomSampler
from transformers import DataCollatorForLanguageModeling

class MaskedLMDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        
    def encode(self, data):
        return self.tokenizer.encode(data, max_length=self.max_length, truncation=True)
        
        
    def __len__(self):
        return len(self.data)

    
    def __getitem__(self, idx):
        return torch.tensor(self.encode(self.data[idx]), dtype=torch.long)
    
    
data_collator = DataCollatorForLanguageModeling(
    tokenizer=fast_tokenizer, mlm=True, mlm_probability=0.2
)

train_dataset = MaskedLMDataset(X_train, fast_tokenizer)
train_sampler = RandomSampler(X_train, replacement=True, num_samples=16000)
train_loader = DataLoader(train_dataset, batch_size=32, collate_fn=data_collator, num_workers=16, pin_memory=True, prefetch_factor=10, drop_last=True, sampler=train_sampler)

valid_dataset = MaskedLMDataset(X_valid, fast_tokenizer)
valid_sampler = RandomSampler(X_valid, replacement=True, num_samples=1600)
valid_loader = DataLoader(valid_dataset, batch_size=32, collate_fn=data_collator, num_workers=16, pin_memory=True, prefetch_factor=10, sampler=valid_sampler)

test_dataset = MaskedLMDataset(X_test, fast_tokenizer)
test_loader = DataLoader(test_dataset, batch_size=32, collate_fn=data_collator, num_workers=16, pin_memory=True, prefetch_factor=10)

In [None]:
import torchmetrics
import pytorch_lightning as pl
from transformers import LongformerConfig, LongformerForMaskedLM
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

config = LongformerConfig(
    vocab_size=len(fast_tokenizer.vocab),
    hidden_size=768,
    intermediate_size=3072,
    max_position_embeddings=1024 + 2,
    num_hidden_layers=12,
    num_attention_heads=12
)


class Bert(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.save_hyperparameters()
        self.model = LongformerForMaskedLM(config)
        
        self.train_accuracy = torchmetrics.Accuracy()
        self.valid_accuracy = torchmetrics.Accuracy()
        self.test_accuracy = torchmetrics.Accuracy()
        
        
    def forward(self, input_ids, labels):
        return self.model(input_ids=input_ids, attention_mask=None, labels=labels)

       
    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        labels = batch['labels']
        
        output = self(input_ids, labels)

        loss = output.loss
        logits = output.logits

        preds = logits.argmax(dim=-1)
        
        self.log('train_loss', float(loss), on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_accuracy", self.train_accuracy(preds[labels > 0], labels[labels > 0]), on_step=False, on_epoch=True, prog_bar=True, logger=True)
        
        torch.cuda.empty_cache()
        
        return loss

    
    def validation_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        labels = batch['labels']
        
        output = self(input_ids, labels)

        loss = output.loss
        logits = output.logits

        preds = logits.argmax(dim=-1)
        
        self.log('valid_loss', float(loss), on_step=False, on_epoch=True, prog_bar=True)
        self.log("valid_accuracy", self.valid_accuracy(preds[labels > 0], labels[labels > 0]), on_step=False, on_epoch=True, prog_bar=True, logger=True)
    
        torch.cuda.empty_cache()
    
    
    def test_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        labels = batch['labels']
        
        output = self(input_ids, labels)

        loss = output.loss
        logits = output.logits

        preds = logits.argmax(dim=-1)
        
        self.log('test_loss', float(loss), on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_accuracy", self.test_accuracy(preds[labels > 0], labels[labels > 0]), on_step=False, on_epoch=True, prog_bar=True, logger=True)
    
        torch.cuda.empty_cache()
    
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10)
    
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
    
    
model = Bert(config)
callbacks = [
    ModelCheckpoint(monitor='valid_loss', save_top_k=10, dirpath='weights/protein_bert', filename='protein_bert-{epoch:02d}-{valid_loss:.4f}'),
#     EarlyStopping('valid_loss', patience=20)
]

trainer = pl.Trainer(max_epochs=1000, gpus=1, enable_progress_bar=True, callbacks=callbacks, precision=16)
trainer.fit(model, train_loader, valid_loader)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type                  | Params
---------------------------------------------------------
0 | model          | LongformerForMaskedLM | 115 M 
1 | train_accuracy | Accuracy              | 0     
2 | valid_accuracy | Accuracy              | 0     
3 | test_accuracy  | Accuracy              | 0     
---------------------------------------------------------
115 M     Trainable params
0         Non-trainable params
115 M     Total params
231.179   Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

In [7]:
# model = Bert(config).load_from_checkpoint("weights/protein_bert/protein_bert-epoch=896-valid_loss=0.0995.ckpt")
# trainer.test(model, test_loader)