<a href="https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/Donut/DocVQA/Batched_generation_with_Donut.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Batched generation with Donut

This notebook shows how to do batched generation with the Donut 🍩 model.

https://huggingface.co/docs/transformers/model_doc/donut#inference-examples

## Set-up environment

In [None]:
%load_ext autoreload
%autoreload 2

## Load model and processor

In [None]:
from transformers import DonutProcessor, VisionEncoderDecoderModel

processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
# important: we need to pad from the left when doing batched inference
processor.tokenizer.padding_side = 'left'
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")

## Load dataset

In [None]:
from datasets import load_dataset

dataset = load_dataset("hf-internal-testing/example-documents", split="test")
dataset

In [None]:
image = dataset[1]["image"]
width, height = image.size
image.resize((int(0.3*width), (int(0.3*height))))

## Run batched inference

As can be seen below, the `generate()` method supports batched generation, which will be a lot faster than generating one example at a time.

In [None]:
import re
import torch

# move model to GPU if it's available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# prepare encoder inputs
pixel_values = processor(images=dataset['image'][:2], return_tensors="pt").pixel_values
batch_size = pixel_values.shape[0]

# prepare decoder inputs
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
questions = ["when is the coffee break?", "which year was the report created?"]
prompts = [task_prompt.replace("{user_input}", question) for question in questions]
decoder_input_ids = processor.tokenizer(prompts, add_special_tokens=False, padding=True, return_tensors="pt").input_ids

outputs = model.generate(
    pixel_values.to(device),
    decoder_input_ids=decoder_input_ids.to(device),
    max_length=model.decoder.config.max_position_embeddings,
    early_stopping=True,
    pad_token_id=processor.tokenizer.pad_token_id,
    eos_token_id=processor.tokenizer.eos_token_id,
    use_cache=True,
    num_beams=1,
    bad_words_ids=[[processor.tokenizer.unk_token_id]],
    return_dict_in_generate=True,
)

sequences = processor.batch_decode(outputs.sequences)

for seq in sequences:
  sequence = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
  sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token
  print(processor.token2json(sequence))