In [None]:
import os
import json
from dotenv import load_dotenv
from typing import List
from tqdm import tqdm
import time

# ReceiptIQ Model

This is the model behind the receiptiq application. It takes a receipt/invoice image and a json schema and generates the desired output in json format with each JSON leaf containing both the extracted value and the bounding box coordinates in format [x,y,w,h]

# Dataset

The starting dataset is a composite of images only from:
- [Receipt or Invoice Computer Vision Model](https://universe.roboflow.com/jakob-awn1e/receipt-or-invoice)
- [OCR Receipts Text Detection - retail dataset](https://www.kaggle.com/datasets/trainingdatapro/ocr-receipts-text-detection)
- [Receipt Dataset for information extraction](https://www.kaggle.com/datasets/dhiaznaidi/receiptdatasetssd300v2)

Totaling `7,334 images` in `datasets/images` folder.

## Dataset Preparation

Given this dataset does not natively contain extracted information and importantly the bounding boxes (Some contain data but it's incomplete and mostly doesn't have the bounding information), the raw images were sent through a larger model to distill the high quality training data from the receipts. In this case it is `gpt4o` from openai.

### Model disitilation

Below script was tested on an initial batch of 10 images in order to finetune the prompt and settings

In [None]:
import base64
import glob
from openai import OpenAI

load_dotenv()
API_KEY = os.getenv("OPENAI_API_KEY")
MODEL = "gpt-5-nano"

HEADERS = {
    "Authorization": f"Bearer {API_KEY}",
    "Content-Type": "application/json"
}

SYSTEM_PROMPT = (
    "You're a document parser. Given an image of a receipt, "
    "extract structured data in JSON including all leaf-node values and their bounding boxes (x,y,w,h). For the leaf nodes" \
    "use the format `name: {{value: actual value, bbox: bounding box in the format (x,y,w,h), descr: description of the field}} "
)

client = OpenAI(api_key=API_KEY)

def encode_image(file_path: str) -> str:
    with open(file_path, "rb") as f:
        img_bytes = f.read()
    return base64.b64encode(img_bytes).decode("utf-8")

def call_openai(image_b64: str) -> dict:
    try:
        payload = {
            "model": MODEL,
            "messages": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {
                    "role": "user",
                    "content": [
                        {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}},
                        {"type": "text", "text": "Please extract the structured JSON for this receipt."}
                    ]
                }
            ]
        }
        response = client.chat.completions.create(
                    **payload,
                    temperature=0.1,  # Low temperature for consistent extraction
                    max_tokens=2000
                )
        result_text = response.choices[0].message.content.strip().replace("```json\n","").replace("```","")
        return result_text
    except Exception as e:
        raise

def get_image_files(input_dir: str):
    return glob.glob(os.path.join(input_dir, "*.jpg")) + glob.glob(os.path.join(input_dir, "*.png"))


In [None]:
input_dir = "datasets/images"
output_dir = "datasets/data"
os.makedirs(output_dir, exist_ok=True)
file_list = get_image_files(input_dir)[:10]
print(f"Found {len(file_list)} images. Starting...")

start = time.time()
for file in tqdm(file_list):
    if os.path.exists(f"{output_dir}/{os.path.basename(file)}.json"):
        continue
    image_enconded = encode_image(file)
    receipt_data = call_openai(image_enconded)
    with open(f"{output_dir}/{os.path.basename(file)}.json","w") as f:
        f.write(receipt_data)
print(f"Completed in {time.time() - start:.2f}s")

Sample json file
```json
{
  "store_info": {
    "name": {
      "value": "POP TATES R-MALL",
      "bbox": [50, 20, 150, 20],
      "descr": "Store name"
    },
    "address": {
      "value": "MULUND WEST, MUMBAI - 421 004",
      "bbox": [50, 40, 200, 20],
      "descr": "Store address"
    },
    "phone": {
      "value": "TEL: 2591 2591",
      "bbox": [50, 60, 150, 20],
      "descr": "Store phone number"
    }
  },
  "transaction_info": {
    "bill_no": {
      "value": "2708/020612/3V",
      "bbox": [50, 100, 150, 20],
      "descr": "Bill number"
    }
    ...
```

Having tested on 10 images which took 204s i.e 20s/image:
the collection of 7334 images would take `7,334 images * 20 s/image = 146,680 s approx 40hr 45 minutes` which is inefficient.
Coupled with the cheaper pricing for batch inference and api rate limits, batch inference is prefered for model distillation.
Nonethless, a batch of 500 was used to refine the the rest of the prep and training while full batch was being prepared
 

### Batch model distilation

Use openai's batch api to send batches of prompts in jsonl format (max accepted size is 200MB)

In [None]:
import json
import base64
import shutil
from typing import List

MAX_ITEMS_PER_BATCH = 500
OUTPUT_DIR = "datasets/batch_files"
BASE_URL = "https://receiptiq-model-finetuning-receipts.t3.storageapi.dev"
COMPLETION_WINDOW = "24h"

os.makedirs(OUTPUT_DIR, exist_ok=True)

def encode_image(path: str) -> str:
    """Read image as-is and encode to base64."""
    with open(path, "rb") as img_file:
        return base64.b64encode(img_file.read()).decode("utf-8")

def prepare_batch_files(image_files_list: List[str], output_dir: str):
    shutil.rmtree(output_dir)
    batch_files = []
    os.makedirs(output_dir, exist_ok=True)
    batch_count = 0
    current_batch = []

    for img_file in tqdm(image_files_list, desc="Preparing batch files"):
        file_name = img_file.split("/")[-1]
        img_url = f"{BASE_URL}/{file_name}"
        payload = {
            "custom_id": img_file,
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": {
                "model": MODEL,
                "messages": [
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "image_url",
                                "image_url": {"url": img_url}
                            },
                            {
                                "type": "text",
                                "text": "Please extract the structured JSON for this receipt. Respond only in json format of your best approximation."
                            }
                        ]
                    }
                ]
            }
        }
        current_batch.append(payload)

        if len(current_batch) >= MAX_ITEMS_PER_BATCH:
            batch_path = os.path.join(output_dir, f"batch_{batch_count}.jsonl")
            with open(batch_path, "w") as f:
                for item in current_batch:
                    f.write(json.dumps(item) + "\n")
            batch_count += 1
            current_batch = []
            batch_files.append(batch_path)

    # Write remaining items
    if current_batch:
        batch_path = os.path.join(output_dir, f"batch_{batch_count}.jsonl")
        batch_files.append(batch_path)
        with open(batch_path, "w") as f:
            for item in current_batch:
                f.write(json.dumps(item) + "\n")

    print(f"[✓] Created {batch_count + (1 if current_batch else 0)} batch files in {output_dir}")
    return batch_files

POLL_INTERVAL = 60

def wait_for_completion(batch_id: str):
    """Poll until batch is completed or failed"""
    while True:
        resp = client.batches.retrieve(batch_id)
        status = resp.status
        print(f"[{time.strftime('%H:%M:%S')}] Batch {batch_id} status: {status}")
        if status in ("completed", "failed", "expired", "cancelled"):
            return status
        time.sleep(POLL_INTERVAL)

def upload_batch_file(filepath: str):
    """Upload a JSONL file for batch processing"""
    with open(filepath, "rb") as f:
        resp = client.files.create(file=f, purpose="batch")
    return resp.id

def start_batch(file_id: str, batch_name: str):
    """Start a batch job from an uploaded file"""
    resp = client.batches.create(
        input_file_id=file_id,
        endpoint="/v1/chat/completions",
        completion_window=COMPLETION_WINDOW,
        metadata={"name": batch_name}
    )
    return resp.id

PROCESSED_FILE = "processed_images.json"

def load_processed():
    if os.path.exists(PROCESSED_FILE):
        with open(PROCESSED_FILE) as f:
            return set(json.load(f))
    return set()

def save_processed(processed_set):
    with open(PROCESSED_FILE, "w") as f:
        json.dump(list(processed_set), f)

input_dir = "datasets/images"
output_dir = "datasets/batch_files"
INPUT_DIR = "datasets/images"
BATCH_NAME_PREFIX = "receiptiq_model_batch"
images_files_list = get_image_files(INPUT_DIR)
processed_images = load_processed()
remaining_images = [img for img in images_files_list if img not in processed_images]
batch_files = prepare_batch_files(remaining_images, output_dir=output_dir)

# Upload & create batches
for idx, batch_file in enumerate(batch_files, start=1):
    print(f"\n=== Processing batch {idx}/{len(batch_files)}: {batch_file} ===")
    file_id = upload_batch_file(batch_file)
    print(f"[✓] Uploaded file {batch_file} → File ID: {file_id}")

    batch_name = f"{BATCH_NAME_PREFIX}{idx}"
    batch_id = start_batch(file_id, batch_name)
    print(f"[✓] Started batch {batch_name} → Batch ID: {batch_id}")

    status = wait_for_completion(batch_id)
    print(f"[!] Batch {batch_name} finished with status: {status}")

    if status == "completed":
        with open(batch_file) as bf:
            for line in bf:
                data = json.loads(line)
                processed_images.add(data["custom_id"])
        save_processed(processed_images)
        time.sleep(300)
    else:
        print(f"[!] Batch failed: {batch_name}")
        break

In [None]:
receipts_and_data = []
faulty_extractions = []
for batch in client.batches.list():
    if batch.status == "completed":
        try:
            output = client.files.content(batch.output_file_id)
        except Exception as e:
            print(e)
            continue
        for receipt in output.text.splitlines():
            completion_response_json = json.loads(receipt)
            receipt_dict = dict(json.loads(receipt)).get("response").get("body").get("choices")[0]
            receipt_data = receipt_dict.get("message").get("content")
            try:
                parsed_data = json.loads(receipt_data)
                with open(f'datasets/data/{completion_response_json.get("custom_id").split("/")[-1]}.json',"w") as df:
                    df.write(receipt_data)
                receipts_and_data.append({
                    "image_filename": f'{completion_response_json.get("custom_id").split("/")[-1]}',
                    "data": parsed_data
                })
            except:
                faulty_extractions.append(completion_response_json.get("custom_id").split("/")[-1])
print(f"{len(receipts_and_data)} receipts")
print(f"{len(faulty_extractions)} failed")

### Tokenize the dataset

In [None]:

from dataclasses import asdict
from transformers import AutoProcessor
from PIL import Image
from datasets import Dataset, Features, Value, Image as HFImage

hf_model_id = "meta-llama/Llama-3.2-11B-Vision"

processor = AutoProcessor.from_pretrained(hf_model_id)
tokenizer = processor.tokenizer

def get_schema(data: dict):
    schema = {}
    for k, v in data.items():
        if isinstance(v, dict):
            if "value" in v:
                schema[k] = f"string//{v.get('descr', '')}"
            else:
                schema[k] = get_schema(v)
        elif isinstance(v, list):
            schema[k] = [get_schema(v[0]) if len(v)>0 else {}]
    return schema

# remove those whose data does not match expected schema
receipts_and_data_schema_checked = []
for item in tqdm(receipts_and_data):
    try:
        schema = get_schema(item['data'])
        receipts_and_data_schema_checked.append(item)
    except Exception as e:
        # print(e)
        # print(item["image_filename"])
        # raise
        pass
for item in receipts_and_data_schema_checked:
    item["data"] = json.dumps(item["data"], ensure_ascii=False)
    item['image_filename'] = f"datasets/images/{item['image_filename']}"

print(receipts_and_data_schema_checked[0])
features = Features({
    "image_filename": Value("string"),
    "data": Value("string")  # we'll store as JSON string for Arrow compatibility
})
receipts_and_data_dataset = Dataset.from_list(receipts_and_data_schema_checked, features=features)
receipts_and_data_dataset = receipts_and_data_dataset.cast_column("image_filename", HFImage(decode=True))

In [None]:
def tokenize(batch):
    import json
    from PIL import Image as PILImage

    # Convert JSON string back to dict for schema extraction
    data_dict = json.loads(batch["data"])
    schema_str = json.dumps(get_schema(data_dict), ensure_ascii=False)

    # HF Image column is already decoded as PIL.Image
    image = batch["image_filename"]

    # Tokenize input
    input_tokens = processor(
        image, f"<|image|><|begin_of_text|>{schema_str}", return_tensors="pt"
    )

    # Tokenize output
    output_tokens = tokenizer(json.dumps(data_dict, ensure_ascii=False), return_tensors="pt")

    return {
        "input_ids": input_tokens["input_ids"][0],   # remove batch dim
        "attention_mask": input_tokens["attention_mask"][0],
        "output_ids": output_tokens["input_ids"][0]
    }

In [None]:
receipts_and_data_dataset =receipts_and_data_dataset.map(tokenize)

In [None]:
receipts_and_data_dataset

# Model

In [None]:
import os
import pprint
from typing import Dict
import torch
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from transformers import MllamaForConditionalGeneration, AutoProcessor
from transformers.utils.quantization_config import BitsAndBytesConfig
from utils import load_model_quantized

In [None]:
hf_model_id = "meta-llama/Llama-3.2-11B-Vision"

base_model_path: str = "models/llama/Llama-3.2-11B-Vision-base"

finetune_dataset_path = "dataset"

finetuned_model_path = "models/llama/Llama-3.2-11B-Vision-ReceiptIQ-tuned"
if not os.path.exists(finetuned_model_path):
    os.makedirs(finetuned_model_path)

model, processor = load_model_quantized(base_model_path, hf_model_id)
tokenizer = processor.tokenizer

model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj"],
    bias="none",
)

finetuned_model = get_peft_model(model, lora_config)
finetuned_model.print_trainable_parameters()

# Prepare finetuning dataset of receipts
It the dataset comes from kaggle `https://www.kaggle.com/datasets/dhiaznaidi/receiptdatasetssd300v2` and contains
- Images under `dataset/images` of receipts
- Extracted data under `dataset/gdt` containing `company`, `total`, `date` and `address`
- Data about extracted data in `data/info_data` contains the data with important words and their coordinates

The expected data should contain `receipt_path`, `schema`(for now just the company, date, total and address), and the `output`

## Load Dataset

In [None]:
import json
from typing import List
from dataclasses import dataclass

@dataclass
class ReceiptData:
    receipt_path: str
    schema: Dict
    output: Dict

fixed_schema = {
    "total": "number//total amount of the invoice",
    "company": "string//the name of the company or person doing the supply",
    "date": "date//the date of the invoice",
    "address": "string//address of the person or company doing the supply",
}

def prepare_dataset() -> List[ReceiptData]:
    dataset: List[ReceiptData] = []
    for receipt in os.listdir(f"{finetune_dataset_path}/info_data"):
        with open(f"{finetune_dataset_path}/info_data/{receipt}", "r") as f:
            receipt_info = json.loads(f.read())
            receipt_id  = receipt_info.get("image_path","").split("/")[5].replace(".jpg","")
            with open(f"{finetune_dataset_path}/gdt/{receipt_id}.json","r") as df:
                extracted_data = json.loads(df.read())
                for k,v in extracted_data.items():
                    if k in receipt_info.keys():
                        extracted_data[k] = {
                            "value": v,
                            "coordinates": receipt_info[k]
                        }
                    else:
                        extracted_data[k] = {
                            "value": v,
                            "coordinates": {
                                "xmin": 0,
                                "ymin": 0,
                                "xmax": 0,
                                "ymax": 0
                            }
                        }
            receipt_data = ReceiptData(
                receipt_path=receipt_info.get("image_path","").replace("/content/Dataset/train/",f"{finetune_dataset_path}/"),
                schema=fixed_schema,
                output=extracted_data
            )
            dataset.append(receipt_data)
    return dataset
data_list = prepare_dataset()
print(len(data_list))
pprint.pprint(data_list[0])

## Tokenize the dataset

In [None]:
from typing import Dict
from PIL import Image
from dataclasses import asdict
from datasets import Dataset

def convert_to_dataset(data_list: List[ReceiptData]) -> Dataset:
    data_dict = {
        "prompt": [asdict(record) for record in data_list], 
        "completion": [{"output": record.output} for record in data_list]
    }
    dataset = Dataset.from_dict(data_dict)
    return dataset

dataset = convert_to_dataset(data_list=data_list)

def tokenize_batched(examples: Dict[str, List]) -> Dict[str, torch.Tensor]:
    receipt_paths = [item.get("receipt_path") for item in examples["prompt"]]
    schemas = [item.get("schema") for item in examples["prompt"]]
    outputs = [f'{item.get("output")}' for item in examples["completion"]]
    
    images = [[Image.open(path)] for path in receipt_paths] 
    input_prompts = [f'<|image|><|begin_of_text|>{schema}' for schema in schemas]
    
    model_inputs = processor(
        images=images,
        text=input_prompts,
        return_tensors="pt",
        padding="longest",
        truncation=True
    )
      
    completion_tokens = tokenizer(
        outputs,
        return_tensors="pt",
        padding="longest",
        truncation=True
    )

    batch_input_ids = []
    batch_attention_mask = []
    batch_labels = []

    # Determine max length across all combined sequences in this batch
    max_len = 0
    for i in range(len(input_prompts)):
        combined_len = model_inputs["input_ids"][i].shape[0] + completion_tokens["input_ids"][i].shape[0]
        if combined_len > max_len:
            max_len = combined_len
    
    # Optional: Cap max_len to a global maximum for consistency
    # if max_len > 1024: max_len = 1024 

    for i in range(len(input_prompts)):
        prompt_input_ids = model_inputs["input_ids"][i]
        prompt_attention_mask = model_inputs["attention_mask"][i]
        
        output_input_ids = completion_tokens["input_ids"][i]
        output_attention_mask = completion_tokens["attention_mask"][i]

        combined_input_ids = torch.cat((prompt_input_ids, output_input_ids))
        combined_attention_mask = torch.cat((prompt_attention_mask, output_attention_mask))

        labels = combined_input_ids.clone()
        prompt_length = prompt_input_ids.shape[0]
        labels[:prompt_length] = -100

        # Pad to max_len of the batch
        current_len = combined_input_ids.shape[0]
        if current_len < max_len:
            padding_len = max_len - current_len
            combined_input_ids = torch.cat([combined_input_ids, torch.full((padding_len,), tokenizer.pad_token_id, dtype=torch.long)])
            combined_attention_mask = torch.cat([combined_attention_mask, torch.zeros(padding_len, dtype=torch.long)])
            labels = torch.cat([labels, torch.full((padding_len,), -100, dtype=torch.long)])

        batch_input_ids.append(combined_input_ids)
        batch_attention_mask.append(combined_attention_mask)
        batch_labels.append(labels)

    return {
        "input_ids": torch.stack(batch_input_ids),
        "attention_mask": torch.stack(batch_attention_mask),
        "labels": torch.stack(batch_labels),
        "pixel_values": model_inputs["pixel_values"],
        "aspect_ratio_ids": model_inputs["aspect_ratio_ids"],
        "aspect_ratio_mask": model_inputs['aspect_ratio_mask'],
    }

tokenized_dataset = dataset.map(
    tokenize_batched,
    batched=True,
    batch_size=4,
    remove_columns=dataset.column_names
)

# LoRA Finetune the model

In [None]:
import wandb
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup
from torch.utils.data import DataLoader
from transformers.data.data_collator import default_data_collator

device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 1e-4
num_epochs = 1
batch_size = 1

optimizer = torch.optim.AdamW(finetuned_model.parameters(), lr=lr)
train_dataloader = DataLoader(
    tokenized_dataset, 
    shuffle=True, 
    collate_fn=default_data_collator, 
    batch_size=batch_size, 
    pin_memory=True if device == "cuda" else False,
    num_workers=2
)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs),
)

print(f"Training on device: {device}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB" if device == "cuda" else "Using CPU")
config = {
    "lr": lr,
    "num_epochs": num_epochs,
    "batch_size": batch_size,
    "device": device,
    "lr_scheduler": "linear_schedule_with_warmup",
    "optimizer": "AdamW"
}
with wandb.init(config=config) as run:
    run.watch(finetuned_model)
    finetuned_model.train()
    for epoch in range(num_epochs):
        for step, batch in enumerate(tqdm(train_dataloader)):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = finetuned_model(**batch)
            print(f"forward pass done")
            loss = outputs.loss
            print(f"loss: {loss}")
            loss.backward()
            print(f"back pass done")
            optimizer.step()
            print(f"optimizer step done")
            lr_scheduler.step()
            optimizer.zero_grad()
            print(f"{step=}: {loss=}")
            break

# Save Model

In [None]:
import time

time_now = time.time()
finetuned_model.save_pretrained(os.path.join(finetuned_model_path, f"receiptiq_model_{time_now}"))