In [1]:
import os
import sys

sys.path.append(
    os.path.abspath(
        os.path.join(os.path.dirname(os.path.abspath("__file__")), "../../..")
    )
)

In [2]:
import json
from textwrap import dedent
from typing import Dict, Literal

from datasets import DatasetDict, load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextGenerationPipeline,
)

from src.shared.config import settings

## Utils

In [3]:
HF_GEMMA_ID = "google/gemma-3-1b-pt"


def tokenizer_init():
    tokenizer = AutoTokenizer.from_pretrained(HF_GEMMA_ID)
    return tokenizer


def format_dataset_row_for_test(row) -> Dict[Literal["text"], str]:
    row = dict(row)
    row.pop("from_account")
    formatted = dedent(
        f"""
        input: {json.dumps(row)}
        label:
        """
    ).strip()
    return {"text": formatted}


def test_dataset_init() -> DatasetDict:
    dataset = load_dataset(f"{settings.hf_user_name}/{settings.hf_dataset_repo_name}")

    dataset["test"] = dataset["test"].map(format_dataset_row_for_test)
    del dataset["train"]

    return dataset

## Pipeline

In [4]:
class AccountClassifcationPipeline(TextGenerationPipeline):
    def _extract_from_account(self, text: str) -> str:
        """Extract the from_account found between 'label:' and '<eos>'"""
        start_idx = text.find("label:") + len("label:")
        eos_idx = text.find("<eos>")

        from_account = text[start_idx:eos_idx].strip()

        return from_account

    def postprocess(self, model_outputs, *args, **kwargs):
        decoded = self.tokenizer.decode(model_outputs["generated_sequence"][0][0])
        account = self._extract_from_account(decoded)
        return account

## Score

In [5]:
def predict_and_score(pipeline: TextGenerationPipeline, dataset: DatasetDict) -> Dict:
    predictions = pipeline(dataset["test"]["text"])
    actual = dataset["test"]["from_account"]
    accuracy = sum(1 for pred, act in zip(predictions, actual) if pred == act) / len(
        predictions
    )

    prediction_actual_pairs = list(zip(predictions, actual))
    return {"accuracy": accuracy, "prediction_actual_pairs": prediction_actual_pairs}


tokenizer = tokenizer_init()
model = AutoModelForCausalLM.from_pretrained(
    f"{settings.hf_user_name}/{settings.gemma_3_1b_causal_finetune_lora}"
)
pipeline = AccountClassifcationPipeline(
    task="text-generation",
    model=model,
    device=settings.device,
    tokenizer=tokenizer,
    max_new_tokens=20,
)

dataset = test_dataset_init()
output = predict_and_score(pipeline, dataset)
output["accuracy"]

Device set to use mps


0.6683937823834197

In [7]:
output["prediction_actual_pairs"]

[('Assets:Discover:Main:HomeImprovement:Other',
  'Assets:Discover:Main:Needs:Other'),
 ('Assets:Discover:Main:Needs:Other', 'Assets:Discover:FutureWants'),
 ('Assets:Discover:Main:Wants:Other', 'Assets:Discover:Main:Wants:Other'),
 ('Assets:Discover:Main:Needs:Groceries',
  'Assets:Discover:Main:Needs:Groceries'),
 ('Assets:Discover:Main:Wants:Other', 'Assets:Discover:Main:Wants:Monthly'),
 ('Assets:Discover:Main:Needs:Groceries', 'Assets:Discover:Main:Needs:Other'),
 ('Assets:Discover:Main:Wants:Other', 'Assets:Discover:Main:Wants:Other'),
 ('Assets:Discover:Main:Needs:Groceries',
  'Assets:Discover:Main:Needs:Groceries'),
 ('Assets:Discover:Main:Wants:Other', 'Assets:Discover:Main:Wants:Other'),
 ('Assets:Discover:Main:Wants:Groceries',
  'Assets:Discover:Main:Needs:Groceries'),
 ('Assets:Discover:Main:Wants:Other', 'Assets:Discover:Main:Wants:Other'),
 ('Assets:Discover:FutureWants', 'Assets:Discover:Main:Wants:Other'),
 ('Assets:Discover:Main:Needs:Gas', 'Assets:Discover:Main:Need

It is outputting some categories that aren't even real, i.e. 'Assets:Discover:Main:HomeImprovement:Other', I think I can improve this model with better prompting.

In [None]:
dataset = test_dataset_init()

example = dataset["test"]["text"][-1]
actual = dataset["test"]["from_account"][-1]
prediction = pipeline(example)

print(example, actual, prediction, sep="\n\n")

input: {"transaction_date": "2025-03-18", "description": "MCDONALD'S F33943", "amount": 7.9, "category": "Food & Drink", "category_source": "Chase", "card": "Chase Sapphire Preferred", "day_of_week": "Tuesday"}
label:

Assets:Discover:Main:Wants:Other

Assets:Discover:Main:Wants:Other
