In [1]:
import os
import sys

sys.path.append(
    os.path.abspath(
        os.path.join(os.path.dirname(os.path.abspath("__file__")), "../../..")
    )
)

In [2]:
import json
from functools import partial
from textwrap import dedent
from typing import Dict, Literal

from datasets import DatasetDict, load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    EarlyStoppingCallback,
    Trainer,
    TrainingArguments,
)

from src.shared.config import settings

## Utils

In [3]:
HF_GEMMA_ID = "google/gemma-3-1b-pt"


def tokenizer_init():
    tokenizer = AutoTokenizer.from_pretrained(HF_GEMMA_ID)
    return tokenizer


def model_init():
    model = AutoModelForCausalLM.from_pretrained(HF_GEMMA_ID)
    return model.to(settings.device)


def format_dataset_row(
    row, for_test: bool, unique_from_accounts: set, eos_token: str
) -> Dict[Literal["text"], str]:
    row = dict(row)
    from_account = row.pop("from_account")
    formatted = dedent(
        f"""
        <possible accounts>{', '.join(unique_from_accounts)}</possible accounts>
        <transaction>{json.dumps(row)}</transaction>
        which account did this transaction come from?
        answer: {'' if for_test else f"{from_account}{eos_token}"}
    """
    ).strip()
    return {"text": formatted}


def training_dataset_init(tokenizer) -> DatasetDict:
    dataset = load_dataset(f"{settings.hf_user_name}/{settings.hf_dataset_repo_name}")

    unique_from_accounts = set(
        dataset["train"]["from_account"] + dataset["test"]["from_account"]
    )
    eos_token = tokenizer.eos_token
    format_for_train = partial(
        format_dataset_row,
        for_test=False,
        unique_from_accounts=unique_from_accounts,
        eos_token=eos_token,
    )

    dataset["train"] = dataset["train"].map(format_for_train)
    dataset["validation"] = dataset["test"].map(format_for_train)
    del dataset["test"]

    remove_columns = [
        "amount",
        "month",
        "day",
        "year",
        "vendor",
        "from_account",
        "text",
    ]
    dataset = dataset.map(
        lambda batch: tokenizer(batch["text"]),
        batched=True,
        remove_columns=remove_columns,
    )
    dataset.set_format("pt")
    return dataset

In [4]:
dataset = training_dataset_init(tokenizer_init())

## Identify Best Learning Rate

In [None]:
def identify_best_learning_rate():
    learning_rate_train_dataset = dataset["train"].select(range(64))
    learning_rate_validation_dataset = dataset["validation"].select(range(64))

    training_args = TrainingArguments(
        output_dir="/tmp/lr_search",
        num_train_epochs=1,
        per_device_train_batch_size=64,
        per_device_eval_batch_size=64,
        weight_decay=0.01,
        eval_strategy="epoch",
        logging_strategy="epoch",
        disable_tqdm=False,
        push_to_hub=False,
        log_level="error",
        save_strategy="no",
        report_to="none",
    )

    trainer = Trainer(
        model_init=model_init,
        args=training_args,
        data_collator=DataCollatorForLanguageModeling(
            tokenizer=tokenizer_init(), mlm=False
        ),
        train_dataset=learning_rate_train_dataset,
        eval_dataset=learning_rate_validation_dataset,
    )

    best_run = trainer.hyperparameter_search(
        hp_space=lambda trial: {
            "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-3, log=True)
        },
        n_trials=5,
        direction="minimize",
    )

    return best_run


identify_best_learning_rate()

## Training Run

In [5]:
n_epochs = 1
best_learning_rate = 4.9437983806413086e-05

tokenizer = tokenizer_init()
dataset = training_dataset_init(tokenizer)

training_args = TrainingArguments(
    output_dir="/tmp/gemma_3_1b_causal_finetune",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=n_epochs,
    learning_rate=best_learning_rate,
    weight_decay=0.01,
    eval_strategy="epoch",
    logging_strategy="epoch",
    save_strategy="epoch",
    push_to_hub=True,
    hub_model_id=f"{settings.hf_user_name}/{settings.gemma_3_1b_causal_finetune}",
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    load_best_model_at_end=True,
)

trainer = Trainer(
    model_init=model_init,
    args=training_args,
    data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)

trainer.train()

It is strongly recommended to train Gemma3 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.


Epoch,Training Loss,Validation Loss
1,0.4075,0.165474


There were missing keys in the checkpoint model loaded: ['lm_head.weight'].


TrainOutput(global_step=24, training_loss=0.40745850404103595, metrics={'train_runtime': 420.5645, 'train_samples_per_second': 1.767, 'train_steps_per_second': 0.057, 'total_flos': 517877961667584.0, 'train_loss': 0.40745850404103595, 'epoch': 1.0})

This was bumping up to around 120 GB of ram.

In [6]:
trainer.push_to_hub()

CommitInfo(commit_url='https://huggingface.co/jacob-danner/gemma_3_1b_causal_finetune/commit/d57dbe313855fe90e612cf39f5b4185a570e88b8', commit_message='End of training', commit_description='', oid='d57dbe313855fe90e612cf39f5b4185a570e88b8', pr_url=None, repo_url=RepoUrl('https://huggingface.co/jacob-danner/gemma_3_1b_causal_finetune', endpoint='https://huggingface.co', repo_type='model', repo_id='jacob-danner/gemma_3_1b_causal_finetune'), pr_revision=None, pr_num=None)