In [6]:
import os
import numpy as np

import torch
from torch import nn
import wandb

from datasets import load_from_disk
from transformers import BertTokenizerFast, AutoModelForSequenceClassification
from transformers import Trainer, EvalPrediction, default_data_collator, TrainingArguments, EarlyStoppingCallback

from transformers.integrations import WandbCallback

import evaluate 
from evaluate import evaluator

In [5]:
TUNE_HYPERPARAMETERS = False

In [6]:
if TUNE_HYPERPARAMETERS:
    SWEEP_NAME = 'bert_sweep_8'
    WANDB_NOTES = 'inital hp sweep on 10,000 train samples, tune lr, no early stop'
else:
    WANDB_NAME = 'final_bert_lr_5e5'
    WANDB_NOTES = 'finetune bert'
    
WANDB_PROJECT = 'finetune_bert' 
WANDB_NOTEBOOK_NAME = 'finetune_bert.ipynb'

In [7]:
train_size = 600000 if not TUNE_HYPERPARAMETERS else 10000 # doing 600k instead of 840k bc gonna take like 15 hr
val_size = 120000 if not TUNE_HYPERPARAMETERS else 1000 # doing less val bc we can calculate final val at the end. this val mostly for tracking overfitting

batch_size = 16
num_steps_per_epoch = int(train_size / batch_size)
num_epochs = 1

total_steps = num_steps_per_epoch * num_epochs
logging_steps = round(total_steps / 100)
# logging_steps = 50
# eval_steps = round(num_steps_per_epoch / 100)
eval_steps = logging_steps * 20

print(f'logging_steps: {logging_steps}')
print(f'eval_steps: {eval_steps}')

logging_steps: 375
eval_steps: 7500


In [10]:
CWD = os.getcwd()
DATASET_DIR = os.path.join(CWD, 'data')

dataset = load_from_disk(DATASET_DIR)
dataset = dataset.remove_columns('text')
dataset = dataset.rename_column('label', 'labels')
dataset.set_format('torch')

dataset['train'] = dataset['train'].shuffle(seed=100).select(range(train_size))
dataset['val'] = dataset['val'].shuffle(seed=100).select(range(val_size))
dataset['test'] = dataset['test'].shuffle(seed=100).select(range(val_size))

print(dataset)

Loading cached shuffled indices for dataset at /home/ubuntu/partisan_bias_detection/data/train/cache-2f99bb8df05f4ee8.arrow
Loading cached shuffled indices for dataset at /home/ubuntu/partisan_bias_detection/data/val/cache-e3d648fff5c3e86c.arrow
Loading cached shuffled indices for dataset at /home/ubuntu/partisan_bias_detection/data/test/cache-2d06e247be8c9023.arrow


DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids'],
        num_rows: 600000
    })
    val: Dataset({
        features: ['labels', 'input_ids'],
        num_rows: 180000
    })
    test: Dataset({
        features: ['labels', 'input_ids'],
        num_rows: 180000
    })
})


In [None]:
torch.cuda.empty_cache()

In [None]:
# wandb logging

os.environ['WANDB_API_KEY'] = '409d576b1e20724351b01a9d45b006f36972d20f' 

# set the wandb project where this run will be logged
os.environ['WANDB_PROJECT'] = WANDB_PROJECT

# save your trained model checkpoint to wandb
os.environ['WANDB_LOG_MODEL'] = 'end'

# turn off watch to log faster
os.environ['WANDB_WATCH'] = 'false'

os.environ['WANDB_NOTEBOOK_NAME'] = WANDB_NOTEBOOK_NAME
os.environ['WANDB_NOTES'] = WANDB_NOTES

In [None]:
# tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', do_lower_case=True)
def model_init():
    model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=5)
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    model.to(device)
    return model

In [None]:
# metrics
def compute_metrics(eval_pred):
    "called at end of validation"
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    
    accuracy_metric = evaluate.load('accuracy')
    precision_metric = evaluate.load('precision')
    recall_metric = evaluate.load('recall')
    f1_metric = evaluate.load('f1')
    
    metrics = {'accuracy': accuracy_metric.compute(predictions=preds, references=labels)['accuracy'], 
            'precision': precision_metric.compute(predictions=preds, references=labels, average='macro', zero_division=0)['precision'], 
            'recall': recall_metric.compute(predictions=preds, references=labels, average='macro', zero_division=0)['recall'], 
            'f1': f1_metric.compute(predictions=preds, references=labels, average='macro')['f1'], 
           }
    return metrics

In [None]:
def set_train_args(lr, save_path, run_name):
    return TrainingArguments(
        learning_rate=lr,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=256,
        num_train_epochs=2, 
        logging_strategy='steps',
        logging_steps=logging_steps,
        evaluation_strategy = 'steps',
        save_strategy='steps',
        eval_steps=eval_steps, 
        save_steps=eval_steps,
        load_best_model_at_end=True,
        report_to='wandb',
        logging_dir=os.path.join('logs', save_path),
        output_dir=os.path.join('models', save_path),
        save_total_limit=3, 
        run_name=run_name
    )
    
def train():
    wandb.init()
    
    # initialize model
    model = model_init()

    # set hyperparams
    lr = wandb.config.learning_rate
    run_name = f'lr_{lr}'
    save_path = os.path.join(WANDB_PROJECT, SWEEP_NAME, run_name)
    
    args = set_train_args(lr, save_path, run_name)
          
    # training loop
    trainer = Trainer(model=model,
                      args=args,
                      train_dataset=dataset['train'],
                      eval_dataset=dataset['val'],
                      compute_metrics=compute_metrics,
                     )
    
    trainer.train()
    
    # save best model
    trainer.save_model(os.path.join('best_models', save_path))
    trainer.save_state()
    
    # evaluate on val
    val_loss = trainer.evaluate(dataset['val'])
    wandb.log({'val_loss': val_loss})  

    wandb.finish()

In [None]:
if TUNE_HYPERPARAMETERS:
    sweep_config = {
        'name': SWEEP_NAME,
        'method': 'grid',
        'metric': {
            'goal': 'minimize',
            'name': 'eval/loss'
        },
        'parameters': {
            'learning_rate': {
                'values': [1e-5, 2e-5, 5e-5, 1e-4]
            },
        }
    }

    sweep_id = wandb.sweep(sweep_config)
    wandb.agent(sweep_id, function=train)
    
else:
    wandb.init()
    
    # initialize model
    model = model_init()
    save_path = os.path.join(WANDB_PROJECT, WANDB_NAME)
    run_name = WANDB_NAME
        
    # set hyperparams
    lr = 5e-5
    num_train_epochs = 2
    args = set_train_args(lr, save_path, run_name)
    
    # training loop
    trainer = Trainer(model=model,
                      args=args,
                      train_dataset=dataset['train'],
                      eval_dataset=dataset['val'],
                      compute_metrics=compute_metrics,
                     )
    
    trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.0))
    trainer.train()
    
    # save best model
    trainer.save_model(os.path.join('best_models', save_path))
    trainer.save_state()

    wandb.finish()