In [1]:
# Replace this with pip install stabilizer
import sys
sys.path.insert(0, '..')

In [2]:
import torch
import logging
import numpy as np
import pandas as pd

from torch import nn
from torch.utils.data import DataLoader
from sklearn.metrics import matthews_corrcoef

from stabilizer.model import PoolerClassifier
from stabilizer.dataset import TextLabelDataset
from stabilizer.trainer import train_step, evaluate_step

from transformers import get_scheduler, AdamW, AutoModel, AutoTokenizer

In [3]:
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

In [4]:
def post_process_targets(targets):
    targets = targets.type(torch.int)
    targets = targets.cpu().detach().numpy().reshape(-1)
    return targets


def post_process_predictions(predictions):
    predictions = torch.sigmoid(predictions)
    predictions = (predictions >= 0.5).type(torch.int)
    predictions = predictions.cpu().detach().numpy().reshape(-1)
    return predictions


def compute_matthews_corrcoef(targets, predictions):
    if len(np.unique(predictions)) > 1 and len(np.unique(targets)) > 1:
        score = matthews_corrcoef(y_true=targets, y_pred=predictions)
    else:
        score = 0.0
    return score

In [5]:
config = {'train_data_path': '../data/glue/cola/train.jsonl',
          'valid_data_path': '../data/glue/cola/valid.jsonl',
          'batch_size': 32,
          'pretrained_tokenizer_name_or_path': '../models/bert-base-uncased/',
          'pretrained_model_name_or_path': '../models/bert-base-uncased/',
          'device_name': 'cpu',
          'dropout_prob': 0.1,
          'num_classes': 1,
          'lr': 2e-5,
          'num_epochs': 3,
          'validate_every_n_iteration': 10}

In [6]:
# Read training data
train_data = pd.read_json(path_or_buf=config['train_data_path'], lines=True).set_index('idx')
valid_data = pd.read_json(path_or_buf=config['valid_data_path'], lines=True).set_index('idx')

In [7]:
# Show a small snippet and Give a small explanation of the data

In [8]:
# Prepate data to create dataset
train_text_excerpts = train_data['text'].tolist()
valid_text_excerpts = valid_data['text'].tolist()
train_labels = torch.from_numpy(train_data['label'].to_numpy().reshape(-1, 1)).type(torch.float32)
valid_labels = torch.from_numpy(valid_data['label'].to_numpy().reshape(-1, 1)).type(torch.float32)

In [9]:
# Create Dataset and DataLoader
train_dataset = TextLabelDataset(text_excerpts=train_text_excerpts, labels=train_labels)
valid_dataset = TextLabelDataset(text_excerpts=valid_text_excerpts, labels=valid_labels)

train_dataloader = DataLoader(dataset=train_dataset, batch_size=config['batch_size'], shuffle=True)
valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=config['batch_size'], shuffle=False)

In [10]:
# Create tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(config['pretrained_tokenizer_name_or_path'])
transformer = AutoModel.from_pretrained(pretrained_model_name_or_path=config['pretrained_model_name_or_path'],
                                        hidden_dropout_prob=config['dropout_prob'],
                                        attention_probs_dropout_prob=config['dropout_prob'])
model = PoolerClassifier(transformer=transformer,
                         transformer_output_size=transformer.config.hidden_size,
                         transformer_output_dropout_prob=config['dropout_prob'],
                         num_classes=config['num_classes'])
device = torch.device(config['device_name'])
_ = model.to(device)

In [11]:
# Define loss
loss_fn = nn.BCEWithLogitsLoss()

# Create optimizer and scheduler
model_parameters = model.parameters()
optimizer = AdamW(params=model_parameters, lr=config['lr'])

In [12]:
num_training_steps = config['num_epochs'] * len(train_dataloader)
num_warmup_steps = num_training_steps // 10
logger.info(f'Number of training steps: {num_training_steps}')
logger.info(f'Number of warmup steps: {num_warmup_steps}')

scheduler = get_scheduler(name='linear',
                          optimizer=optimizer,
                          num_warmup_steps=num_warmup_steps,
                          num_training_steps=num_training_steps)

09/18/2021 13:58:04 - INFO - __main__ - Number of training steps: 804
09/18/2021 13:58:04 - INFO - __main__ - Number of warmup steps: 80


In [13]:
iteration_num = 0
for epoch in range(config['num_epochs']):
    for batch in train_dataloader:
        batch_inputs = tokenizer(text=batch['text_excerpt'], padding=True, truncation=True, return_tensors='pt').to(device)
        batch_targets = batch['label'].to(device)
        train_outputs = train_step(model=model, inputs=batch_inputs, targets=batch_targets, loss_fn=loss_fn, optimizer=optimizer, scheduler=scheduler)
        if iteration_num % config['validate_every_n_iteration'] == 0:
            valid_targets, valid_predictions = [], []
            for batch in valid_dataloader:
                batch_inputs = tokenizer(text=batch['text_excerpt'], padding=True, truncation=True, return_tensors='pt').to(device)
                batch_targets = batch['label'].to(device)
                valid_outputs = evaluate_step(model=model, inputs=batch_inputs, targets=batch_targets, loss_fn=loss_fn)
                valid_targets.extend(valid_outputs['targets'])
                valid_predictions.extend(valid_outputs['predictions'])
            valid_targets = torch.vstack(valid_targets)
            valid_predictions = torch.vstack(valid_predictions)
            valid_loss = loss_fn(valid_predictions, valid_targets)
            valid_targets = post_process_targets(valid_targets)
            valid_predictions = post_process_predictions(valid_predictions)
            valid_score = compute_matthews_corrcoef(targets=valid_targets, predictions=valid_predictions)
            logger.info(f"Iteration num: {iteration_num}, Train loss: {train_outputs['loss']}")
            logger.info(f"Iteration num: {iteration_num}, Valid loss: {valid_loss}, Valid score: {valid_score}")
        iteration_num += 1

09/18/2021 13:58:07 - INFO - __main__ - Iteration num: 0, Train loss: 0.8618794083595276
09/18/2021 13:58:07 - INFO - __main__ - Iteration num: 0, Valid loss: 0.8374536633491516, Valid score: 0.0
09/18/2021 13:58:13 - INFO - __main__ - Iteration num: 10, Train loss: 0.8465286493301392
09/18/2021 13:58:13 - INFO - __main__ - Iteration num: 10, Valid loss: 0.812882125377655, Valid score: 0.022538632076706366
09/18/2021 13:58:18 - INFO - __main__ - Iteration num: 20, Train loss: 0.7404846549034119
09/18/2021 13:58:18 - INFO - __main__ - Iteration num: 20, Valid loss: 0.7380532026290894, Valid score: -0.020862608741005987
09/18/2021 13:58:24 - INFO - __main__ - Iteration num: 30, Train loss: 0.648025631904602
09/18/2021 13:58:24 - INFO - __main__ - Iteration num: 30, Valid loss: 0.6546173095703125, Valid score: -0.008820066505818383
09/18/2021 13:58:29 - INFO - __main__ - Iteration num: 40, Train loss: 0.616409182548523
09/18/2021 13:58:29 - INFO - __main__ - Iteration num: 40, Valid loss:

KeyboardInterrupt: 