In [26]:
import os
import sys

sys.path.append(
    os.path.abspath(
        os.path.join(os.path.dirname(os.path.abspath("__file__")), "../../..")
    )
)

In [27]:
import json
from functools import partial
from textwrap import dedent
from typing import Dict, Literal

from datasets import DatasetDict, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline, pipeline
from transformers.generation.stopping_criteria import (
    StoppingCriteria,
    StoppingCriteriaList,
)

from src.shared.config import settings

## Utils

In [28]:
HF_GEMMA_ID = "google/gemma-3-1b-pt"

def tokenizer_init():
    tokenizer = AutoTokenizer.from_pretrained(HF_GEMMA_ID)
    return tokenizer


def model_init():
    model = AutoModelForCausalLM.from_pretrained(HF_GEMMA_ID)
    return model.to(settings.device)


def format_dataset_row(
    row, for_test: bool, unique_from_accounts: set, eos_token: str
) -> Dict[Literal["text"], str]:
    row = dict(row)
    from_account = row.pop("from_account")
    formatted = dedent(
        f"""
        <possible accounts>{', '.join(unique_from_accounts)}</possible accounts>
        <transaction>{json.dumps(row)}</transaction>
        which account did this transaction come from?
        answer: {'' if for_test else f"{from_account}{eos_token}"}
    """
    ).strip()
    return {"text": formatted}


def test_dataset_init(tokenizer) -> DatasetDict:
    dataset = load_dataset(f"{settings.hf_user_name}/{settings.hf_dataset_repo_name}")

    unique_from_accounts = set(
        dataset["train"]["from_account"] + dataset["test"]["from_account"]
    )
    eos_token = tokenizer.eos_token
    format_for_test = partial(
        format_dataset_row,
        for_test=True,
        unique_from_accounts=unique_from_accounts,
        eos_token=eos_token,
    )

    dataset["test"] = dataset["test"].map(format_for_test)
    del dataset["train"]

    return dataset

## Pipeline

In [29]:
class StopOnEOSToken(StoppingCriteria):
    """Stop generation when the EOS token is generated."""

    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        super().__init__()

    def __call__(self, input_ids, *args, **kwargs):
        if input_ids[0, -1] == self.tokenizer.eos_token_id:
            return True
        return False


class RawTextGenerationPipeline(TextGenerationPipeline):
    def _extract_from_account(self, text: str) -> str:
        """Extract the from_account found between 'answer:' and the first occurrence of '<eos>' or '<pad>'."""
        # Find start of account name after 'answer:'
        start_idx = text.find('answer:') + len('answer:')
        if start_idx == -1:
            return ''
        
        # Get substring after 'answer:'
        text_after_answer = text[start_idx:].strip()
        
        # Find first occurrence of eos or pad token
        eos_idx = text_after_answer.find('<eos>')
        pad_idx = text_after_answer.find('<pad>')
        
        # Get earliest token occurrence, or use full string if neither found
        end_idx = min(idx for idx in [eos_idx, pad_idx] if idx != -1) if eos_idx != -1 or pad_idx != -1 else len(text_after_answer)
            
        return text_after_answer[:end_idx].strip()

    def postprocess(self, model_outputs, *args, **kwargs):
        decoded = self.tokenizer.decode(model_outputs["generated_sequence"][0][0])
        account = self._extract_from_account(decoded)
        return account

## Score

In [31]:
def predict_and_score(pipeline: TextGenerationPipeline, dataset: DatasetDict) -> float:
    predictions = pipeline(dataset["test"]["text"])
    actual = dataset["test"]["from_account"]
    return sum(1 for pred, act in zip(predictions, actual) if pred == act) / len(
        predictions
    )


tokenizer = tokenizer_init()
model = AutoModelForCausalLM.from_pretrained(
    f"{settings.hf_user_name}/{settings.gemma_3_1b_causal_finetune}"
)
pipeline = RawTextGenerationPipeline(
    task="text-generation",
    model=model,
    device=settings.device,
    tokenizer=tokenizer,
    stopping_criteria=StoppingCriteriaList([StopOnEOSToken(tokenizer)]),
    max_new_tokens=50,
)

dataset = test_dataset_init(tokenizer)

predict_and_score(pipeline, dataset)

Device set to use mps


0.44919786096256686

In [33]:
example = dataset["test"]["text"][-1]
actual = dataset["test"]["from_account"][-1]
prediction = pipeline(example)

print(example, actual, prediction, sep="\n\n")

<possible accounts>Assets:Discover:Travel, Assets:Discover:Main:Needs:Monthly, Assets:Discover:FutureWants, Assets:Discover:Main:Needs:Groceries, Assets:Discover:Furniture, Assets:Discover:Main:Wants:Monthly, Assets:Discover:Main:Wants:Other, Assets:Discover:Main:Needs:Gas, Assets:Discover:Main:Needs:Other</possible accounts>
<transaction>{"amount": 31.29, "month": 1, "day": 4, "year": 2025, "vendor": "CASEY'S"}</transaction>
which account did this transaction come from?
answer:

Assets:Discover:Main:Needs:Gas

Assets:Discover:Main:Wants:Other
