In [1]:
import torch
from PIL import Image
from transformers import DonutProcessor, VisionEncoderDecoderModel

In [27]:
model_id = "mychen76/invoice-and-receipts_donut_v1"
processor = DonutProcessor.from_pretrained(model_id)
model = VisionEncoderDecoderModel.from_pretrained(model_id)
model.eval()

Loading weights:   0%|          | 0/483 [00:00<?, ?it/s]

VisionEncoderDecoderModel(
  (encoder): DonutSwinModel(
    (embeddings): DonutSwinEmbeddings(
      (patch_embeddings): DonutSwinPatchEmbeddings(
        (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): DonutSwinEncoder(
      (layers): ModuleList(
        (0): DonutSwinStage(
          (blocks): ModuleList(
            (0): DonutSwinLayer(
              (layernorm_before): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
              (attention): DonutSwinAttention(
                (self): DonutSwinSelfAttention(
                  (query): Linear(in_features=128, out_features=128, bias=True)
                  (key): Linear(in_features=128, out_features=128, bias=True)
                  (value): Linear(in_features=128, out_features=128, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
                )

In [28]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

VisionEncoderDecoderModel(
  (encoder): DonutSwinModel(
    (embeddings): DonutSwinEmbeddings(
      (patch_embeddings): DonutSwinPatchEmbeddings(
        (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): DonutSwinEncoder(
      (layers): ModuleList(
        (0): DonutSwinStage(
          (blocks): ModuleList(
            (0): DonutSwinLayer(
              (layernorm_before): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
              (attention): DonutSwinAttention(
                (self): DonutSwinSelfAttention(
                  (query): Linear(in_features=128, out_features=128, bias=True)
                  (key): Linear(in_features=128, out_features=128, bias=True)
                  (value): Linear(in_features=128, out_features=128, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
                )

In [29]:
image_path = "../data/model_data/1.png"
image = Image.open(image_path).convert("RGB")

In [30]:
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)
task_prompt = "<s_invoice>"

In [31]:
output_ids = model.generate(
    pixel_values,
    decoder_start_token_id=processor.tokenizer.cls_token_id,
    max_length=1024,
    num_beams=5
)

In [32]:
# output_ids

In [33]:
predicted = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
print(predicted)

</s_store_name></s_store_name><s_store_addr> 987RenewableLn EcoCity,FL33134</s_store_addr><s_telephone></s_telephone><s_date> 03/12/2025</s_date><s_time></s_time><s_subtotal> 4,900.00</s_subtotal><s_tax> 196.00</s_tax><s_total> 5,096.00</s_total><s_ignore></s_ignore><s_tips></s_tips><s_line_items><s_item_key></s_item_key><s_item_name> UnitPriceOgyAmount</s_item_name><s_item_value> 4,000.00</s_item_value><s_item_quantity> 1</s_item_quantity><sep/><s_item_key></s_item_key><s_item_name> SolarPanels</s_item_name><s_item_value></s_item_value><s_item_quantity> 1</s_item_quantity><sep/><s_item_key></s_item_key><s_item_name></s_item_name><s_item_value></s_item_value><s_item_quantity> 1</s_item_quantity><sep/><s_item_key></s_item_key><s_item_name></s_item_name><s_item_value></s_item_value><s_item_quantity> 1</s_item_quantity><sep/><s_item_key> Description</s_item_key><s_item_name></s_item_name><s_item_value></s_item_value><s_item_quantity></s_item_quantity><sep/><s_item_key> Description</s_item

In [34]:
try:
    result_json = processor.token2json(predicted)
    print("JSON PARSED OUTPUT:")
    print(result_json)
except Exception as e:
    print("Could not convert to JSON:", e)

JSON PARSED OUTPUT:
{'store_addr': '987RenewableLn EcoCity,FL33134', 'telephone': '', 'date': '03/12/2025', 'time': '', 'subtotal': '4,900.00', 'tax': '196.00', 'total': '5,096.00', 'ignore': '', 'tips': '', 'line_items': [{'item_key': '', 'item_name': 'UnitPriceOgyAmount', 'item_value': '4,000.00', 'item_quantity': '1'}, {'item_key': '', 'item_name': 'SolarPanels', 'item_value': '', 'item_quantity': '1'}, {'item_key': '', 'item_name': '', 'item_value': '', 'item_quantity': '1'}, {'item_key': '', 'item_name': '', 'item_value': '', 'item_quantity': '1'}, {'item_key': 'Description', 'item_name': '', 'item_value': '', 'item_quantity': ''}, {'item_key': 'Description', 'item_name': '', 'item_value': '', 'item_quantity': ''}, {'item_key': 'Description', 'item_name': '', 'item_value': '', 'item_quantity': ''}, {'item_key': 'Description', 'item_name': '', 'item_value': '', 'item_quantity': ''}, {'item_key': 'Description', 'item_name': '', 'item_value': '', 'item_quantity': ''}, {'item_key': 'D

In [35]:
result_json

{'store_addr': '987RenewableLn EcoCity,FL33134',
 'telephone': '',
 'date': '03/12/2025',
 'time': '',
 'subtotal': '4,900.00',
 'tax': '196.00',
 'total': '5,096.00',
 'ignore': '',
 'tips': '',
 'line_items': [{'item_key': '',
   'item_name': 'UnitPriceOgyAmount',
   'item_value': '4,000.00',
   'item_quantity': '1'},
  {'item_key': '',
   'item_name': 'SolarPanels',
   'item_value': '',
   'item_quantity': '1'},
  {'item_key': '', 'item_name': '', 'item_value': '', 'item_quantity': '1'},
  {'item_key': '', 'item_name': '', 'item_value': '', 'item_quantity': '1'},
  {'item_key': 'Description',
   'item_name': '',
   'item_value': '',
   'item_quantity': ''},
  {'item_key': 'Description',
   'item_name': '',
   'item_value': '',
   'item_quantity': ''},
  {'item_key': 'Description',
   'item_name': '',
   'item_value': '',
   'item_quantity': ''},
  {'item_key': 'Description',
   'item_name': '',
   'item_value': '',
   'item_quantity': ''},
  {'item_key': 'Description',
   'item_name