# Fine-Tuning BERT

> Let's try fine-tuning a BERT model for reddit rule violation prediction!

## Idea

- replace model's top layer with binary classification head
- fine tune our LLM on training data
    - actually, this would allow us to quintuple our training data, by also including positive and negative examples
    - could even include negative and positive examples from test data
- we need to create a custom huggingface dataset from the csv or a dictionary

In [1]:
import kagglehub
import numpy as np
import pandas as pd
import transformers
import torch
import seaborn as sns

# in Kaggle, add evaluate to depedencies!
# however, you need internet to install additional dependencies...
# import evaluate

from datasets import Dataset, load_dataset
from pathlib import Path
from torch.nn.functional import softmax
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TrainingArguments, Trainer, DataCollatorWithPadding
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score


COMPETITION_HANDLE = "jigsaw-agile-community-rules"

BERT_HANDLE = "google-bert/bert-base-cased"
BERT_PATH = "/kaggle/input/bert-base-cased/transformers/default/1"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


### Load model

- need a model for Sequence Classification
- would like to use a BERT model for sequence classification
    - unfortunately, `kagglehub` does not offer `BERT` for the `transformers` framework
    - could use transformers `TFAutoModel` class to instantiate `tensorflow` models
    - but we cannot use `tensorflow` in our local environment, as it requires CUDA 12.3, and we only have CUDA 12.2
    - could use `transformers` to download `bert-base-cased`
    - but if we want to submit our notebook to the competition, it is supposed to run without `internet access`!
    - so downloads from `transformers` will fail...
- alternative: stick with `gemma-3` for sequence classification
    - however: transformers does not seem to support `gemma-3-1b` for sequence classification out of the box...
    - however, does support `gemma-3-4b` for sequence classification though
    - but `gemma-3-4b` is too big to be fine-tuned!
        - while for inference, gemma-3-4b only requires `2 byte * 4b params + kv-cache ~ 10GB` VRAM for inference, for fine-tuning we also need to store all activations on the GPU for backpropagating the gradients...
        - Quantization doesn't help, because the parameters are inflated to 16 bit again
- conclusion: use bert, download it beforehand and upload it on kaggle

In [2]:
try:
    # try to download BERT from huggingface
    # add classification head to model
    # instantiate BERT, but use 2 output classification head (num_labels=2)
    bert_model = AutoModelForSequenceClassification.from_pretrained(BERT_HANDLE, num_labels=2)
    bert_model.to(DEVICE)
    bert_tokenizer = AutoTokenizer.from_pretrained(BERT_HANDLE)
except OSError:
    # if we have no internet connection, load bert from local path instead
    bert_model = AutoModelForSequenceClassification.from_pretrained(BERT_PATH, num_labels=2)
    bert_model.to(DEVICE)
    bert_tokenizer = AutoTokenizer.from_pretrained(BERT_PATH)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


When instantiating the Sequence Classification model, you will probably see a warning like: 

> Some weights were not initialized from the model checkpoint and are newly initialized: [...]
> You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

The pretrained model already brings it's own classification head.
But since we only want to do binary classification if the reddit rule is violated or not, we replaced the classification head with just 2 output neurons.
However, the parameters (weights) of the classification head are now randomly initialized and not adjusted to the task at all.
If we would use the model as is, we will receive random Yes or No predictions for rule violations.

Therefore, we need to *fine-tune* the model to the downstream task on our dataset of reddit comments.

### Data Preparation

Since we are now finetuning the model instead of doing few-shot In-Context-Learning, we don't actually need the positive and negative examples in the prompt.
However, instead of just discarding them, we could use them as additional training examples, effectively quintupling the size of our training dataset!
We will save this preprocessed dataset as csv.

We need to wrap our dataset in a `transformers` Dataset object, which the Trainer expects.
Then, we need to construct the right input prompts given our features.

In [3]:
DATA_PATH = Path(kagglehub.competition_download(COMPETITION_HANDLE))

print(DATA_PATH)

/root/.cache/kagglehub/competitions/jigsaw-agile-community-rules


In [4]:
TRAIN_PATH = DATA_PATH / "train.csv"
TEST_PATH = DATA_PATH / "test.csv"
SAMPLE_PATH = DATA_PATH / "sample_submission.csv"


train_df = pd.read_csv(TRAIN_PATH)
test_df = pd.read_csv(TEST_PATH)
sample_df = pd.read_csv(SAMPLE_PATH)

In [5]:
def label_examples(df: pd.DataFrame):
    positive_mask = df["variable"].str.contains("positive")
    negative_mask = df["variable"].str.contains("negative")

    df.loc[negative_mask, "labels"] = 0
    df.loc[positive_mask, "labels"] = 1

    return df


def preprocess_train(df: pd.DataFrame):
    df = df.rename(columns={"rule_violation": "labels"})
    df = df.melt(id_vars=["row_id", "rule", "subreddit", "labels"])
    df = label_examples(df)
    return df

def preprocess_test(df: pd.DataFrame):
    example_cols = [c for c in df.columns if "example" in c]
    df = df.drop(columns=example_cols)
    df = df.melt(id_vars=["row_id", "rule", "subreddit"])
    return df

def preprocess_val(df: pd.DataFrame):
    df = df.drop(columns="body")
    df = df.melt(id_vars=["row_id", "rule", "subreddit"])
    df = label_examples(df)
    return df

In [6]:
TUNE_PATH = Path("data/tune/")
TUNE_PATH.mkdir(exist_ok=True, parents=True)

TUNE_TRAIN_PATH = TUNE_PATH / "train.csv"
TUNE_VAL_PATH = TUNE_PATH / "val.csv"
TUNE_TEST_PATH = TUNE_PATH / "test.csv"


tune_train_df = preprocess_train(train_df)
tune_train_df.to_csv(TUNE_TRAIN_PATH, index=False)

tune_val_df = preprocess_val(test_df)
tune_val_df.to_csv(TUNE_VAL_PATH, index=False)


tune_test_df = preprocess_test(test_df)
tune_test_df.to_csv(TUNE_TEST_PATH, index=False)

tune_train_df.head()


Unnamed: 0,row_id,rule,subreddit,labels,variable,value
0,0,"No Advertising: Spam, referral links, unsolici...",Futurology,0,body,Banks don't want you to know this! Click here ...
1,1,"No Advertising: Spam, referral links, unsolici...",soccerstreams,0,body,SD Stream [ ENG Link 1] (http://www.sportsstre...
2,2,No legal advice: Do not offer or request legal...,pcmasterrace,1,body,Lol. Try appealing the ban and say you won't d...
3,3,"No Advertising: Spam, referral links, unsolici...",sex,1,body,she will come your home open her legs with an...
4,4,"No Advertising: Spam, referral links, unsolici...",hearthstone,1,body,code free tyrande --->>> [Imgur](http://i.imgu...


In [7]:
from datasets import Dataset, load_dataset

tune_ds = load_dataset("csv", data_files={
    "train": str(TUNE_TRAIN_PATH),
    "val": str(TUNE_VAL_PATH),
    "test": str(TUNE_TEST_PATH)
})

Generating train split: 0 examples [00:00, ? examples/s]

Generating val split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Great, now we need to map our prompt template and the tokenizer on each example:

In [8]:
def apply_template(example):
    template = """\
    [SUBREDDIT]
    {subreddit}
    [RULE]
    {rule}
    [CONTENT]
    {value}
    """
    example["prompt"] = template.format(subreddit=example["subreddit"], rule=example["rule"], value=example["value"])
    return example

tune_ds = tune_ds.map(apply_template)

print(tune_ds["train"][0]["prompt"])

Map:   0%|          | 0/10145 [00:00<?, ? examples/s]

Map:   0%|          | 0/40 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

    [SUBREDDIT]
    Futurology
    [RULE]
    No Advertising: Spam, referral links, unsolicited advertising, and promotional content are not allowed.
    [CONTENT]
    Banks don't want you to know this! Click here to know more!
    


In [9]:
def apply_tokenizer(examples):
    return bert_tokenizer(examples["prompt"], padding="max_length", truncation=True)

tune_ds = tune_ds.map(apply_tokenizer, batched=True)

print(tune_ds["train"][0]["input_ids"])

Map:   0%|          | 0/10145 [00:00<?, ? examples/s]

Map:   0%|          | 0/40 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

[101, 164, 156, 2591, 26166, 10069, 17243, 1942, 166, 14763, 20362, 4807, 164, 155, 2591, 17516, 166, 1302, 25010, 131, 23665, 1306, 117, 5991, 4412, 6743, 117, 8362, 24313, 22308, 1174, 6437, 117, 1105, 10626, 3438, 1132, 1136, 2148, 119, 164, 18732, 15681, 11680, 1942, 166, 10117, 1274, 112, 189, 1328, 1128, 1106, 1221, 1142, 106, 140, 13299, 1303, 1106, 1221, 1167, 106, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

### Finetuning the classifier

In [None]:
# metric = evaluate.load("accuracy")

training_args = TrainingArguments(
    output_dir="models/reddit_classifier",
    eval_strategy="epoch",
    # in kaggle, wandb is installed and will raise an error without API token...
    report_to="none"
)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    probas = softmax(torch.from_numpy(logits), dim=1).numpy()
    predictions = np.argmax(logits, axis=-1)

    # compute via evaluate
    # return metric.compute(predictions=predictions, references=labels)

    metrics = {
        "accuracy": accuracy_score(labels, predictions),
        "f1": f1_score(labels, predictions),
        "roc_auc": roc_auc_score(labels, probas[:, 1])
    }
    return metrics


# ensures, that all inputs in a batch are padded to the same length
# data_collator = DataCollatorWithPadding(tokenizer=bert_tokenizer, padding=True)


trainer = Trainer(
    model=bert_model,
    args=training_args,
    train_dataset=tune_ds["train"],
    eval_dataset=tune_ds["val"],
    compute_metrics=compute_metrics,
#    data_collator=data_collator
)

In [11]:
print("start fine-tuning")

trainer.train()

start fine-tuning


Epoch,Training Loss,Validation Loss,Accuracy,F1,Roc Auc
1,No log,0.689457,0.5,0.666667,0.6725
2,No log,0.684196,0.525,0.344828,0.6975
3,No log,0.668402,0.65,0.708333,0.7075


TrainOutput(global_step=39, training_loss=0.6830650231777093, metrics={'train_runtime': 9.5962, 'train_samples_per_second': 31.263, 'train_steps_per_second': 4.064, 'total_flos': 78933316608000.0, 'train_loss': 0.6830650231777093, 'epoch': 3.0})

### Predict on Test examples

In [12]:
pred = trainer.predict(tune_ds["test"])
pred_labels = np.argmax(pred.predictions, axis=1)

pred_proba = softmax(torch.from_numpy(pred.predictions), dim=1).numpy()[:, 1]
pred_proba

array([0.5096333 , 0.5220407 , 0.5062994 , 0.47678885, 0.52223575,
       0.6111362 , 0.57424563, 0.43309432, 0.5075619 , 0.5197084 ],
      dtype=float32)

In [13]:
submission_df = pd.DataFrame({"row_id": test_df["row_id"], "rule_violation": pred_proba})
submission_df.to_csv("submission.csv", index=False)

## Bookmarks

- kaggle:
    - https://github.com/Kaggle/kagglehub
    - https://www.kaggle.com/models
- huggingface:
    - https://huggingface.co/
    - https://huggingface.co/docs/transformers/en/training
    - https://huggingface.co/google-bert/bert-base-uncased