## Fine tuning bert for classification

This notebook describes fine tuning of bert model. Expected results after fine-tuning is
* 0.9-0.95 Accuracy
* 0.9-0.95 f1

before fine-tuning both metics are around 0.6

In [5]:
import pandas as pd
import torch
import numpy as np
from sklearn.metrics import f1_score, accuracy_score
from datetime import datetime
from torch.utils.data import Dataset, DataLoader
from tensorboardX import SummaryWriter

from transformers.transformers import AdamW, WarmupLinearSchedule
from transformers.transformers import BertTokenizer, BertModel, BertConfig, BertForSequenceClassification

from utils import cache_ds, evaluate, process_test

### Training

In [None]:
writer = SummaryWriter()

In [None]:
train = pd.read_csv('./data/train.csv', index_col='id')
train.dropna(axis=0, inplace=True)
test = pd.read_csv('data/test.csv', index_col='test_id')

In [None]:
model_weights = 'bert-base-cased-finetuned-mrpc'

In [None]:
tokenizer = BertTokenizer.from_pretrained(model_weights, do_lower_case=False)

In [None]:
model=BertForSequenceClassification.from_pretrained(model_weights, output_hidden_states=True).cuda()
model.eval()

In [71]:
cache_ds(train, tokenizer, save = './data/train_ds_CASED_cached')
#train_tensor_data = torch.load('./data/train_ds_CASED_cached')

In [8]:
train_ds, val_ds = torch.utils.data.random_split(train_tensor_data, [len(train_tensor_data) - 10000, 10000])

In [10]:
dl_train = DataLoader(train_ds, batch_size=32, sampler=train_sampler)
dl_val = DataLoader(val_ds, batch_size=8, sampler=val_sampler)

In [12]:
n_epochs = 3

In [None]:
no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

t_total = n_epochs * len(dl_train)
optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5, eps=1e-8)
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=0, t_total=t_total)

In [None]:
max_grad_norm = 1
global_step = 0
acc_loss = 0.0
logging_loss = acc_loss
model.zero_grad()
log_step = 500
max_steps = 15000

epoch_range = trange(n_epochs, desc='Epoch', position=0, leave=True)

for epoch_num in epoch_range:
    epoch_iter = tqdm(dl_train, desc='Inside epoch {}'.format(epoch_num), position=1, leave=True)
    for step, batch in enumerate(epoch_iter):
        model.train()
        
        tup = tuple(item.cuda() for item in batch[3:])
        model_input = dict(zip(['input_ids', 'attention_mask', 'token_type_ids', 'labels'], tup))
        logloss, logits = model(**model_input)[:2]
        
        logloss.backward()
        acc_loss += logloss.item()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        
        optimizer.step()
        scheduler.step()
        model.zero_grad()
        global_step += 1
        
        if global_step % log_step == 0:
            eval_res = evaluate(model)
            for key, value in eval_res:
                writer.add_scalar('eval_{}'.format(key), value, global_step=global_step)
            writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
            writer.add_scalar('loss', (acc_loss - logging_loss) / log_step, global_step)
            logging_loss = acc_loss
    
        if global_step >= max_steps:
            break
    
    torch.save(
        {
            'epoch': epoch_num,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            #'loss': loss,
            'val_metric': list(evaluate(model))
        }, 'models/checkpoint_iter_{}_{}'.format(global_step, datetime.now()))
    
    if global_step >= max_steps:
        break

        
        
writer.export_scalars_to_json('./scalars_{}.json'.format(datetime.now()))
writer.close()

### Predicting for test

In [29]:
model.load_state_dict(torch.load('./models/checkpoint_iter_15000_2019-11-04 14:43:10.972407')['model_state_dict'])

<All keys matched successfully>

In [30]:
#check that everything's ok
list(evaluate(model, dl_val))

[('logloss', tensor(2.8814e-05, device='cuda:0')),
 ('accuracy', 0.9354),
 ('f1', 0.9105263157894737)]

In [16]:
cached_test_ds = cache_ds(test.dropna(), save='./data/test_ds_CASED_cached', train=False)

In [56]:
test_sampler = torch.utils.data.SequentialSampler(test_ds)
dl_test = DataLoader(test_ds, batch_size=100, sampler=test_sampler)

In [57]:
test_predictions = process_test(model, dl_test)
answers = np.concatenate(test_predictions)

In [40]:
test['is_duplicate'] = 0
test.loc[test.dropna().index, 'is_duplicate'] = answers

In [41]:
your_name = #INSERT SUBMISSION NAME HERE
test[['is_duplicate']].to_csv(your_name)