# BERT Fine-Tuning with PyTorch

In [4]:
import time
import random
import datetime
from collections import defaultdict
from IPython.display import clear_output

from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import pickle


import torch
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import AdamW, BertModel , BertModel, BertForMaskedLM, AdamW, RobertaTokenizer, RobertaModel

Matplotlib created a temporary config/cache directory at /var/tmp/matplotlib-chspwwfn because the default path (/home/vm_admin/.config/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


There are 1 GPU(s) available.
We will use the GPU: Tesla K80


In [None]:
if torch.cuda.is_available():        
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

torch.cuda.manual_seed(123)
np.random.seed(123)
random.seed(123)
torch.backends.cudnn.enabled=False
torch.backends.cudnn.deterministic=True

## Make data for model

### Create a dictionary of true labels with range(0, labels)

In [None]:
CLF_NUMBERs = [66, 37, 39, 40, 8, 43, 77, 81, 83, 84, 57, 91, 189, 195]
data_path = 'clfs_data/'
i_lbl = iter(range(len(CLF_NUMBERs)))
lbl_d = dict()
lbl_d[0] = 0
for i, l in zip(i_lbl, CLF_NUMBERs):
    lbl_d[l] = i+1

lbl_d

### The sample is pre-divided because this code was written to compare the old classifier with the new one, with the same train, val, test

In [None]:
dd_ = defaultdict(list)
sentences, test_sentences, labels, test_labels, val_sentences, val_labels = [], [], [], [], [], []
for CLF in CLF_NUMBERs:
    if CLF == 0:
        continue
    df = pd.read_excel(data_path + CLF + '_train.xlsx')
    sentences.extend(df['sentence'].to_list())
    labels.extend([lbl_d[CLF] if x==1 else 0 for x in df['label'].to_list()])
    
    df = pd.read_excel(data_path + CLF + '_test.xlsx')
    test_sentences.extend(df['sentence'].to_list())
    test_labels.extend([lbl_d[CLF] if x==1 else 0 for x in df['label'].to_list()])    
    
    df = pd.read_excel(data_path + CLF + '_val.xlsx')
    val_sentences.extend(df['sentence'].to_list())
    val_labels.extend([lbl_d[CLF] if x==1 else 0 for x in df['label'].to_list()])       


In [None]:
from collections import Counter
print('train')
print(Counter(labels))
print('test')
print(Counter(test_labels))
print('val')
print(Counter(val_labels))
labels_count = len(Counter(val_labels))

## Tokenization & Input Formatting

### The model has the ability to select multiple pretrained Bert's

In [None]:
from transformers import BertTokenizer
model_name = "sberbank-ai/ruBert-large"
# model_name = "biencoder_mlm/best-val-model_mlm_desc.pt"
# model_name = 'sberbank-ai/ruRoberta-large'
# Load the BERT tokenizer.

if 'ruRoberta' in model_name:
    print('Loading RobertaTokenizer tokenizer...')
    tokenizer = RobertaTokenizer.from_pretrained(model_name, do_lower_case=False)  
else:
    print('Loading BertTokenizer tokenizer...')
    tokenizer = BertTokenizer.from_pretrained(model_name, do_lower_case=False)

### Tokenize Dataset &  Training & Validation Split

In [None]:
def manual_tokenizer(tokenizer, sentences, labels):
    input_ids = []
    attention_masks = []

    for sent in sentences:
        encoded_dict = tokenizer.encode_plus(
                            sent,                      
                            add_special_tokens = True, 
                            max_length = 64,           
                            pad_to_max_length = True,
                            return_attention_mask = True,  
                            return_tensors = 'pt',   
                       )

        input_ids.append(encoded_dict['input_ids'])
        attention_masks.append(encoded_dict['attention_mask'])


    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)
    labels = torch.tensor(labels)
    return input_ids, attention_masks, labels

In [None]:
input_ids, attention_masks, labels = manual_tokenizer(tokenizer, sentences, labels)
train_dataset = TensorDataset(input_ids, attention_masks, labels)

input_ids, attention_masks, labels = manual_tokenizer(tokenizer, val_sentences, val_labels)
val_dataset = TensorDataset(input_ids, attention_masks, labels)

### DataLoader

In [None]:
batch_size = 126

train_dataloader = DataLoader(
            train_dataset,  # The training samples.
            sampler = RandomSampler(train_dataset), # Select batches randomly
            batch_size = batch_size # Trains with this batch size.
        )

validation_dataloader = DataLoader(
            val_dataset, # The validation samples.
            sampler = SequentialSampler(val_dataset), # Pull out batches sequentially.
            batch_size = batch_size # Evaluate with this batch size.
        )

## Multi Classification Model

In [None]:
freeze_bert = True

class BertClassifier(nn.Module):

    def __init__(self, model_name, num_labels, mlm_model_path=None):

        super(BertClassifier, self).__init__()

        # self.bert = BertModel.from_pretrained(model_name)
        # self.config = BertConfig.from_pretrained('sberbank-ai/ruBert-large')
        # self.bert = BertForMaskedLM.from_pretrained(self.new_model_path, config=self.config).bert  
        
        
        # self.bert.resize_token_embeddings(len(tokenizer))
        if mlm_model_path:
            print('MLM model load state_dict')            
            self.bert = BertForMaskedLM.from_pretrained(model_name)
            self.bert.load_state_dict(torch.load(mlm_model_path))
            self.bert = self.bert.bert
        else:
            self.bert = BertModel.from_pretrained(model_name) 
        
        self.hidden_size = self.bert.config.hidden_size
        self.linear = nn.Linear(self.hidden_size, num_labels)
        

    def forward(self, input_id, mask):      
        outputs_bert = self.bert(input_ids=input_id,attention_mask=mask)
        outputs_bert_zero = outputs_bert[0][:,0,:]    
        linear_output = self.linear(outputs_bert_zero)

        return linear_output
    
class RobertaClass(torch.nn.Module):
    def __init__(self, model_name, num_labels):
        super(RobertaClass, self).__init__()
        self.l1 = RobertaModel.from_pretrained(model_name)
        self.pre_classifier = torch.nn.Linear(1024, 768)
        self.dropout = torch.nn.Dropout(0.3)
        self.classifier = torch.nn.Linear(768, num_labels)

    def forward(self, input_ids, attention_mask):
        output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = output_1[0]
        pooler = hidden_state[:, 0]
        pooler = self.pre_classifier(pooler)
        pooler = torch.nn.ReLU()(pooler)
        pooler = self.dropout(pooler)
        output = self.classifier(pooler)
        return output    


### mlm_path - path to the retrained BertForMaskedLM on the source domain (wach )

In [None]:
mlm_path = None # '../model_mlm.pt'

if 'ruRoberta' in model_name:
    print(f'start ruRoberta')
    model = RobertaClass(model_name, len(lbl_d)) 
else:    
    print(f'start sberbank-ai/ruBert-large')
    if mlm_path:
        print('with MLM')
    model = BertClassifier('sberbank-ai/ruBert-large', len(lbl_d), mlm_path)    
    
model.cuda()
clear_output(False)

### Optimizer & Learning Rate Scheduler & epochs

In [None]:
optimizer = AdamW(model.parameters(),
                  lr = 2e-5, 
                  eps = 1e-8 
                )

# criterion = nn.BCEWithLogitsLoss()
criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda()

epochs = 2
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = 0, # Default value in run_glue.py
                                            num_training_steps = total_steps)

## Training Loop

### Utils

In [None]:
def flat_accuracy(preds, y):
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = [i==j for i, j in zip(rounded_preds, y)]
    acc = sum(correct) / float(len(correct))
    return acc.item()

def format_time(elapsed):
    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds=elapsed_rounded))

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }


In [None]:
training_stats = []
total_t0 = time.time()

for epoch_i in range(0, epochs):
    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    print('Training...')

    t0 = time.time()
    total_train_loss = 0
    model.train()

    for step, batch in enumerate(train_dataloader):
        if step % 40 == 0 and not step == 0:
            elapsed = format_time(time.time() - t0)
            print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))

        #   [0]: input ids 
        #   [1]: attention masks
        #   [2]: labels 
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)

        model.zero_grad()   
        
        output = model(b_input_ids, 
                       b_input_mask).squeeze(1)
        
        loss = criterion(output, b_labels)
        logits = output
        total_train_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()

    scheduler.step()
    avg_train_loss = total_train_loss / len(train_dataloader)            
    training_time = format_time(time.time() - t0)

    print("")
    print("  Average training loss: {0:.2f}".format(avg_train_loss))
    print("  Training epcoh took: {:}".format(training_time))
    print("")
    print("Running Validation...")

    t0 = time.time()
    model.eval()

    total_eval_accuracy = 0
    total_eval_loss = 0
    nb_eval_steps = 0

    for batch in validation_dataloader:
        
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)

        with torch.no_grad():        
            logits = model(b_input_ids, b_input_mask).squeeze(1)           
            loss = criterion(logits, b_labels)

        total_eval_loss += loss.item()

    avg_val_loss = total_eval_loss / len(validation_dataloader)
    validation_time = format_time(time.time() - t0)
    
    print("  Validation Loss: {0:.2f}".format(avg_val_loss))
    print("  Validation took: {:}".format(validation_time))

    training_stats.append(
        {
            'epoch': epoch_i + 1,
            'Training Loss': avg_train_loss,
            'Valid. Loss': avg_val_loss,
            'Training Time': training_time,
            'Validation Time': validation_time
        }
    )

print("")
print("Training complete!")

print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))

### Plot epoch stats

In [None]:
pd.set_option('precision', 2)
df_stats = pd.DataFrame(data=training_stats)
df_stats = df_stats.set_index('epoch')
df_stats

In [None]:

%matplotlib inline

sns.set(style='darkgrid')
sns.set(font_scale=1.5)
plt.rcParams["figure.figsize"] = (12,6)

plt.plot(df_stats['Training Loss'], 'b-o', label="Training")
plt.plot(df_stats['Valid. Loss'], 'g-o', label="Validation")

plt.title("Training & Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.xticks([1, 2, 3, 4])

plt.show()

### Save model

In [None]:
torch.save(model.state_dict(), 'mulliclass_mlm.pt')

### Save for torch serve

In [None]:
torch.jit.save(traced_model, 'traced_mulliclass_mlm_v3.pt')

## Precision-Recall Curve Compare

In [None]:
CLF_CONFIG = {
    8: {
        "thresh": 0.5
    },
    40: {
        "thresh": 0.53
    },
    43: {
        "thresh": 0.6
    },
    77: {
        "thresh": 0.5
    },
    66: {
        "thresh": 0.6
    },
    81: {
        "thresh": 0.6
    },
    83: {
        "thresh": 0.7
    }, 
    
    37: {
        "thresh": 0.65
    },
    
    39: {
        "thresh": 0.5
    },
    
    84: {
        "thresh": 0.72
    },
    
    57: {
        "thresh": 0.5
    },    
    91: {
        "thresh": 0.6
    },   
    189: {
        "thresh": 0.6
    },      
}

def old_new_precision(OLD_THR, y_elmo, preds_elmo, class_true_label, pred_class):
    old_prec, old_rec, old_thr = precision_recall_curve(y_elmo, preds_elmo)
    new_prec, new_rec, new_thr = precision_recall_curve(class_true_label, pred_class)

    old_thr_idx = np.where(old_thr > OLD_THR)[0].min()

    old_rec_at_thr = old_rec[old_thr_idx]
    old_prec_at_thr = old_prec[old_thr_idx]

    new_thr_at_old_rec = np.where(new_rec <= old_rec_at_thr)[0].min()

    new_rec_at_old_rec = new_rec[new_thr_at_old_rec]
    new_prec_at_old_rec = new_prec[new_thr_at_old_rec]

    return new_prec_at_old_rec , old_prec_at_thr   

### MultiClass BertClassifier vs a bunch of Binary Classifier (same train data) for ROC AUC chart comparison by class

In [None]:
for k, v in lbl_d.items():
    if k == 0:
        continue
    df = pd.read_excel(f'{data_path}/{k}_test.xlsx')
    test_sentences = df['sentence'].to_list()
    test_labels = [lbl_d[k] if x==1 else 0 for x in df['label'].to_list()]
    
    model.eval()
    predictions , true_labels, logits_a_list = [], [], []
    input_ids = []
    attention_masks = []


    for sent in test_sentences:
        encoded_dict = tokenizer.encode_plus(
                            sent,                      
                            add_special_tokens = True, 
                            max_length = 64,           
                            pad_to_max_length = True,
                            return_attention_mask = True,
                            return_tensors = 'pt',     
                       )

        input_ids.append(encoded_dict['input_ids'])    
        attention_masks.append(encoded_dict['attention_mask'])

    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)
    labels = torch.tensor(test_labels)

    batch_size = 32 

    # Create the DataLoader.
    prediction_data = TensorDataset(input_ids, attention_masks, labels)
    prediction_sampler = SequentialSampler(prediction_data)
    prediction_dataloader = DataLoader(prediction_data, sampler=prediction_sampler, batch_size=batch_size)

    predictions , true_labels, logits_a_list = [], [], []

    for batch in prediction_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch

        with torch.no_grad():
            logits = model(b_input_ids, b_input_mask)

        logits = torch.softmax(logits, dim=1)
        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()

        predictions.extend(logits)
        true_labels.extend(label_ids)    
    
    classs = v
    pred_class = np.vstack(predictions)[:, classs]
    class_true_label = [1 if x==classs else 0 for x in true_labels]


    with open(f'binary_clfs/{k}_precision_recall_curve.pkl', 'rb') as f:
        y_elmo, preds_elmo = pickle.load(f)
        
    with open(f'main_model/Bert_{k}_precision_recall_curve.pkl', 'wb') as f:
        pickle.dump((class_true_label, pred_class),f)        

    precision_elmo, recall_elmo, _ = precision_recall_curve(y_elmo, preds_elmo)        
    precision, recall, _ = precision_recall_curve(class_true_label, pred_class)

    fig, ax = plt.subplots()
    ax.plot(recall, precision, color='red')
    ax.plot(recall_elmo, precision_elmo, color='black')    


    tp_count = sum(y_elmo)
    tn_count = len(y_elmo) - tp_count
    ax.set_title(f'CLF_NUMBER: {k}\n red:  {model_name}\n black:  elmo\n tp/tn {tp_count}/{tn_count}')
    ax.set_ylabel('Precision')
    ax.set_xlabel('Recall')

    plt.show()
    print('old_new_precision')
    print(old_new_precision(CLF_CONFIG[k]['thresh'], y_elmo, preds_elmo, class_true_label, pred_class))
