In [1]:
from lr.models.transformers.util import load_and_cache_examples
from lr.models.transformers.util import train, set_seed
from torch.utils.data import TensorDataset
import logging
import os
import shutil
import torch
import numpy as np
import pandas as pd
from transformers import BertTokenizer
from transformers import BertForSequenceClassification


### Params

In [2]:
hyperparams = {"local_rank": -1,
               "max_seq_length": 128,
               "overwrite_cache": False,
               "cached_path":"data/toy/",
               "train_path": "data/toy/train.csv",
               "dev_path":"data/toy/dev.csv",
               "num_train_epochs":3.0,
               "per_gpu_train_batch_size":8,
               "per_gpu_eval_batch_size":8,
               "gradient_accumulation_steps": 1,
               "learning_rate":5e-5,
               "weight_decay":0.0,
               "adam_epsilon": 1e-8,
               "max_grad_norm": 1.0,
               "max_steps": 10,
               "warmup_steps": 0,
               "save_steps": 5,
               "no_cuda":True,
               "n_gpu":1,
               "model_name_or_path":"bert",
               "output_dir":"bert",
               "random_state": 42,
               "fp16":False,
               "fp16_opt_level":"01",
               "device":"cpu",
               "verbose":False,
               "model_type": "bert"}

set_seed(hyperparams["random_state"], hyperparams["n_gpu"])

pretrained_weights = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(pretrained_weights)
model = BertForSequenceClassification.from_pretrained(pretrained_weights, num_labels = 3)


## Creating features

In [3]:
train_dataset = load_and_cache_examples(hyperparams, tokenizer)

In [4]:
dev_dataset = load_and_cache_examples(hyperparams, tokenizer, evaluate=True)

## train

In [5]:
global_step, tr_loss = train(train_dataset, model, tokenizer, hyperparams)

In [6]:
training_logs = pd.read_csv("bert/log.csv")

a1 = training_logs.loss.rolling(3).mean().iloc[3]
a2 = training_logs.loss.rolling(3).mean().iloc[-1]



assert a1 > a2
assert a1 == 1.3554697434107463
assert a2 ==  1.131113092104594
assert tr_loss == 1.195960673418912

In [7]:
if os.path.exists("example.log"):
    os.remove("example.log")
    
    
if os.path.exists("bert"):
    shutil.rmtree("bert")