# Finetuning a BERT Model

In this exercise, you will have to create a BERT model and then train it on the CoLA dataset which has a list of sentences that are either grammatically correct or incorrect. The data loading and model training, testing logic are already included in your code. You will need to fetch the data, and use a pretrained BERT model to solve a classification task.

**In this workspace you have GPU to help train the model but it is best practice to DISABLE it while writing code and only ENABLE it when you are training.**

Here are the steps you need to do to complete this exercise:

1. Data has already been downloaded from [here](https://nyu-mll.github.io/CoLA/) into the `cola_public` directory.
2. Create a tokenizer for the BERT Model
3. Create the BERT Model and the optimizer for the model
4. Save all your work and then **ENABLE** the GPU
5. Run the Package Installations.
5. Run the file to make sure that the model is training properly.
6. If it works, remember to **DISABLE** the GPU before moving to the next page. 

In case you get stuck, you can look at the solution by clicking the jupyter symbol at the top left and navigating to `finetune_a_bert_solution.py`.

## Try It Out!
- Can you train a different BERT architecture and compare the results? 
- Can you finetune the same model on a different dataset or task and compare the result?

In [1]:
!pip install pip --upgrade
!pip install datasets==1.8.0
!pip install transformers==4.6.1
!pip install ipywidgets
#!pip list

Collecting pip
  Downloading pip-24.0-py3-none-any.whl.metadata (3.6 kB)
Downloading pip-24.0-py3-none-any.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m75.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 23.3.2
    Uninstalling pip-23.3.2:
      Successfully uninstalled pip-23.3.2
Successfully installed pip-24.0
Collecting datasets==1.8.0
  Downloading datasets-1.8.0-py3-none-any.whl.metadata (9.3 kB)
Collecting pyarrow<4.0.0,>=1.0.0 (from datasets==1.8.0)
  Downloading pyarrow-3.0.0.tar.gz (682 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m682.2/682.2 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25h  Installing build dependencies ... [?25lerror
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpip subprocess to install build dependencies[0m did not run successfully.
  [31m│[0m exit c

In [2]:
from IPython.core.display import HTML
HTML("<script>Jupyter.notebook.kernel.restart()</script>")

In [None]:
import json
import os
import sys

import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
import torch.utils.data
import torch.utils.data.distributed
from torch.utils.data import DataLoader, RandomSampler, TensorDataset
from transformers import AdamW, BertForSequenceClassification, BertTokenizer

MAX_LEN = 64  # this is the max length of the sentence
batch_size=64
epochs=1

tokenizer = #TODO: Create the BERT tokenizer

def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)


def get_train_data_loader(batch_size):
    dataset = pd.read_csv(os.path.join("data", "train.csv"))
    sentences = dataset.sentence.values
    labels = dataset.label.values

    input_ids = []
    for sent in sentences:
        encoded_sent = tokenizer.encode(sent, add_special_tokens=True)
        input_ids.append(encoded_sent)

    # pad shorter sentences
    input_ids_padded = []
    for i in input_ids:
        while len(i) < MAX_LEN:
            i.append(0)
        input_ids_padded.append(i)
    input_ids = input_ids_padded

    # mask; 0: added, 1: otherwise
    attention_masks = []
    # For each sentence...
    for sent in input_ids:
        att_mask = [int(token_id > 0) for token_id in sent]
        attention_masks.append(att_mask)

    # convert to PyTorch data types.
    train_inputs = torch.tensor(input_ids)
    train_labels = torch.tensor(labels)
    train_masks = torch.tensor(attention_masks)

    train_data = TensorDataset(train_inputs, train_masks, train_labels)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

    return train_dataloader


def get_test_data_loader(test_batch_size):
    dataset = pd.read_csv(os.path.join("data", "test.csv"))
    sentences = dataset.sentence.values
    labels = dataset.label.values

    input_ids = []
    for sent in sentences:
        encoded_sent = tokenizer.encode(sent, add_special_tokens=True)
        input_ids.append(encoded_sent)

    # pad shorter sentences
    input_ids_padded = []
    for i in input_ids:
        while len(i) < MAX_LEN:
            i.append(0)
        input_ids_padded.append(i)
    input_ids = input_ids_padded

    # mask; 0: added, 1: otherwise
    attention_masks = []
    # For each sentence...
    for sent in input_ids:
        att_mask = [int(token_id > 0) for token_id in sent]
        attention_masks.append(att_mask)

    # convert to PyTorch data types.
    train_inputs = torch.tensor(input_ids)
    train_labels = torch.tensor(labels)
    train_masks = torch.tensor(attention_masks)

    train_data = TensorDataset(train_inputs, train_masks, train_labels)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=test_batch_size)

    return train_dataloader

def test(model, test_loader, device):
    model.eval()
    _, eval_accuracy = 0, 0

    with torch.no_grad():
        for batch in test_loader:
            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            b_labels = batch[2].to(device)

            outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)
            logits = outputs[0]
            logits = logits.detach().cpu().numpy()
            label_ids = b_labels.to("cpu").numpy()
            tmp_eval_accuracy = flat_accuracy(logits, label_ids)
            eval_accuracy += tmp_eval_accuracy

    print("Test set: Accuracy: ", eval_accuracy/len(test_loader.dataset))

def train(device):
    train_loader = get_train_data_loader(batch_size)
    test_loader = get_test_data_loader(batch_size)

    model = #TODO: Create the BERT Model
    model=model.to(device)
    optimizer = #TODO: Create the optimizer

    for epoch in range(1, epochs + 1):
        total_loss = 0
        model.train()
        for step, batch in enumerate(train_loader):
            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            b_labels = batch[2].to(device)
            model.zero_grad()

            outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
            loss = outputs[0]

            total_loss += loss.item()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            # modified based on their gradients, the learning rate, etc.
            optimizer.step()
            if step % 10  == 0:
                print(
                    "Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}".format(
                        epoch,
                        step * len(batch[0]),
                        len(train_loader.sampler),
                        100.0 * step / len(train_loader),
                        loss.item(),
                    )
                )

    test(model, test_loader, device)


if __name__ == "__main__":
    df = pd.read_csv(
        "./cola_public/raw/in_domain_train.tsv",
        sep="\t",
        header=None,
        usecols=[1, 3],
        names=["label", "sentence"],
    )
    sentences = df.sentence.values
    labels = df.label.values

    from sklearn.model_selection import train_test_split

    train_df, test_df = train_test_split(df)
    train_df.to_csv("./data/train.csv", index=False)
    test_df.to_csv("./data/test.csv", index=False)
    device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    train(device)