In [None]:
import wandb
import yaml

from pathlib import Path
from torch.utils.data import DataLoader

from classifier.file_reader import read_files_from_folder
from classifier.dataset import BertPandasDataset, collate_fn, create_bert_datasets, preprocess_dataframe
from classifier.model import ContinualMultilabelBERTClassifier, MultilabelBERTClassifier

FOLDER_PATH = Path("train_classifier.ipynb").parent.absolute()
print(FOLDER_PATH)


In [None]:
SEED = 42
DATASET = "boolq"
MODEL_NAME = "answerdotai/ModernBERT-base"
MINIBATCH_SIZE = 64
N_EPOCHS = 50
TEST_VAL_SET_SIZE = 0.15

benchmark_config_path = Path(f"{FOLDER_PATH.parent}/config/messplus/boolq.yaml")

# Read and parse the YAML file
with benchmark_config_path.open("r") as f:
    classifier_config = yaml.safe_load(f)["classifier_model"]

f.close()

df = read_files_from_folder(f"{FOLDER_PATH.parent}/data/inference_outputs/boolq", file_ext=".csv")
display(df.head())

In [None]:
display(len(df["input_text"]))

In [None]:
text_col = ["input_text"]
label_cols = ["label_small", "label_medium"]

dataset = df[text_col + label_cols]
dataset = preprocess_dataframe(dataset, label_cols=label_cols)

# Create train and validation datasets
train_dataset, val_dataset, tokenizer = create_bert_datasets(
    dataset,
    text_col,
    label_cols,
    model_name=MODEL_NAME,
    max_length=1024,
    val_ratio=0.10,
)

# Create DataLoaders with the custom collate function
train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=16,
    collate_fn=collate_fn
)

display(f"Training dataset size: {len(train_dataset)}")
display(f"Validation dataset size: {len(val_dataset)}")

## Full model training
Training the full model yields strong results but shows overfitting behavior very quickly.
We also exhibit local batch instabilities (observable from loss spikes).
I tried to adjust the classifier architecture to account for those instabilities.
We might need some form of regularization to treat the losses.

In [None]:
classifier = MultilabelBERTClassifier(
    model_name=MODEL_NAME,  # Replace with your preferred BERT variant
    num_labels=len(label_cols),
    learning_rate=1e-3,
    momentum=0.85,
    weight_decay=0.01,
    batch_size=16,
    max_length=128,
    warmup_ratio=0.05,
    threshold=0.5,
    freeze_bert_layers=True,
    config=classifier_config,
)

with wandb.init(
    entity="tum-i13",
    project="mess-plus-classifier-training-offline",
    name="minibatch_size-16-mom-0.9"
):

    # Train the model
    classifier.fit(train_dataset, val_dataset, epochs=1, early_stopping_patience=2)

wandb.finish()


In [None]:
classifier.predict(texts=[
    "does ethanol take more energy make that produces",
    "is the liver part of the excretory system"
])

## Continuous learning approach

In [None]:
# cont_model = ContinualMultilabelBERTClassifier(
#     model_name=MODEL_NAME,  # Replace with your preferred BERT variant
#     num_labels=len(label_cols),
#     learning_rate=8e-7,
#     weight_decay=0.01,
#     batch_size=16,
#     max_length=128,
#     warmup_ratio=0.1,
#     threshold=0.5,
#     freeze_bert_layers=True,
#     memory_size=0
# )
#
#
# for idx in range(len(dataset)):
#     print(f"Fetching sample {idx}/{len(dataset)}...")
#     sample = BertPandasDataset(df.loc[idx], text_col, label_cols, tokenizer, 128)
#     cont_model.incremental_fit(
#         new_train_dataset=sample,
#         new_val_dataset=val_dataset,
#     )
#
#     if idx % 50 == 0 and idx != 0:
#         display(f"Done.")
#         break
