In [86]:
import os
import sys

sys.path.append(
    os.path.abspath(
        os.path.join(os.path.dirname(os.path.abspath("__file__")), "../../..")
    )
)

In [87]:
from functools import partial
from textwrap import dedent
from typing import Dict, Literal

from datasets import DatasetDict, load_dataset, Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments

from src.shared.config import settings

## Utils

In [88]:
HF_GPT_ID = "openai-gpt"
FINETUNED_HF_GPT_ID = f"{settings.hf_user_name}/{settings.gpt_1_sequence_classification}"


def tokenizer_init():
    tokenizer = AutoTokenizer.from_pretrained(HF_GPT_ID)
    tokenizer.add_special_tokens(
        {"pad_token": "<pad>"}
    )  # gpt-1 tokenizer lacks this by default
    return tokenizer


def model_init():
    tokenizer = tokenizer_init()
    model = AutoModelForSequenceClassification.from_pretrained(FINETUNED_HF_GPT_ID, num_labels=8)
    model.resize_token_embeddings(
        len(tokenizer), mean_resizing=False
    )  # extend the embedding layer to handle padding token
    model.config.pad_token_id = tokenizer.pad_token_id
    return model.to(settings.device)


def format_dataset_row(
    row, label_mapping: Dict[str, int]
) -> Dict[Literal["text"], str]:
    row = dict(row)
    from_account = row.pop("from_account")
    formatted = dedent(
        f"""
        Transaction
        -----------
        Description: {row['description']}
        Amount: {row['amount']}
        Category: {row['category']} (Source: {row['category_source']})
        Transaction Date: {row['transaction_date']}
        Day of Week: {row['day_of_week']}
        Card: {row['card']}

        Question: Which account initiated this transaction?
        Answer:
        """
    ).strip()
    return {"text": formatted, "labels": label_mapping[from_account]}

def test_dataset_init(tokenizer) -> DatasetDict:
    dataset = load_dataset(f"{settings.hf_user_name}/{settings.hf_dataset_repo_name}")

    unique_from_accounts = sorted(set(
        dataset["train"]["from_account"] + dataset["test"]["from_account"]
    ))
    label_mapping = {account: idx for idx, account in enumerate(unique_from_accounts)}
    print(f'label_mapping: {label_mapping}')
    format_dataset_row_partial = partial(
        format_dataset_row,
        label_mapping=label_mapping,
    )

    dataset["test"] = dataset["test"].map(format_dataset_row_partial)
    del dataset["train"]

    remove_columns = ['transaction_date', 'description', 'amount', 'category', 'category_source', 'card', 'day_of_week']
    dataset = dataset.map(
        lambda batch: tokenizer(batch["text"]),
        batched=True,
        remove_columns=remove_columns,
    )
    dataset.set_format("pt")
    return dataset

## Score

In [91]:
training_args = TrainingArguments(
    output_dir="/tmp/eval",
    per_device_eval_batch_size=64,
    report_to="none",
)

trainer = Trainer(
    model=model_init(),
    args=training_args,
    data_collator=DataCollatorWithPadding(tokenizer=tokenizer_init()),
)

dataset = test_dataset_init(tokenizer_init())

predictions = trainer.predict(dataset["test"])
actual = predictions.label_ids
predicted = predictions.predictions.argmax(axis=-1)

# Calculate accuracy using numpy for better performance
accuracy = (predicted == actual).mean()
accuracy.item()

label_mapping: {'Assets:Discover:Furniture': 0, 'Assets:Discover:FutureWants': 1, 'Assets:Discover:Main:Needs:Gas': 2, 'Assets:Discover:Main:Needs:Groceries': 3, 'Assets:Discover:Main:Needs:Monthly': 4, 'Assets:Discover:Main:Needs:Other': 5, 'Assets:Discover:Main:Wants:Monthly': 6, 'Assets:Discover:Main:Wants:Other': 7}


0.927461139896373

## Example

In [92]:
instance = 0
example = dataset["test"][instance]

instance_dataset = Dataset.from_dict({
    'input_ids': [example['input_ids']],
    'attention_mask': [example['attention_mask']],
    'labels': [example['labels']]
})

unique_from_accounts = sorted(set(
    dataset["test"]["from_account"]
))
label_mapping = {account: idx for idx, account in enumerate(unique_from_accounts)}
reverse_mapping = {idx: account for account, idx in label_mapping.items()}

prediction = trainer.predict(instance_dataset)
predicted_class = prediction.predictions.argmax(axis=-1)[0]

actual = reverse_mapping[example['labels'].item()]
prediction = reverse_mapping[predicted_class]

print(example["text"], actual, prediction, sep="\n\n")

Transaction
-----------
Description: O DONELL ACE HARDWARE DES MOINES IA
Amount: 23.19
Category: Home Improvement (Source: Discover)
Transaction Date: 2023-03-22
Day of Week: Wednesday
Card: Discover It Chrome

Question: Which account initiated this transaction?
Answer:

Assets:Discover:Main:Needs:Other

Assets:Discover:Main:Needs:Other
