In [1]:
import os
import sys

sys.path.append(
    os.path.abspath(
        os.path.join(os.path.dirname(os.path.abspath("__file__")), "../../..")
    )
)

In [2]:
import json
from functools import partial
from textwrap import dedent
from typing import Dict, Literal

from datasets import DatasetDict, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline
from transformers.generation.stopping_criteria import (
    StoppingCriteria,
    StoppingCriteriaList,
)

from src.shared.config import settings

## Utils

In [3]:
HF_GPT_ID = "openai-gpt"


def tokenizer_init():
    tokenizer = AutoTokenizer.from_pretrained(HF_GPT_ID)
    tokenizer.add_special_tokens(
        {"pad_token": "<pad>", "eos_token": "<eos>"}
    )  # gpt-1 tokenizer lacks these by default
    return tokenizer


def model_init():
    tokenizer = tokenizer_init()
    model = AutoModelForCausalLM.from_pretrained(HF_GPT_ID)
    model.resize_token_embeddings(
        len(tokenizer), mean_resizing=False
    )  # extend the embedding layer to handle padding and eos tokens
    model.config.pad_token_id = tokenizer.pad_token_id
    model.config.eos_token_id = tokenizer.eos_token_id
    return model.to(settings.device)


def format_dataset_row(
    row, for_test: bool, unique_from_accounts: set, eos_token: str
) -> Dict[Literal["text"], str]:
    row = dict(row)
    from_account = row.pop("from_account")
    formatted = dedent(
        f"""
        <possible accounts>{', '.join(unique_from_accounts)}</possible accounts>
        <transaction>{json.dumps(row)}</transaction>
        which account did this transaction come from?
        answer: {'' if for_test else f"{from_account}{eos_token}" }
    """
    ).strip()
    return {"text": formatted}


def test_dataset_init(tokenizer) -> DatasetDict:
    dataset = load_dataset(f"{settings.hf_user_name}/{settings.hf_dataset_repo_name}")

    unique_from_accounts = set(
        dataset["train"]["from_account"] + dataset["test"]["from_account"]
    )
    eos_token = tokenizer.eos_token
    format_for_test = partial(
        format_dataset_row,
        for_test=True,
        unique_from_accounts=unique_from_accounts,
        eos_token=eos_token,
    )

    dataset["test"] = dataset["test"].map(format_for_test)
    del dataset["train"]

    return dataset

## Pipeline

In [4]:
class StopOnEOSToken(StoppingCriteria):
    """Stop generation when the EOS token is generated."""

    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        super().__init__()

    def __call__(self, input_ids, *args, **kwargs):
        if input_ids[0, -1] == self.tokenizer.eos_token_id:
            return True
        return False


class RawTextGenerationPipeline(TextGenerationPipeline):
    def _extract_from_account(self, text: str) -> str:
        """Extract the from_account found between 'answer:' and the first occurrence of '<eos>' or '<pad>'."""
        answer_search = "answer :"
        answer_start = text.find(answer_search)
        if answer_start == -1:
            return ""
        answer_end = answer_start + len(answer_search)

        # Find indices of <eos> and <pad> after answer_end
        eos_index = text.find("<eos>", answer_end)
        pad_index = text.find("<pad>", answer_end)

        # Default to end of text if not found
        eos_index = eos_index if eos_index != -1 else len(text)
        pad_index = pad_index if pad_index != -1 else len(text)

        end_index = min(eos_index, pad_index)
        account = text[answer_end:end_index].replace(" ", "").strip()

        def title_case_account(account_str: str) -> str:
            mapping = {
                "assets:discover:furniture": "Assets:Discover:Furniture",
                "assets:discover:main:needs:other": "Assets:Discover:Main:Needs:Other",
                "assets:discover:main:wants:monthly": "Assets:Discover:Main:Wants:Monthly",
                "assets:discover:main:wants:other": "Assets:Discover:Main:Wants:Other",
                "assets:discover:main:needs:groceries": "Assets:Discover:Main:Needs:Groceries",
                "assets:discover:main:needs:gas": "Assets:Discover:Main:Needs:Gas",
                "assets:discover:futurewants": "Assets:Discover:FutureWants",
                "assets:discover:travel": "Assets:Discover:Travel",
                "assets:discover:main:needs:monthly": "Assets:Discover:Main:Needs:Monthly",
            }
            return mapping.get(
                account_str,
                ":".join(word.capitalize() for word in account_str.split(":")),
            )

        return title_case_account(account)

    def postprocess(self, model_outputs, *args, **kwargs):
        decoded = self.tokenizer.decode(model_outputs["generated_sequence"][0][0])
        account = self._extract_from_account(decoded)
        return account

## Score

In [5]:
def predict_and_score(pipeline: TextGenerationPipeline, dataset: DatasetDict) -> float:
    predictions = pipeline(dataset["test"]["text"])
    actual = dataset["test"]["from_account"]
    return sum(1 for pred, act in zip(predictions, actual) if pred == act) / len(
        predictions
    )


tokenizer = tokenizer_init()
model = AutoModelForCausalLM.from_pretrained(
    f"{settings.hf_user_name}/{settings.gpt_1_causal_finetune}"
)
pipeline = RawTextGenerationPipeline(
    task="text-generation",
    model=model,
    device=settings.device,
    tokenizer=tokenizer,
    stopping_criteria=StoppingCriteriaList([StopOnEOSToken(tokenizer)]),
    max_new_tokens=50,
)

dataset = test_dataset_init(tokenizer)

predict_and_score(pipeline, dataset)

Device set to use mps


0.7433155080213903

In [6]:
example = dataset["test"]["text"][21]
actual = dataset["test"]["from_account"][21]
prediction = pipeline(example)

print(example, actual, prediction, sep="\n\n")

<possible accounts>Assets:Discover:Main:Needs:Groceries, Assets:Discover:Main:Needs:Monthly, Assets:Discover:Main:Wants:Monthly, Assets:Discover:Main:Needs:Other, Assets:Discover:Travel, Assets:Discover:FutureWants, Assets:Discover:Main:Wants:Other, Assets:Discover:Main:Needs:Gas, Assets:Discover:Furniture</possible accounts>
<transaction>{"amount": 89.99, "month": 6, "day": 4, "year": 2023, "vendor": "WALMART"}</transaction>
which account did this transaction come from?
answer:

Assets:Discover:Main:Needs:Groceries

Assets:Discover:Main:Needs:Groceries
