Donut is a OCR-free Visual Document Understanding model primarily designed to build solutions for digital documents. In this proposed model, we can solve the problem in one step without utilizing any OCR engine compared to other approaches like LayoutLM V2 which uses OCR result as input. 

Donut performed well on Document Classification, DocVQA, and Information Extraction. It is more faster, scalable(because built upon transformer architecture), and cost effective(doesn't require any OCR). 

*Donut = Visual Transformer Encoder + Textual Tranformer Decoder*

![image.png](attachment:2181838f-ba0e-427f-8253-01ba296c1ea6.png)

[Image Source](https://arxiv.org/pdf/2111.15664.pdf)


*this notebook has code samples from the [official huggingface documentation](https://huggingface.co/docs/transformers/model_doc/donut). Objective of this notebook is to check values of variables and parameters at each stage and to add descriptive comments & analysis.*

In [None]:
import re

!pip install -qU transformers[sentencepiece]

from transformers import DonutProcessor, VisionEncoderDecoderModel

In [None]:
import torch

# Checking if GPU accelerator is accessible
device = "cuda" if torch.cuda.is_available() else "cpu"
device

# Donut Input 
Donut model requires input in the form of an image and a textual prompt. 

### Prompt: 

The textual prompt is preprocessed through Donut processor and is used to generate an output from the model. The resulting output can be utilized for other downstream tasks like Document Classification or Text Summarization. 

Model output depends upon the input prompt which further enables the prompt chaining. 

Our Donut model uses BART model as a textual Decoder and BART model require a prompt to generate an output. The prompt can be an instruction, a partial sentence, or a specific context that the model uses to generate an output.

### Image:

Image is preprocessed by the Donut Processor that includes image reshaping and normalization. 

# Testing Donut
### 1. Document Classification
### 2. DocVQA
### 3. Information Extraction


### 1. Document Classification

In [None]:
# Loading the pretrained model and processor for classification task
classification_model_ckpt = "naver-clova-ix/donut-base-finetuned-rvlcdip"

classification_model = VisionEncoderDecoderModel.from_pretrained(classification_model_ckpt)
classification_processor = DonutProcessor.from_pretrained(classification_model_ckpt)  # restart the session if it is causing any error

In [None]:
!pip install -qU datasets
from datasets import load_dataset 

classification_dataset = load_dataset("nielsr/rvl-cdip-demo", split="train")
classification_dataset

In [None]:
classification_image = classification_dataset[1]['image']
classification_image

In [None]:
classification_image.height, classification_image.width

In [None]:
classification_image.getbands()  # original image has only one channel

Donut processor do all the preprocessing on image like image normalization or reshaping. we can notice that the shape of image is tranformed to the format required by the model. 

In [None]:
# Donut processor require RGB image, so convert it before feeding.
classification_pixel_values = classification_processor(classification_image.convert("RGB"), return_tensors="pt").pixel_values
classification_pixel_values.size()

In [None]:
# Prepare decoder inputs
classification_task_prompt = "<s_rvlcdip>"
classification_decoder_input_ids = classification_processor.tokenizer(classification_task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
classification_decoder_input_ids

In [None]:
classification_processor.tokenizer.convert_ids_to_tokens(classification_decoder_input_ids[0])

In [None]:
# Document Classification using pretrianed model
classification_outputs = classification_model.generate(
   classification_pixel_values.to(device), 
   decoder_input_ids=classification_decoder_input_ids.to(device),
   max_length=classification_model.decoder.config.max_position_embeddings,
   pad_token_id=classification_processor.tokenizer.pad_token_id,
   eos_token_id=classification_processor.tokenizer.eos_token_id,
   use_cache=True,
   bad_words_ids=[[classification_processor.tokenizer.unk_token_id]],
   return_dict_in_generate=True
)

# Get the classification from the processed output
classification_sequence = classification_processor.batch_decode(classification_outputs.sequences)[0]
print(classification_processor.token2json(classification_sequence))

In [None]:
classification_sequence

In [None]:
# Getting label names from dataset information
ground_truth_names = classification_dataset.info.features['label'].names

ground_truth = classification_dataset[1]['label']
ground_truth_names[ground_truth]

The prediction is same as the ground truth. Our is model has correctly identified the class of this image. 

### 2. DocVQA

In [None]:
# Loading the pretrained model and processor for Information Extraction task
vqa_model_ckpt = "naver-clova-ix/donut-base-finetuned-docvqa"

vqa_model = VisionEncoderDecoderModel.from_pretrained(vqa_model_ckpt)
vqa_processor = DonutProcessor.from_pretrained(vqa_model_ckpt)  # restart the session if it is causing any error

In [None]:
vqa_dataset = load_dataset("hf-internal-testing/example-documents", split="test")
vqa_dataset

In [None]:
vqa_image = vqa_dataset[0]['image']
vqa_image

In [None]:
# Donut processor require RGB image, so convert it before feeding.
vqa_pixel_values = vqa_processor(vqa_image.convert("RGB"), return_tensors="pt").pixel_values
vqa_pixel_values.size()

In [None]:
vqa_task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
question = "When is the exhibits open?"
vqa_prompt = vqa_task_prompt.replace("{user_input}", question)
vqa_decoder_input_ids = vqa_processor.tokenizer(vqa_prompt, add_special_tokens=False, return_tensors="pt").input_ids
vqa_decoder_input_ids

In [None]:
vqa_outputs = vqa_model.generate(
    vqa_pixel_values.to(device),
    decoder_input_ids=vqa_decoder_input_ids.to(device),
    max_length=vqa_model.decoder.config.max_position_embeddings,
    pad_token_id=vqa_processor.tokenizer.pad_token_id,
    eos_token_id=vqa_processor.tokenizer.eos_token_id,
    use_cache=True,
    bad_words_ids=[[vqa_processor.tokenizer.unk_token_id]],
    return_dict_in_generate=True,
)

vqa_sequence = vqa_processor.batch_decode(vqa_outputs.sequences)[0]
vqa_sequence = vqa_sequence.replace(vqa_processor.tokenizer.eos_token, "").replace(vqa_processor.tokenizer.pad_token, "")
vqa_sequence = re.sub(r"<.*?>", "", vqa_sequence, count=1).strip()  # remove first task start token
print(vqa_processor.token2json(vqa_sequence))

### 3. Information Extraction

In [None]:
# Loading the pretrained model and processor for Information Extraction task
ie_model_ckpt = "naver-clova-ix/donut-base-finetuned-cord-v2"

ie_model = VisionEncoderDecoderModel.from_pretrained(ie_model_ckpt)
ie_processor = DonutProcessor.from_pretrained(ie_model_ckpt)  # restart the session if it is causing any error

In [None]:
ie_dataset = load_dataset("hf-internal-testing/example-documents", split='test')
ie_dataset

In [None]:
ie_image = ie_dataset[2]['image']
ie_image

In [None]:
# Donut processor require RGB image, so convert it before feeding.
ie_pixel_values = ie_processor(ie_image.convert("RGB"), return_tensors="pt").pixel_values
ie_pixel_values.size()

In [None]:
# Prepare decoder inputs
ie_task_prompt = "<s_cord-v2>"
ie_decoder_input_ids = ie_processor.tokenizer(ie_task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
ie_decoder_input_ids

In [None]:
ie_processor.tokenizer.convert_ids_to_tokens(ie_decoder_input_ids[0])

In [None]:
# Information Extraction using pretrianed model
ie_outputs = ie_model.generate(
    ie_pixel_values.to(device),
    decoder_input_ids=ie_decoder_input_ids.to(device),
    max_length=ie_model.decoder.config.max_position_embeddings,
    pad_token_id=ie_processor.tokenizer.pad_token_id,
    eos_token_id=ie_processor.tokenizer.eos_token_id,
    use_cache=True,
    bad_words_ids=[[ie_processor.tokenizer.unk_token_id]],
    return_dict_in_generate=True,
)

ie_sequence = ie_processor.batch_decode(ie_outputs.sequences)[0]
ie_sequence = ie_sequence.replace(ie_processor.tokenizer.eos_token, "").replace(ie_processor.tokenizer.pad_token, "")
ie_sequence = re.sub(r"<.*?>", "", ie_sequence, count=1).strip()  # remove first task start token
print(ie_processor.token2json(ie_sequence))

# References & other reading resources
- Research paper - https://arxiv.org/pdf/2111.15664.pdf
- Dataset for testing - https://huggingface.co/datasets/nielsr/rvl-cdip-demo
- HuggingFace official Documentation - https://huggingface.co/docs/transformers/model_doc/donut
- Prompt Chaining - https://docs.anthropic.com/claude/docs/prompt-chaining
- Text generation prompts for BART - https://flowgpt.gitbook.io/prompt-engineering-guide/group-4/prominent-prompt-engineering-models/t5-and-bart-models
- BART Architecture - https://www.projectpro.io/article/transformers-bart-model-explained/553

*want to read the annotated research paper, visit https://github.com/yesdeepakmittal/model-training-scripts*

Please consider upvoting/sharing it if you like it. 