# Span Detection hyperparameter search

1. Import dependencies

In [None]:
import random
import time

import numpy as np
import torch
import os

from transformers import (
    AutoConfig,
    AutoTokenizer,
    DataCollatorForTokenClassification,
    AutoModelForTokenClassification,
    TrainingArguments,
    Trainer,
)

from src.util.torch_device import resolve_torch_device
from src.data.span_detection_ds import ManipulationDetectionDataset
from src.definitions import (
    MODELS_FOLDER,
    RAW_DATA_FOLDER,
    PROCESSED_DATA_FOLDER,
)
from src.model.span_detection_metrics import compute_metrics

2. Prepare Env

In [None]:
random_seed = 42

random.seed(random_seed)
torch.manual_seed(random_seed)
np.random.seed(random_seed)

device = resolve_torch_device()

model_checkpoint = "FacebookAI/xlm-roberta-large"

epoch_time = int(time.time())

os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"

3. Load dataset

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

dataset_blueprint = ManipulationDetectionDataset(
    tokenizer=tokenizer,
    raw_path=RAW_DATA_FOLDER / "span-detection.parquet",
    processed_path=PROCESSED_DATA_FOLDER / "span-detection",
    seed=random_seed,
)

dataset = dataset_blueprint.read()

3. Prepare model

In [None]:
training_args = TrainingArguments(
    output_dir=MODELS_FOLDER / "span-detection-checkpoint",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    weight_decay=0.8,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    seed=random_seed,
    logging_steps=200,
    auto_find_batch_size=True,
    torch_empty_cache_steps=1000,
)

config = AutoConfig.from_pretrained(
    model_checkpoint,
    num_labels=len(dataset_blueprint.label2id),
    id2label=dataset_blueprint.id2label,
    label2id=dataset_blueprint.label2id,
    dropout=0.6,
)


def optuna_hp_space(trial):
    return {}


def model_init(trial):
    if trial:
        config.update(
            {
                "dropout": trial.suggest_float("dropout", 0.1, 0.8, log=True),
            }
        )
        training_args.learning_rate = trial.suggest_float(
            "learning_rate", 1e-6, 1e-4, log=True
        )
        training_args.weight_decay = trial.suggest_float(
            "weight_decay", 0.01, 0.9, log=True
        )

    return AutoModelForTokenClassification.from_pretrained(
        model_checkpoint,
        config=config,
    )


data_collator = DataCollatorForTokenClassification(tokenizer)


trainer = Trainer(
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics(dataset_blueprint),
    model_init=model_init,
)

4. Run HP search

In [None]:
def compute_objective(metrics):
    return metrics["eval_f1"]


best_trial = trainer.hyperparameter_search(
    direction="maximize",
    backend="optuna",
    hp_space=optuna_hp_space,
    n_trials=30,
    compute_objective=compute_objective,
)

In [None]:
best_trial