# Fine-tuning BERT model for news

## 1. Download Pre-trained Model
- Download the pre-trained BERT model from the Hugging Face model hub.
- We will do fine-tuning on top of it for the sentiment analysis task.

In [None]:
import torch.nn
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "google-bert/bert-base-chinese"
cache_dir = "../local_models"
AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)

## 2. Define Dataset class, for loading our custom dataset for further fine-tuning
- The dataset should be prepared in advance, and in different purpose, like train, validation, test, etc.
- There should more data for training, smaller size for validation and test. like 80%, 10%, 10%.

In [None]:
from torch.utils.data import Dataset
from datasets import load_dataset

class MyDataset(Dataset):
    def __init__(self, split):
        self.dataset = load_dataset(path="csv", data_files=f"../local_datasets/news/{split}.csv", split="train")

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        text = self.dataset[item]['text']
        label = self.dataset[item]['label']
        return text,label

dataset_train = MyDataset("train")
for data in dataset_train[:5]:
    print(data)

dataset_validation = MyDataset("validation")
for data in dataset_validation[:5]:
    print(data)

dataset_test = MyDataset("test")
for data in dataset_test[:5]:
    print(data)

## 3. Define downstream tasks model
- Extending the pretrained model for the fine-tuning;
- Update model config to support larger input size;
- default BERT model size was 512, we need to increase it to 1024.

In [None]:
from transformers import BertModel,BertConfig
import torch

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {DEVICE}")

model_path = r"../local_models/models--google-bert--bert-base-chinese/snapshots/c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f"
# config to support larger input size
config = BertConfig.from_pretrained(model_path)
config.max_position_embeddings = 1024

# Initialize the model with the config
pretrained_model = BertModel(config).to(DEVICE)
print(pretrained_model)

# Define the downstream task model, with more classification layers
class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(768, 10)

    def forward(self,input_ids,attention_mask,token_type_ids):
        # As result of config change, we need to do full training
        out = pretrained_model(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
        # Incremental model
        out = self.fc(out.last_hidden_state[:,0])
        return out


## 4. Training


In [None]:
from torch.utils.data import DataLoader
from transformers import BertTokenizer,AdamW
import torch

# Previous step loaded the DEVICE already, otherwise, you can load it again.
DEVICE = torch.device("cuba" if torch.cuda.is_available() else "cpu")

# tokenizer to encode the data
model_path = r"../local_models/models--google-bert--bert-base-chinese/snapshots/c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f"
tokenizer = BertTokenizer.from_pretrained(model_path)

# to encode the data while loading process
def tokenize_batches(batch):
    texts, labels = zip(*batch)
    encoded_data = tokenizer.batch_encode_plus(
        batch_text_or_text_pairs=texts,
        truncation=True,
        max_length=1024,
        padding="max_length",
        return_tensors="pt",
        return_length=True
    )
    tensor_labels = torch.tensor(labels)
    return encoded_data["input_ids"], encoded_data["attention_mask"], encoded_data["token_type_ids"], tensor_labels

# loading the training dataset
dataset_train = MyDataset("train")
train_data_loader = DataLoader (
    dataset_train,
    batch_size=2,
    shuffle=True,
    drop_last=True,
    collate_fn=tokenize_batches
)

# loading the validation dataset
dataset_train = MyDataset("validation")
validation_data_loader = DataLoader (
    dataset_validation,
    batch_size=2,
    shuffle=True,
    drop_last=True,
    collate_fn=tokenize_batches
)

EPOCH = 2 # This should be very large, like 30000, but for the demo, we set it to 3.
def run_training(data_loader=train_data_loader):
    print(DEVICE)
    model = MyModel().to(DEVICE)
    optimizer = AdamW(model.parameters(), lr=1e-5)
    loss_func = torch.nn.CrossEntropyLoss()

    best_validation_acc = 0.0
    for epoch in range(EPOCH):
        for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(data_loader):
            input_ids, attention_mask, token_type_ids = input_ids.to(DEVICE), attention_mask.to(DEVICE), token_type_ids.to(DEVICE)
            out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
            loss = loss_func(out, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % 5 == 0:
                out = out.argmax(dim=1)
                acc = (out==labels).sum().item()/len(labels)
                print(f"epoch:{epoch},i:{i},loss:{loss.item()},acc:{acc}")

        model.eval()
        # Validation doesn't need engaging the pretrained model
        with torch.no_grad():
            validation_acc = 0.0
            validation_loss = 0.0
            for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(validation_data_loader):
                input_ids, attention_mask, token_type_ids = input_ids.to(DEVICE), attention_mask.to(DEVICE), token_type_ids.to(DEVICE)
                out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
                validation_loss += loss_func(out, labels)
                out = out.argmax(dim=1)
                validation_acc += (out==labels).sum().item()
            validation_loss /= len(data_loader)
            validation_acc /= len(data_loader)
            print(f"epoch:{epoch},validation_loss:{validation_loss},validation_acc:{validation_acc}")

            if validation_acc > best_validation_acc:
                best_validation_acc = validation_acc
                torch.save(model.state_dict(), "params/best.pth")
                print(f"epoch:{epoch},best model saved with acc:{best_validation_acc}")

        torch.save(model.state_dict(), "params/last.pth")
        print(f"epoch:{epoch},last model saved")

run_training()

## 5. Testing
- After training, we need to test the model on the test dataset.
- Load the generated parameters model and test it

In [None]:
from torch.utils.data import DataLoader
import torch

dataset_test = MyDataset("test")
test_data_loader = DataLoader (
    dataset_test,
    batch_size=2,
    shuffle=True,
    drop_last=True,
    collate_fn=tokenize_batches
)

def run_testing(param_path="params/best.pth"):
    test_acc = 0.0
    total = 0
    model = MyModel().to(DEVICE)
    model.load_state_dict(torch.load(param_path))
    model.eval()
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(test_data_loader):
        input_ids, attention_mask, token_type_ids = input_ids.to(DEVICE), attention_mask.to(DEVICE), token_type_ids.to(DEVICE)
        out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        out = out.argmax(dim=1)
        test_acc += (out==labels).sum().item()
        print(i, (out==labels).sum().item())
        total += len(labels)
    print(f"test_acc:{test_acc/total}")

run_testing()
run_testing("params/last.pth")