In [None]:
import json
import os
from pathlib import Path

import yaml
from dotenv import load_dotenv
from langchain.chat_models import init_chat_model
from tqdm.auto import tqdm

from tax_form.dataset import Dataset, Sample
from tax_form.steps.continuation_clf import classify_continuation_page
from tax_form.steps.first_page_clf import classify_form_start

load_dotenv()

DATA_PROCESSED_DIR = Path("../data/processed")
MODEL_CONFIG_PATH = Path("../configs/model_openai.yaml")

In [None]:
dataset = Dataset(data_dir=DATA_PROCESSED_DIR, img_mode="base64").load()
# we can also load image from base64 for printing while debugging
# but i just don't want to deal with that here right now
dataset_pil = Dataset(data_dir=DATA_PROCESSED_DIR, img_mode="pil").load()

In [None]:
with open(MODEL_CONFIG_PATH, "r") as f:
    config = yaml.safe_load(f)
    config["api_key"] = os.getenv(config["api_key"])

llm = init_chat_model(**config)

In [None]:
# TODO: async, proper restart, log to db instead of file


# just logger not to lose progress
def write_jsonl(data, filepath):
    with open(filepath, "a+") as f:
        f.write(json.dumps(data) + "\n")


def classification_pipeline(sample: Sample, output_dir: str = "../data/output"):
    """Find all specified tax forms in the sample pages.

    First find start pages for any form, then continuation of the form pages until the form ends.
    Then move to the next form from the remaining pages.

    Return: list of classified forms with their start and end pages.
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(exist_ok=True, parents=True)

    forms = []
    with tqdm(total=len(sample.pages), desc="Processing") as pbar:
        i = 0
        while i < len(sample.pages):
            page = sample.pages[i]
            res = classify_form_start(page, llm)
            write_jsonl(
                {"page_index": i, "classification": res},
                filepath=output_dir / f"{sample.sample_name}.jsonl",
            )
            if res != "other":
                form = {"form_type": res, "start_page": i, "end_page": i}
                i += 1
                while i < len(sample.pages):

                    cont_res = classify_continuation_page(
                        current_page=sample.pages[i],
                        previous_page=sample.pages[i - 1],
                        start_page=sample.pages[form["start_page"]],
                        llm=llm,
                    )
                    write_jsonl(
                        {"page_index": i, "continuation_classification": cont_res},
                        filepath=output_dir / f"{sample.sample_name}.jsonl",
                    )
                    if cont_res:
                        form["end_page"] = i
                        i += 1
                    else:
                        break
                forms.append(form)
                write_jsonl(
                    {"classified_form": form},
                    filepath=output_dir / f"{sample.sample_name}.jsonl",
                )
            else:
                i += 1
            pbar.update(1)
    return forms

In [None]:
for x in dataset:
    print(x.sample_name)
    res_forms = classification_pipeline(x)