https://docs.wandb.ai/guides/sweeps/define-sweep-configuration

https://colab.research.google.com/github/wandb/examples/blob/master/colabs/pytorch/Organizing_Hyperparameter_Sweeps_in_PyTorch_with_W%26B.ipynb

https://huggingface.co/docs/transformers/hpo_train

In [None]:
# imports
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, EarlyStoppingCallback
from datasets import load_dataset
import evaluate
import numpy as np

In [None]:
# init vars
metric = evaluate.load("sacrebleu")
train_dataset = load_dataset('pandas', data_files='/home/j/Documents/Projects/MLotsawa/data/size-selection-data/1M-train.p')
eval_dataset = load_dataset('pandas', data_files='/home/j/Documents/Projects/MLotsawa/data/size-selection-data/100k-eval.p')
checkpoint = "google-t5/t5-large"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

In [None]:
# preproc

source_lang = 'bo'
target_lang = 'en'
prefix = "translate Tibetan to English: "

def preprocess_function(examples):

    inputs = [prefix + example[source_lang] for example in examples['translation']]
    targets = [example[target_lang] for example in examples['translation']]
    
    model_inputs = tokenizer(inputs, text_target=targets, max_length=128, truncation=True)

    return model_inputs

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True)
tokenized_eval_dataset = eval_dataset.map(preprocess_function, batched=True)

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir=".",
    learning_rate=2e-5,
    auto_find_batch_size=True,
    weight_decay=0.01,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=False,
    push_to_hub=False,
    eval_strategy='epoch', # CHANGE THESE !!!
    save_strategy='epoch',
    load_best_model_at_end=True,
    optim="adafactor"
)

trainer = Seq2SeqTrainer(
    model=None,
    args=training_args,
    train_dataset=tokenized_train_dataset['train'],
    eval_dataset=tokenized_eval_dataset['train'],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback()]
)

In [None]:
def model_init(trial):

    return AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-large", device_map="auto")

In [None]:
sweep_configuration = {
    "name": "sweepdemo",
    "method": "bayes",
    "metric": {"goal": "minimize", "name": "validation_loss"},
    "parameters": {
        "learning_rate": {"min": 0.0001, "max": 0.1},
        "batch_size": {"values": [16, 32, 64]},
        "epochs": {"values": [5, 10, 15]},
        "optimizer": {"values": ["adam", "sgd"]},
    },
}

In [None]:
def wandb_hp_space(trial):

    return {

        "method": "random",

        "metric": {"name": "objective", "goal": "minimize"},

        "parameters": {

            "learning_rate": {"distribution": "uniform", "min": 1e-6, "max": 1e-4},

            "per_device_train_batch_size": {"values": [16, 32, 64, 128]},

        },

    }

In [None]:
best_trial = trainer.hyperparameter_search(

    direction="maximize",

    backend="optuna",

    hp_space=wandb_hp_space,

    n_trials=20,

    compute_objective=compute_objective,

)