In [1]:
import os
import json
from dotenv import load_dotenv
from typing import List, Dict
from tqdm import tqdm
import time

# ReceiptIQ Model

This is the model behind the receiptiq application. It takes a receipt/invoice image and a json schema and generates the desired output in json format with each JSON leaf containing both the extracted value and the bounding box coordinates in format [x,y,w,h]

# Dataset

The starting dataset is a composite of images only from:
- [Receipt or Invoice Computer Vision Model](https://universe.roboflow.com/jakob-awn1e/receipt-or-invoice)
- [OCR Receipts Text Detection - retail dataset](https://www.kaggle.com/datasets/trainingdatapro/ocr-receipts-text-detection)
- [Receipt Dataset for information extraction](https://www.kaggle.com/datasets/dhiaznaidi/receiptdatasetssd300v2)

Totaling `7,334 images` in `datasets/images` folder.

## Dataset Preparation

Given this dataset does not natively contain extracted information and importantly the bounding boxes (Some contain data but it's incomplete and mostly doesn't have the bounding information), the raw images were sent through a larger model to distill the high quality training data from the receipts. In this case it is `gpt4o` from openai.

### Model disitilation
Sample json file
```json
{
  "store_info": {
    "name": {
      "value": "POP TATES R-MALL",
      "bbox": [50, 20, 150, 20],
      "descr": "Store name"
    },
    "address": {
      "value": "MULUND WEST, MUMBAI - 421 004",
      "bbox": [50, 40, 200, 20],
      "descr": "Store address"
    },
    "phone": {
      "value": "TEL: 2591 2591",
      "bbox": [50, 60, 150, 20],
      "descr": "Store phone number"
    }
  },
  "transaction_info": {
    "bill_no": {
      "value": "2708/020612/3V",
      "bbox": [50, 100, 150, 20],
      "descr": "Bill number"
    }
    ...
```

Having tested on 10 images which took 204s i.e 20s/image:
the collection of 7334 images would take `7,334 images * 20 s/image = 146,680 s approx 40hr 45 minutes` which is inefficient.
Coupled with the cheaper pricing for batch inference and api rate limits, batch inference is prefered for model distillation.
Nonethless, a batch of 500 was used to refine the the rest of the prep and training while full batch was being prepared

In [2]:
import shutil
import glob
from openai import OpenAI

load_dotenv()
API_KEY = os.getenv("OPENAI_API_KEY")
MODEL = "gpt-5-nano"

HEADERS = {
    "Authorization": f"Bearer {API_KEY}",
    "Content-Type": "application/json"
}

SYSTEM_PROMPT = (
    "You're a document parser. Given an image of a receipt, "
    "extract structured data in JSON including all leaf-node values and their bounding boxes (x,y,w,h). For the leaf nodes" \
    "use the format `name: {{value: actual value, bbox: bounding box in the format (x,y,w,h), descr: description of the field}} "
)

client = OpenAI(api_key=API_KEY)

def get_image_files(input_dir: str):
    return glob.glob(os.path.join(input_dir, "*.jpg")) + glob.glob(os.path.join(input_dir, "*.png"))

def prepare_batch_files(image_files_list: List[str], output_dir: str, max_items_per_batch: int):
    IMAGE_BASE_URL = "https://receiptiq-model-finetuning-receipts.t3.storageapi.dev"
    shutil.rmtree(output_dir)
    os.makedirs(output_dir, exist_ok=True)
    batch_count = 0
    current_batch = []
    batch_files = []
    for img_file in tqdm(image_files_list, desc="Preparing batch files"):
        file_name = img_file.split("/")[-1]
        img_url = f"{IMAGE_BASE_URL}/{file_name}"
        payload = {
            "custom_id": img_file,
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": {
                "model": MODEL,
                "messages": [
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {
                        "role": "user",
                        "content": [
                            {"type": "image_url","image_url": {"url": img_url}},
                            {"type": "text","text": "Please extract the structured JSON for this receipt. Respond only in json format of your best approximation."}
                        ]
                    }
                ]
            }
        }
        current_batch.append(payload)
        if len(current_batch) >= max_items_per_batch:
            batch_path = os.path.join(output_dir, f"batch_{batch_count}.jsonl")
            with open(batch_path, "w") as f:
                for item in current_batch:
                    f.write(json.dumps(item) + "\n")
            batch_count += 1
            current_batch = []
            batch_files.append(batch_path)
    if current_batch:
        batch_path = os.path.join(output_dir, f"batch_{batch_count}.jsonl")
        batch_files.append(batch_path)
        with open(batch_path, "w") as f:
            for item in current_batch:
                f.write(json.dumps(item) + "\n")
    return batch_files

def upload_batch_file(filepath: str):
    """Upload a JSONL file for batch processing"""
    with open(filepath, "rb") as f:
        resp = client.files.create(file=f, purpose="batch")
    return resp.id

def start_batch(file_id: str, batch_name: str, completion_window: str):
    """Start a batch job from an uploaded file"""
    resp = client.batches.create(
        input_file_id=file_id,
        endpoint="/v1/chat/completions",
        completion_window=completion_window,
        metadata={"name": batch_name}
    )
    return resp.id

def wait_for_completion(batch_id: str, poll_interval:int):
    """Poll until batch is completed or failed"""
    while True:
        resp = client.batches.retrieve(batch_id)
        status = resp.status
        print(f"[{time.strftime('%H:%M:%S')}] Batch {batch_id} status: {status}")
        if status in ("completed", "failed", "expired", "cancelled"):
            return status
        time.sleep(poll_interval)

def load_unprocessed_images(processed_file_path, input_dir):
    images_files_list = get_image_files(input_dir)
    processed_files = set()
    if os.path.exists(processed_file_path):
        with open(processed_file_path) as f:
            processed_files = set(json.load(f))
    return [img for img in images_files_list if img not in processed_files], processed_files

def save_processed(processed_set, processed_file_path):
    with open(processed_file_path, "w") as f:
        json.dump(list(processed_set), f)


In [3]:
MAX_ITEMS_PER_BATCH = 500
INPUT_DIR = "datasets/images"
BATCHES_OUTPUT_DIR = "datasets/batch_files"
COMPLETION_WINDOW = "24h"
POLL_INTERVAL = 60
PROCESSED_FILE = "datasets/processed_images.json" # tracking file to avoid repetitions during restarts

remaining_images, processed_images = load_unprocessed_images(PROCESSED_FILE,INPUT_DIR)
batch_files = prepare_batch_files(remaining_images, output_dir=BATCHES_OUTPUT_DIR, max_items_per_batch=MAX_ITEMS_PER_BATCH)

for idx, batch_file in enumerate(batch_files, start=1):
    file_id = upload_batch_file(batch_file)
    print(f"[✓] Uploaded file {batch_file} → File ID: {file_id}")

    batch_id = start_batch(file_id, f"receiptiq_distill_batch_{idx}", completion_window=COMPLETION_WINDOW)
    print(f"[✓] Started batch receiptiq_distill_batch_{idx} → Batch ID: {batch_id}")

    status = wait_for_completion(batch_id, poll_interval=POLL_INTERVAL)
    print(f"[!] Batch receiptiq_distill_batch_{idx} finished with status: {status}")

    if status == "completed":
        with open(batch_file) as bf:
            for line in bf:
                data = json.loads(line)
                processed_images.add(data["custom_id"])
        save_processed(processed_images, processed_file_path=PROCESSED_FILE)
        time.sleep(300)
    else:
        print(f"[!] Batch failed: receiptiq_distill_batch_{idx}")
        break

Preparing batch files: 0it [00:00, ?it/s]


### Download Distillation Output
- Dowload the output (jsonl)
- Try and parse the output as json and `ONLY` accept the the ones that are valid json for now

In [4]:
receipts_and_data = []
faulty_extractions = []
for batch in tqdm(client.batches.list()):
    if batch.status == "completed":
        try:
            output = client.files.content(batch.output_file_id)
        except Exception as e:
            print(e)
            continue
        for receipt in output.text.splitlines():
            completion_response_json = json.loads(receipt)
            receipt_dict = dict(json.loads(receipt)).get("response").get("body").get("choices")[0]
            receipt_data = receipt_dict.get("message").get("content")
            try:
                parsed_data = json.loads(receipt_data)
                with open(f'datasets/data/{completion_response_json.get("custom_id").split("/")[-1]}.json',"w") as df:
                    df.write(receipt_data)
                receipts_and_data.append({
                    "image_filename": f'{completion_response_json.get("custom_id").split("/")[-1]}',
                    "data": parsed_data
                })
            except:
                faulty_extractions.append(completion_response_json.get("custom_id").split("/")[-1])
print(f"{len(receipts_and_data)} receipts")
print(f"{len(faulty_extractions)} failed")

33it [00:45,  1.46it/s]

Error code: 404 - {'error': {'message': 'No such File object: file-RXx22D9VdHczqR8qZ5H4rj', 'type': 'invalid_request_error', 'param': 'id', 'code': None}}


35it [00:45,  1.90it/s]

Error code: 404 - {'error': {'message': 'No such File object: file-7qS7Be5NVspA9kBuLzX8KN', 'type': 'invalid_request_error', 'param': 'id', 'code': None}}


36it [00:46,  2.02it/s]

Error code: 404 - {'error': {'message': 'No such File object: file-M5BKp5Q5NYnYKXLvEgAg9a', 'type': 'invalid_request_error', 'param': 'id', 'code': None}}


45it [00:47,  1.05s/it]

Error code: 404 - {'error': {'message': 'No such File object: file-V1ofekdEcGvG18sMDrJ3tz', 'type': 'invalid_request_error', 'param': 'id', 'code': None}}
7170 receipts
160 failed





## Clean and format the data
- Try and extract the schema and ONLY accept those that are in the expected schema format for now
I expected and requested the schema to be `name: {{value: actual value, bbox: bounding box in the format (x,y,w,h), descr: description of the field}}`
with possible child dict and child lists as part of hierarchy.
- Remove descr from the data
- Convert the list to HF dataset

In [5]:
from datasets import Dataset, Features, Value, Image as HFImage

# Extract schema and only accept those that match the expected schema
def get_schema(data: dict):
    schema = {}
    for k, v in data.items():
        if isinstance(v, dict):
            if "value" in v:
                schema[k] = f"string//{v.get('descr', '')}"
            else:
                schema[k] = get_schema(v)
        elif isinstance(v, list):
            schema[k] = [get_schema(v[0]) if len(v)>0 else {}]
    return schema

receipts_and_data_schema_checked = []
for item in tqdm(receipts_and_data):
    try:
        schema = get_schema(item['data'])
        receipts_and_data_schema_checked.append(item)
    except Exception as e:
        # print(item["image_filename"])
        # raise
        pass


# Remove description `descr` fromt the data
def remove_descr(data: dict):
    out_data = {}
    for k, v in data.items():
        if isinstance(v, dict):
            if "value" in v:
                vc = v.copy()
                vc.pop("descr",None)
                out_data[k] = vc
            else:
                out_data[k] = remove_descr(v)
        elif isinstance(v, list):
            out_data[k] = [remove_descr(child) for child in v]
    return out_data

receipts_and_data_cleaned = []
for item in receipts_and_data_schema_checked:
    data_cleaned = remove_descr(item['data'])
    receipts_and_data_cleaned.append({
        "data": json.dumps(data_cleaned, ensure_ascii=False),
        "image_filename": f"datasets/images/{item['image_filename']}"
    })


# Convert the list to HF dataset
features = Features({
    "image_filename": Value("string"),
    "data": Value("string")
})
receipts_and_data_dataset = Dataset.from_list(receipts_and_data_cleaned, features=features)
# receipts_and_data_dataset = receipts_and_data_dataset.cast_column("image_filename", HFImage(decode=True))
receipts_and_data_dataset, receipts_and_data_dataset[0]["data"]

100%|██████████| 7170/7170 [00:00<00:00, 410297.42it/s]


(Dataset({
     features: ['data', 'image_filename'],
     num_rows: 6876
 }),
 '{"company_name": {"value": "[unreadable]", "bbox": "(60,40,520,40)"}, "store_number": {"value": "[unreadable]", "bbox": "(520,20,120,20)"}, "date": {"value": "[unreadable]", "bbox": "(60,90,180,20)"}, "time": {"value": "[unreadable]", "bbox": "(250,90,120,20)"}, "receipt_id": {"value": "[unreadable]", "bbox": "(460,90,180,20)"}, "subtotal": {"value": "[unreadable]", "bbox": "(500,420,120,20)"}, "tax": {"value": "[unreadable]", "bbox": "(620,420,120,20)"}, "total": {"value": "[unreadable]", "bbox": "(500,450,140,20)"}, "currency": {"value": "USD", "bbox": "(700,450,50,20)"}, "items": [{"name": {"value": "[unreadable]", "bbox": "(60,260,320,20)"}, "qty": {"value": "[unreadable]", "bbox": "(390,260,40,20)"}, "price": {"value": "[unreadable]", "bbox": "(430,260,60,20)"}, "line_total": {"value": "[unreadable]", "bbox": "(500,260,60,20)"}}]}')

# Tokenization

## Tokenize

In [6]:
from transformers import AutoProcessor
from datasets import Sequence, Array4D
import numpy as np
from PIL import Image

hf_model_id = "meta-llama/Llama-3.2-11B-Vision"
processor = AutoProcessor.from_pretrained(hf_model_id)
tokenizer = processor.tokenizer

def tokenize(batch):
    input_ids = []
    attention_mask = []
    output_ids = []
    output_attention_mask = []
    pixel_values = []
    aspect_ratio_ids = []
    aspect_ratio_mask = []

    for data_json, image_path in zip(batch["data"], batch["image_filename"]):
        image_data = Image.open(image_path)
        data_dict = json.loads(data_json)
        schema_str = json.dumps(get_schema(data_dict), ensure_ascii=False)

        input_tokens = processor(
            image_data, f"<|image|><|begin_of_text|>{schema_str}", return_tensors="pt"
        )

        output_tokens = tokenizer(
            json.dumps(data_dict, ensure_ascii=False), return_tensors="pt"
        )

        input_ids.append(input_tokens["input_ids"][0])
        attention_mask.append(input_tokens["attention_mask"][0])
        output_ids.append(output_tokens["input_ids"][0])
        output_attention_mask.append(output_tokens["attention_mask"][0])
        pixel_values.append(input_tokens["pixel_values"][0])
        aspect_ratio_ids.append(input_tokens["aspect_ratio_ids"][0])
        aspect_ratio_mask.append(input_tokens["aspect_ratio_mask"][0])

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "output_ids": output_ids,
        "output_attention_mask": output_attention_mask,
        "pixel_values": pixel_values,  # numpy arrays, not lists
        "aspect_ratio_ids": aspect_ratio_ids,
        "aspect_ratio_mask": aspect_ratio_mask,
    }

features = Features({
    "input_ids": Sequence(Value("int32")),
    "attention_mask": Sequence(Value("int8")),
    "output_ids": Sequence(Value("int32")),
    "output_attention_mask": Sequence(Value("int8")),
    "pixel_values": Sequence(Sequence(Sequence(Sequence(Value("float32"))))),
    "aspect_ratio_ids": Sequence(Value("int32")),
    "aspect_ratio_mask": Sequence(Value("int8")),
})

receipts_and_data_dataset = receipts_and_data_dataset.select(range(200)).map(
    tokenize, 
    batched=True,
    batch_size=2,
    num_proc=2
)
receipts_and_data_dataset

Map (num_proc=2):   0%|          | 0/200 [00:00<?, ? examples/s]

Dataset({
    features: ['data', 'image_filename', 'input_ids', 'attention_mask', 'output_ids', 'output_attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask'],
    num_rows: 200
})

## Check Data Loader

In [None]:
from utils import ReceiptIQModelDataLoader
from torch.utils.data import DataLoader

max_len = max([len(x)+len(y) for x,y in zip(receipts_and_data_dataset["input_ids"],receipts_and_data_dataset["output_ids"])])
test_dataset = ReceiptIQModelDataLoader(dataset=receipts_and_data_dataset,max_len=max_len, tokenizer=tokenizer)
test_dataloader = DataLoader(dataset=test_dataset,batch_size=1)
next(iter(test_dataloader))

ImportError: attempted relative import with no known parent package