In [None]:
import numpy as np
import optuna
from sklearn.metrics import accuracy_score
from transformers import Trainer, TrainingArguments

def train_and_evaluate(trial):
    learning_rate = trial.suggest_loguniform('learning_rate', 1e-5, 5e-5)
    batch_size = trial.suggest_categorical('batch_size', [8, 16, 32])
    num_epochs = trial.suggest_int('num_epochs', 3, 5)

    training_args = TrainingArguments(
        output_dir='./results',
        num_train_epochs=num_epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        learning_rate=learning_rate,
        warmup_steps=500,
        weight_decay=0.01,
        logging_dir='./logs',
        evaluation_strategy="epoch",
        logging_strategy="steps",
        logging_steps=100,
        save_strategy="epoch",
        fp16=True
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        compute_metrics=compute_metrics
    )

    trainer.train()
    metrics = trainer.evaluate()
    return metrics["eval_accuracy"]

def compute_metrics(p):
    pred, labels = p
    preds = np.argmax(pred, axis=1)
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc}

study = optuna.create_study(direction='maximize')
study.optimize(train_and_evaluate, n_trials=20)

print(f"Best trial: {study.best_trial.value}")
print(f"Best params: {study.best_trial.params}")


In [None]:
best_params = study.best_trial.params

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=best_params['num_epochs'],
    per_device_train_batch_size=best_params['batch_size'],
    per_device_eval_batch_size=best_params['batch_size'],
    learning_rate=best_params['learning_rate'],
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    evaluation_strategy="epoch",
    logging_strategy="steps",
    logging_steps=100,
    save_strategy="epoch",
    fp16=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics
)

trainer.train()
trainer.evaluate()
