In [1]:
import torch
from PIL import Image
from transformers import DonutProcessor, VisionEncoderDecoderModel
import json


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 1️⃣ Model & Processor
model_name = "mychen76/invoice-and-receipts_donut_v1"
processor = DonutProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)

The image processor of type `DonutImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. 
Loading weights: 100%|██████████| 483/483 [00:00<00:00, 2313.23it/s, Materializing param=encoder.encoder.layers.3.blocks.1.output.dense.weight]                         


In [3]:
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

VisionEncoderDecoderModel(
  (encoder): DonutSwinModel(
    (embeddings): DonutSwinEmbeddings(
      (patch_embeddings): DonutSwinPatchEmbeddings(
        (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): DonutSwinEncoder(
      (layers): ModuleList(
        (0): DonutSwinStage(
          (blocks): ModuleList(
            (0): DonutSwinLayer(
              (layernorm_before): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
              (attention): DonutSwinAttention(
                (self): DonutSwinSelfAttention(
                  (query): Linear(in_features=128, out_features=128, bias=True)
                  (key): Linear(in_features=128, out_features=128, bias=True)
                  (value): Linear(in_features=128, out_features=128, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
                )

In [4]:
def donut_to_json(image: Image.Image, task_prompt: str = ""):
    pixel_values = processor(image.convert("RGB"), return_tensors="pt").pixel_values.to(device)

    # ✅ Make sure decoder_input_ids is NOT empty
    if task_prompt and task_prompt.strip():
        decoder_input_ids = processor.tokenizer(
            task_prompt, add_special_tokens=False, return_tensors="pt"
        ).input_ids.to(device)
    else:
        start_id = model.config.decoder_start_token_id
        if start_id is None:
            # fallback to BOS token if decoder_start_token_id isn't set
            start_id = processor.tokenizer.bos_token_id
        decoder_input_ids = torch.tensor([[start_id]], device=device)

    outputs = model.generate(
        pixel_values,
        decoder_input_ids=decoder_input_ids,
        max_length=model.decoder.config.max_position_embeddings,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        num_beams=1,
        use_cache=True,
    )

    seq = processor.batch_decode(outputs, skip_special_tokens=False)[0]
    seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "").strip()

    data = processor.token2json(seq)
    return data, json.dumps(data, ensure_ascii=False, indent=2)

In [5]:
img = Image.open("./data/sample_hard.jpg")
data, json_text = donut_to_json(img, task_prompt="")  # or try "<s_receipt>" if your model expects it
print(json_text)

{
  "store_name": "Walmart",
  "store_addr": "PLNOTX75024",
  "telephone": "(972)731-9576",
  "date": "07/29/14 07/29/14",
  "time": "13:57:55 13:57:53",
  "subtotal": "90.32",
  "tax": "6.16",
  "total": "90.32",
  "ignore": "",
  "tips": "",
  "line_items": [
    {
      "item_key": "",
      "item_name": "COM3PCSET",
      "item_value": "9.44",
      "item_quantity": "1"
    },
    {
      "item_key": "",
      "item_name": "COM3PKB",
      "item_value": "7.24",
      "item_quantity": "1"
    },
    {
      "item_key": "",
      "item_name": "COMBOS",
      "item_value": "7.24",
      "item_quantity": "1"
    },
    {
      "item_key": "",
      "item_name": "DRSHITHS",
      "item_value": "7.97",
      "item_quantity": "1"
    },
    {
      "item_key": "",
      "item_name": "COM2PKHAIR",
      "item_value": "4.97",
      "item_quantity": "1"
    },
    {
      "item_key": "",
      "item_name": "BABYWIPS",
      "item_value": "1.97",
      "item_quantity": "1"
    },
    {
      