In [1]:
%cd ..

/mnt/SSD_Data/active_projects/transformer_to_lstm


In [2]:
import dataclasses
from pathlib import Path

import nlp
import torch
# import joblib
import numpy as np
from transformers import BertTokenizerFast
from transformers import BertForSequenceClassification
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.model_selection import train_test_split
from tqdm.autonotebook import tqdm

from pytorch_helper_bot.bot import batch_to_device

try:
    from apex import amp
    APEX_AVAILABLE = True
except ModuleNotFoundError:
    APEX_AVAILABLE = False

In [3]:
CACHE_DIR = Path("cache/")
CACHE_DIR.mkdir(exist_ok=True)

In [4]:
dataset = nlp.load_dataset('glue', "sst2")

In [5]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

In [6]:
# Tokenize our training dataset
def convert_to_features(example_batch):
    # Tokenize contexts and questions (as pairs of inputs)
    encodings = tokenizer.batch_encode_plus(example_batch['sentence'], pad_to_max_length=True, max_length=64)
    return encodings

In [7]:
# Format our dataset to outputs torch.Tensor to train a pytorch model
columns = ['input_ids', 'token_type_ids', 'attention_mask', "label"]
for subset in ("train", "validation"): 
    dataset[subset] = dataset[subset].map(convert_to_features, batched=True)
    dataset[subset].set_format(type='torch', columns=columns)

In [8]:
model = BertForSequenceClassification.from_pretrained(str(CACHE_DIR / "sst2_bert_uncased")).cuda()

In [9]:
if APEX_AVAILABLE:
    model = amp.initialize(
        model, opt_level="O1"
    )

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [10]:
class SST2Dataset(torch.utils.data.Dataset):
    def __init__(self, entries_dict):
        super().__init__()
        self.entries_dict = entries_dict
    
    def __len__(self):
        return len(self.entries_dict["label"])
    
    def __getitem__(self, idx):
        return (
            self.entries_dict["input_ids"][idx],
            self.entries_dict["attention_mask"][idx],
            self.entries_dict["token_type_ids"][idx],
            self.entries_dict["label"][idx]
        )

In [11]:
valid_idx, test_idx = train_test_split(list(range(len(dataset["validation"]))), test_size=0.5, random_state=42)

In [12]:
train_dict = {
    "input_ids": dataset['train']["input_ids"],
    "attention_mask": dataset['train']["attention_mask"],
    "token_type_ids": dataset['train']["token_type_ids"],
    "label": dataset['train']["label"]
}
valid_dict = {
    "input_ids": dataset['validation']["input_ids"][valid_idx],
    "attention_mask": dataset['validation']["attention_mask"][valid_idx],
    "token_type_ids": dataset['validation']["token_type_ids"][valid_idx],
    "label": dataset['validation']["label"][valid_idx]
}
test_dict = {
    "input_ids": dataset['validation']["input_ids"][test_idx],
    "attention_mask": dataset['validation']["attention_mask"][test_idx],
    "token_type_ids": dataset['validation']["token_type_ids"][test_idx],
    "label": dataset['validation']["label"][test_idx]
}

In [13]:
# Instantiate a PyTorch Dataloader around our dataset
train_loader = torch.utils.data.DataLoader(SST2Dataset(train_dict), batch_size=32, shuffle=False, drop_last=False)
valid_loader = torch.utils.data.DataLoader(SST2Dataset(valid_dict), batch_size=32, drop_last=False)
test_loader = torch.utils.data.DataLoader(SST2Dataset(test_dict), batch_size=32, drop_last=False)

In [14]:
logits = {}
for subset, dataloader in (("train", train_loader), ("valid", valid_loader), ("test", test_loader)):
    results = []
    for *batch, target in tqdm(dataloader):
        results.append(model(*batch_to_device(batch, "cuda"))[0].detach().cpu())
    logits[subset] = torch.cat(results, axis=0)

100%|██████████| 2105/2105 [01:01<00:00, 34.14it/s]
100%|██████████| 14/14 [00:00<00:00, 36.15it/s]
100%|██████████| 14/14 [00:00<00:00, 35.80it/s]


In [15]:
logits["train"].shape

torch.Size([67349, 2])

In [16]:
train_dict["logits"] = logits["train"]
valid_dict["logits"] = logits["valid"]
test_dict["logits"] = logits["test"]

In [17]:
torch.save([train_dict, valid_dict, test_dict], str(CACHE_DIR / "distill-dicts.jbl"))