# Load Module

In [1]:
import os

os.chdir("..")

In [2]:
os.environ["WANDB_PROJECT"] = "spell-correction"

In [3]:
import transformers
from datasets import load_dataset
import evaluate
from transformers import AutoTokenizer
import sentencepiece
import copy

In [4]:
import time

In [5]:
# wandb.login()

# Global Config

In [6]:
MODEL_CHECKPOINT = "csebuetnlp/banglat5_small"

In [7]:
MODEL_NAME = "spell-correction"

In [8]:
SPLIT_CONFIG = {
    "train": "./datasets/correction_train.jsonl",
    # "test": "./datasets/correction_test.jsonl",
    "val": "./datasets/correction_val.jsonl",
}

In [9]:
MAX_INPUT_LENGTH = 128
MAX_TARGET_LENGTH = 128

In [10]:
DATASET_NAME = "spell_correction_dataset"
BATCH_SIZE = 16
EPOCHS = 3
ACCUMULATION_STEPS = 1
LR = 2e-4

# Load Dataset

In [11]:
from utils.tokenizer import TokenizerPreprocessor

In [12]:
raw_datasets = load_dataset("json", data_files=SPLIT_CONFIG)

In [13]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT, legacy=True)

In [14]:
tp = TokenizerPreprocessor(
    tokenizer=tokenizer, max_input_length=MAX_INPUT_LENGTH, max_target_length=MAX_TARGET_LENGTH
)

In [15]:
tokenized_datasets = raw_datasets.map(tp, batched=True)

# Configure Trainer

In [16]:
from transformers import EarlyStoppingCallback

from transformers import (
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)

In [17]:
def model_init():
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_CHECKPOINT)
    return model

In [18]:
data_collator = DataCollatorForSeq2Seq(tokenizer)

# Tune Hyper params

In [19]:
from utils.metrics import compute_objective, CustomTrainer
import optuna

In [20]:
HP_TRAIN_BATCH = 5
HP_VAL_BATCH = 3
HP_BATCH_CHOICES = [16]
HP_LR_RANGE = [2e-5, 2e-4, 2e-3]
HP_EPOCH_RANGE = [3]
HP_ACCUMULATION_STEPS = [1, 2, 4]
HP_RUNS = 10

In [21]:
def objective(trial):
    # learning_rate = trial.suggest_float("learning_rate", 1e-5, 5e-5, log=True)
    learning_rate = trial.suggest_categorical("learning_rate", HP_LR_RANGE)
    per_device_train_batch_size = trial.suggest_categorical("per_device_train_batch_size", HP_BATCH_CHOICES)
    accumulation_steps = trial.suggest_categorical("accumulation_steps", HP_ACCUMULATION_STEPS)
    epoch = trial.suggest_categorical("epoch", HP_EPOCH_RANGE)

    model = model_init()
    
    training_args = Seq2SeqTrainingArguments(
        f"./results/hparams_tuner_{time.time()}",
        evaluation_strategy="epoch",
        learning_rate=learning_rate,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        weight_decay=0.01,
        # save_total_limit=3,
        # save_strategy="epoch",
        num_train_epochs=EPOCHS,
        predict_with_generate=True,
        fp16=False,
        logging_steps=1,
        push_to_hub=False,
        # load_best_model_at_end=True,
        logging_strategy="steps",
        gradient_accumulation_steps=ACCUMULATION_STEPS,
        report_to=[],
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        dataloader_pin_memory=True
    )

    trainer = CustomTrainer(
        None,
        training_args,
        train_dataset=tokenized_datasets["train"]
        .select(range(BATCH_SIZE * HP_TRAIN_BATCH)),
        eval_dataset=tokenized_datasets["val"]
        .select(range(BATCH_SIZE * HP_VAL_BATCH)),
        data_collator=data_collator,
        tokenizer=tokenizer,
        model_init=model_init,
    )

    trainer.train()
    eval_results = trainer.evaluate()
    return eval_results["eval_loss"]


study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=HP_RUNS)


print("Best trial:")
trial_ = study.best_trial

print(f" Value: {trial_.value}")
print(" Params: ")
for key, value in trial_.params.items():
    print(f"    {key}: {value}")

[I 2024-02-03 20:28:25,689] A new study created in memory with name: no-name-70e54cb2-1e80-4129-8222-44a72b56d0d4
You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Bleu,Gen Len
1,4.7596,3.658498,3.7275,10.7083
2,4.1274,3.568081,2.5166,9.1458
3,2.9362,3.540076,2.0561,8.6875


[I 2024-02-03 20:28:37,674] Trial 0 finished with value: 3.5400760173797607 and parameters: {'learning_rate': 0.0002, 'per_device_train_batch_size': 16, 'accumulation_steps': 4, 'epoch': 3}. Best is trial 0 with value: 3.5400760173797607.


Epoch,Training Loss,Validation Loss,Bleu,Gen Len
1,4.7596,3.658498,3.7275,10.7083
2,4.1274,3.568081,2.5166,9.1458
3,2.9362,3.540076,2.0561,8.6875


[I 2024-02-03 20:28:46,581] Trial 1 finished with value: 3.5400760173797607 and parameters: {'learning_rate': 0.0002, 'per_device_train_batch_size': 16, 'accumulation_steps': 2, 'epoch': 3}. Best is trial 0 with value: 3.5400760173797607.


Epoch,Training Loss,Validation Loss,Bleu,Gen Len
1,4.7596,3.658498,3.7275,10.7083
2,4.1274,3.568081,2.5166,9.1458
3,2.9362,3.540076,2.0561,8.6875


[I 2024-02-03 20:28:54,675] Trial 2 finished with value: 3.5400760173797607 and parameters: {'learning_rate': 0.0002, 'per_device_train_batch_size': 16, 'accumulation_steps': 2, 'epoch': 3}. Best is trial 0 with value: 3.5400760173797607.


Epoch,Training Loss,Validation Loss,Bleu,Gen Len
1,4.131,3.127286,3.2898,7.9375
2,2.8877,2.962231,14.6658,11.7917
3,1.9531,2.912153,17.3259,12.7292


[I 2024-02-03 20:29:02,698] Trial 3 finished with value: 2.912153482437134 and parameters: {'learning_rate': 0.002, 'per_device_train_batch_size': 16, 'accumulation_steps': 1, 'epoch': 3}. Best is trial 3 with value: 2.912153482437134.


Epoch,Training Loss,Validation Loss,Bleu,Gen Len
1,4.131,3.127286,3.2898,7.9375
2,2.8877,2.962231,14.6658,11.7917
3,1.9531,2.912153,17.3259,12.7292


[I 2024-02-03 20:29:11,225] Trial 4 finished with value: 2.912153482437134 and parameters: {'learning_rate': 0.002, 'per_device_train_batch_size': 16, 'accumulation_steps': 1, 'epoch': 3}. Best is trial 3 with value: 2.912153482437134.


Epoch,Training Loss,Validation Loss,Bleu,Gen Len
1,4.7596,3.658498,3.7275,10.7083
2,4.1274,3.568081,2.5166,9.1458
3,2.9362,3.540076,2.0561,8.6875


[I 2024-02-03 20:29:18,910] Trial 5 finished with value: 3.5400760173797607 and parameters: {'learning_rate': 0.0002, 'per_device_train_batch_size': 16, 'accumulation_steps': 1, 'epoch': 3}. Best is trial 3 with value: 2.912153482437134.


Epoch,Training Loss,Validation Loss,Bleu,Gen Len
1,4.7596,3.658498,3.7275,10.7083
2,4.1274,3.568081,2.5166,9.1458
3,2.9362,3.540076,2.0561,8.6875


[I 2024-02-03 20:29:27,220] Trial 6 finished with value: 3.5400760173797607 and parameters: {'learning_rate': 0.0002, 'per_device_train_batch_size': 16, 'accumulation_steps': 2, 'epoch': 3}. Best is trial 3 with value: 2.912153482437134.


Epoch,Training Loss,Validation Loss,Bleu,Gen Len
1,5.3558,3.97334,1.3769,13.1042
2,5.0499,3.938274,1.7383,12.9167
3,3.5103,3.926643,1.6025,12.7292


[I 2024-02-03 20:29:35,359] Trial 7 finished with value: 3.9266433715820312 and parameters: {'learning_rate': 2e-05, 'per_device_train_batch_size': 16, 'accumulation_steps': 2, 'epoch': 3}. Best is trial 3 with value: 2.912153482437134.


Epoch,Training Loss,Validation Loss,Bleu,Gen Len
1,5.3558,3.97334,1.3769,13.1042
2,5.0499,3.938274,1.7383,12.9167
3,3.5103,3.926643,1.6025,12.7292


[I 2024-02-03 20:29:43,720] Trial 8 finished with value: 3.9266433715820312 and parameters: {'learning_rate': 2e-05, 'per_device_train_batch_size': 16, 'accumulation_steps': 4, 'epoch': 3}. Best is trial 3 with value: 2.912153482437134.


Epoch,Training Loss,Validation Loss,Bleu,Gen Len
1,5.3558,3.97334,1.3769,13.1042
2,5.0499,3.938274,1.7383,12.9167
3,3.5103,3.926643,1.6025,12.7292


[I 2024-02-03 20:29:52,445] Trial 9 finished with value: 3.9266433715820312 and parameters: {'learning_rate': 2e-05, 'per_device_train_batch_size': 16, 'accumulation_steps': 2, 'epoch': 3}. Best is trial 3 with value: 2.912153482437134.


Best trial:
 Value: 2.912153482437134
 Params: 
    learning_rate: 0.002
    per_device_train_batch_size: 16
    accumulation_steps: 1
    epoch: 3


# End