# Fine-tune Donut 🍩 on DocVQA

In this notebook, we'll fine-tune Donut (which is an instance of [`VisionEncoderDecoderModel`](https://huggingface.co/docs/transformers/model_doc/vision-encoder-decoder)) on a DocVQA dataset, which is a dataset consisting of (document, question, answer(s)) triplets. This way, the model will learn to look at an image, and answer a question related to the document. Pretty cool, isn't it?

## Set-up environment

First, let's install the relevant libraries:
* 🤗 Transformers, for the model
* 🤗 Datasets, for loading + processing the data
* PyTorch Lightning, for training the model
* Weights and Biases, for logging metrics during training
* Sentencepiece, used for tokenization.

In [1]:
!pip install -q git+https://github.com/huggingface/transformers.git

In [2]:
!pip install -q datasets sentencepiece

In [3]:
!pip install -q pytorch-lightning wandb

## Load dataset

Next, let's load the dataset from the [hub](https://huggingface.co/datasets/naver-clova-ix/cord-v2). We're prepared a minimal dataset for DocVQA, the notebook for that can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Donut/DocVQA/Creating_a_toy_DocVQA_dataset_for_Donut.ipynb).

Important here is that we've added a "ground_truth" column, containing the ground truth JSON which the model will learn to generate.

In [6]:
from huggingface_hub import login
# grab a token from https://huggingface.co/settings/tokens
tkn = input("Huggingface access_token [https://huggingface.co/settings/tokens]")
login(tkn)

  from .autonotebook import tqdm as notebook_tqdm


Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid.
Your token has been saved to /home/nathang/.huggingface/token
Login successful


In [8]:
from datasets import load_dataset

dataset = load_dataset("hublot/fish-label")

Using custom data configuration hublot--fish-label-54c937f822e7c2af


Downloading and preparing dataset None/None to /home/nathang/.cache/huggingface/datasets/hublot___parquet/hublot--fish-label-54c937f822e7c2af/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Downloading data: 100%|██████████| 432M/432M [00:12<00:00, 34.7MB/s]

[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Downloading data: 100%|██

Dataset parquet downloaded and prepared to /home/nathang/.cache/huggingface/datasets/hublot___parquet/hublot--fish-label-54c937f822e7c2af/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.


100%|██████████| 3/3 [00:00<00:00, 36.36it/s]


As can be seen, the dataset contains a training and test split, and each example consists of an image, a question ("query"), and one or more answers.

In [9]:
dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'ground_truth'],
        num_rows: 1929
    })
    validation: Dataset({
        features: ['image', 'ground_truth'],
        num_rows: 65
    })
    test: Dataset({
        features: ['image', 'ground_truth'],
        num_rows: 65
    })
})

## Load model and processor

Next, we load the model (which is an instance of [VisionEncoderDecoderModel](https://huggingface.co/docs/transformers/model_doc/vision-encoder-decoder), and the processor, which is the object that can be used to prepare inputs for the model.

In [None]:
from transformers import VisionEncoderDecoderConfig

max_length = 256
image_size = [960, 640]

# update image_size of the encoder
# during pre-training, a larger image size was used
config = VisionEncoderDecoderConfig.from_pretrained("naver-clova-ix/donut-base")
config.encoder.image_size = image_size # (height, width)
# update max_length of the decoder (for generation)
config.decoder.max_length = max_length
# TODO we should actually update max_position_embeddings and interpolate the pre-trained ones:
# https://github.com/clovaai/donut/blob/0acc65a85d140852b8d9928565f0f6b2d98dc088/donut/model.py#L602

In [None]:
from transformers import DonutProcessor, VisionEncoderDecoderModel, BartConfig

processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base", config=config)

Downloading:   0%|          | 0.00/362 [00:00<?, ?B/s]

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


Downloading:   0%|          | 0.00/518 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.30M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/4.01M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/71.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/355 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/809M [00:00<?, ?B/s]

## Add special tokens

For DocVQA, we add special tokens for \<yes> and \<no/>, to make sure that the model (actually the decoder) learns embedding vectors for those explicitly.

In [None]:
from typing import List

def add_tokens(list_of_tokens: List[str]):
    """
    Add tokens to tokenizer and resize the token embeddings
    """
    newly_added_num = processor.tokenizer.add_tokens(list_of_tokens)
    if newly_added_num > 0:
        model.decoder.resize_token_embeddings(len(processor.tokenizer))

In [None]:
additional_tokens = []

add_tokens(additional_tokens)

## Create PyTorch dataset

Here we create a regular PyTorch dataset.

The model doesn't directly take the (image, JSON) pairs as input and labels. Rather, we create `pixel_values`, `decoder_input_ids` and `labels`. These are all PyTorch tensors. The `pixel_values` are the input images (resized, padded and normalized), the `decoder_input_ids` are the decoder inputs, and the `labels` are the decoder targets.

The reason we create the `decoder_input_ids` explicitly here is because otherwise, the model would create them automatically based on the `labels` (by prepending the decoder start token ID, replacing -100 tokens by padding tokens). The reason for that is that we don't want the model to learn to generate the entire prompt, which includes the question. Rather, we only want it to learn to generate the answer. Hence, we'll set the labels of the prompt tokens to -100.



In [None]:
import json
import random
from typing import Any, List, Tuple

import torch
from torch.utils.data import Dataset

added_tokens = []

class DonutDataset(Dataset):
    """
    DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets)
    Each row, consists of image path(png/jpg/jpeg) and gt data (json/jsonl/txt),
    and it will be converted into input_tensor(vectorized image) and input_ids(tokenized string).
    Args:
        dataset_name_or_path: name of dataset (available at huggingface.co/datasets) or the path containing image files and metadata.jsonl
        max_length: the max number of tokens for the target sequences
        split: whether to load "train", "validation" or "test" split
        ignore_id: ignore_index for torch.nn.CrossEntropyLoss
        task_start_token: the special token to be fed to the decoder to conduct the target task
        prompt_end_token: the special token at the end of the sequences
        sort_json_key: whether or not to sort the JSON keys
    """

    def __init__(
        self,
        dataset_name_or_path: str,
        max_length: int,
        split: str = "train",
        ignore_id: int = -100,
        task_start_token: str = "<s>",
        prompt_end_token: str = None,
        sort_json_key: bool = True,
    ):
        super().__init__()

        self.max_length = max_length
        self.split = split
        self.ignore_id = ignore_id
        self.task_start_token = task_start_token
        self.prompt_end_token = prompt_end_token if prompt_end_token else task_start_token
        self.sort_json_key = sort_json_key

        self.dataset = load_dataset(dataset_name_or_path, split=self.split)
        self.dataset_length = len(self.dataset)

        self.gt_token_sequences = []
        for sample in self.dataset:
            ground_truth = json.loads(sample["ground_truth"])
            if "gt_parses" in ground_truth:  # when multiple ground truths are available, e.g., docvqa
                assert isinstance(ground_truth["gt_parses"], list)
                gt_jsons = ground_truth["gt_parses"]
            else:
                assert "gt_parse" in ground_truth and isinstance(ground_truth["gt_parse"], dict)
                gt_jsons = [ground_truth["gt_parse"]]

            self.gt_token_sequences.append(
                [
                    self.json2token(
                        gt_json,
                        update_special_tokens_for_json_key=self.split == "train",
                        sort_json_key=self.sort_json_key,
                    )
                    + processor.tokenizer.eos_token
                    for gt_json in gt_jsons  # load json from list of json
                ]
            )

        self.add_tokens([self.task_start_token, self.prompt_end_token])
        self.prompt_end_token_id = processor.tokenizer.convert_tokens_to_ids(self.prompt_end_token)

    def json2token(self, obj: Any, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):
        """
        Convert an ordered JSON object into a token sequence
        """
        if type(obj) == dict:
            if len(obj) == 1 and "text_sequence" in obj:
                return obj["text_sequence"]
            else:
                output = ""
                if sort_json_key:
                    keys = sorted(obj.keys(), reverse=True)
                else:
                    keys = obj.keys()
                for k in keys:
                    if update_special_tokens_for_json_key:
                        self.add_tokens([fr"<s_{k}>", fr"</s_{k}>"])
                    output += (
                        fr"<s_{k}>"
                        + self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
                        + fr"</s_{k}>"
                    )
                return output
        elif type(obj) == list:
            return r"<sep/>".join(
                [self.json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
            )
        else:
            obj = str(obj)
            if f"<{obj}/>" in added_tokens:
                obj = f"<{obj}/>"  # for categorical special tokens
            return obj
    
    def add_tokens(self, list_of_tokens: List[str]):
        """
        Add special tokens to tokenizer and resize the token embeddings of the decoder
        """
        newly_added_num = processor.tokenizer.add_tokens(list_of_tokens)
        if newly_added_num > 0:
            model.decoder.resize_token_embeddings(len(processor.tokenizer))
            added_tokens.extend(list_of_tokens)
    
    def __len__(self) -> int:
        return self.dataset_length - 1

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Load image from image_path of given dataset_path and convert into input_tensor and labels
        Convert gt data into input_ids (tokenized string)
        Returns:
            input_tensor : preprocessed image
            input_ids : tokenized gt_data
            labels : masked labels (model doesn't need to predict prompt and pad token)
        """
        sample = self.dataset[idx]

        # input_tensor
        pixel_values = processor(sample["image"].convert("RGB"), random_padding=self.split == "train", return_tensors="pt").pixel_values
        input_tensor = pixel_values.squeeze()

        # input_ids
        processed_parse = random.choice(self.gt_token_sequences[idx])  # can be more than one, e.g., DocVQA Task 1
        input_ids = processor.tokenizer(
            processed_parse,
            add_special_tokens=False,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )["input_ids"].squeeze(0)

        if self.split == "train":
            labels = input_ids.clone()
            labels[
                labels == processor.tokenizer.pad_token_id
            ] = self.ignore_id  # model doesn't need to predict pad token
            labels[
                : torch.nonzero(labels == self.prompt_end_token_id).sum() + 1
            ] = self.ignore_id  # model doesn't need to predict prompt (for VQA)
            return input_tensor, input_ids, labels
        else:
            prompt_end_index = torch.nonzero(
                input_ids == self.prompt_end_token_id
            ).sum()  # return prompt end index instead of target output labels
            return input_tensor, input_ids, prompt_end_index, processed_parse

In [None]:
dataset

DatasetDict({
    test: Dataset({
        features: ['image', 'ground_truth'],
        num_rows: 65
    })
    train: Dataset({
        features: ['image', 'ground_truth'],
        num_rows: 1929
    })
    validation: Dataset({
        features: ['image', 'ground_truth'],
        num_rows: 65
    })
})

In [None]:
# we update some settings which differ from pretraining; namely the size of the images + no rotation required
# source: https://github.com/clovaai/donut/blob/master/config/train_cord.yaml
processor.feature_extractor.size = image_size[::-1] # should be (width, height)
processor.feature_extractor.do_align_long_axis = False

train_dataset = DonutDataset("hublot/fish-label", max_length=max_length,
                             split="train",
                             sort_json_key=False, # cord dataset is preprocessed, so no need for this
                             )

val_dataset = DonutDataset("hublot/fish-label", max_length=max_length,
                             split="test",
                             sort_json_key=False, # cord dataset is preprocessed, so no need for this
                             )



In [None]:
pixel_values, decoder_input_ids, labels = train_dataset[0]

In [None]:
pixel_values.shape

torch.Size([3, 2560, 1920])

In [None]:
print(labels)

tensor([ -100, 57527, 40769, 48942,  6626,  7690, 54630, 11938, 57528, 57522,
        57529, 10179, 50623, 53538,  1153, 57530, 57522, 57531, 34899, 38934,
        50708, 45508, 35816,  7031, 57532, 57522, 57533, 10558, 57534, 57522,
        57535, 34118, 57536, 57522, 57537, 52760,  3827, 50360,  6433, 57538,
        57522, 57539, 13482, 10382, 52476, 46300, 46192,   209, 38397, 57540,
        57526,     2,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 

In [None]:
for decoder_input_id, label in zip(decoder_input_ids.tolist()[:-1], labels.tolist()[1:]):
  if label != -100:
    print(processor.decode([decoder_input_id]), processor.decode([label]))
  else:
    print(processor.decode([decoder_input_id]), label)

<s_ner> <s_Peche>
<s_Peche> P
P ê
ê ché
ché (
( es
es )
) </s_Peche>
</s_Peche> <sep/>
<sep/> <s_Fish>
<s_Fish> Big
Big or
or ne
ne aux
aux </s_Fish>
</s_Fish> <sep/>
<sep/> <s_Latin>
<s_Latin> Lit
Lit tori
tori na
na litt
litt o
o rea
rea </s_Latin>
</s_Latin> <sep/>
<sep/> <s_ZoneCode>
<s_ZoneCode> 27
27 </s_ZoneCode>
</s_ZoneCode> <sep/>
<sep/> <s_SzoneCode>
<s_SzoneCode> VII
VII </s_SzoneCode>
</s_SzoneCode> <sep/>
<sep/> <s_Zone>
<s_Zone> Atlant
Atlant ique
ique Nord
Nord Est
Est </s_Zone>
</s_Zone> <sep/>
<sep/> <s_Szone>
<s_Szone> Man
Man che
che et
et Mer
Mer s
s Cel
Cel tiques
tiques </s_Szone>
</s_Szone> </s_ner>
</s_ner> </s>
</s> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad> -100
<pad

In [None]:
pixel_values, decoder_input_ids, prompt_end_index, answer = val_dataset[0]

In [None]:
pixel_values.shape

torch.Size([3, 2560, 1920])

In [None]:
prompt_end_index

tensor(0)

In [None]:
answer

'<s_ner><s_Packing>Aile</s_Packing><sep/><s_Fish>Raie</s_Fish><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Latin>Torpedo marmorata</s_Latin><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>VI</s_SzoneCode><sep/><s_Zone>Atlantique Nord Est</s_Zone><sep/><s_Szone>Ouest Ecosse</s_Szone><sep/><s_Gear>Chalut</s_Gear></s_ner></s>'

## Create PyTorch DataLoaders

Next, we create corresponding PyTorch DataLoaders.

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=2)

Let's verify a batch:

In [None]:
batch = next(iter(train_dataloader))
pixel_values, decoder_input_ids, labels = batch
print(pixel_values.shape)

torch.Size([1, 3, 2560, 1920])


In [None]:
decoder_input_ids.shape

torch.Size([1, 256])

We can clearly see that we have set the labels of all prompt tokens (which includes the question) to -100, to make sure the model doesn't learn to generate them. We only start to have labels starting from the \<s_answer> decoder input token.

In [None]:
for decoder_input_id, label in zip(decoder_input_ids[0].tolist()[:-1][:30], labels[0].tolist()[1:][:30]):
  if label != -100:
    print(processor.decode([decoder_input_id]), processor.decode([label]))
  else:
    print(processor.decode([decoder_input_id]), label)

<s_ner> <s_Fish>
<s_Fish> E
E per
per lan
lan </s_Fish>
</s_Fish> <sep/>
<sep/> <s_Peche>
<s_Peche> P
P ê
ê ché
ché (
( es
es )
) </s_Peche>
</s_Peche> <sep/>
<sep/> <s_ZoneCode>
<s_ZoneCode> 27
27 </s_ZoneCode>
</s_ZoneCode> <sep/>
<sep/> <s_SzoneCode>
<s_SzoneCode> VIII
VIII </s_SzoneCode>
</s_SzoneCode> <sep/>
<sep/> <s_Latin>
<s_Latin> Os
Os mer
mer us
us e
e per
per lan


## Define LightningModule

We'll fine-tune the model using [PyTorch Lightning](https://www.pytorchlightning.ai/) here, but note that you can of course also just fine-tune with regular PyTorch, HuggingFace [Accelerate](https://github.com/huggingface/accelerate), the HuggingFace [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer), etc.

PyTorch Lightning is pretty convenient to handle things like device placement, mixed precision and logging for you.

In [None]:
from pathlib import Path
import re
from nltk import edit_distance
import numpy as np
import math

from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import LambdaLR

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only


class DonutModelPLModule(pl.LightningModule):
    def __init__(self, config, processor, model):
        super().__init__()
        self.config = config
        self.processor = processor
        self.model = model

    def training_step(self, batch, batch_idx):
        pixel_values, decoder_input_ids, labels = batch
        
        outputs = self.model(pixel_values,
                             decoder_input_ids=decoder_input_ids[:, :-1],
                             labels=labels[:, 1:])
        loss = outputs.loss
        self.log_dict({"train_loss": loss}, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx, dataset_idx=0):
        pixel_values, decoder_input_ids, prompt_end_idxs, answers = batch
        decoder_prompts = pad_sequence(
            [input_id[: end_idx + 1] for input_id, end_idx in zip(decoder_input_ids, prompt_end_idxs)],
            batch_first=True,
        )
        
        outputs = self.model.generate(pixel_values,
                                   decoder_input_ids=decoder_prompts,
                                   max_length=max_length,
                                   early_stopping=True,
                                   pad_token_id=self.processor.tokenizer.pad_token_id,
                                   eos_token_id=self.processor.tokenizer.eos_token_id,
                                   use_cache=True,
                                   num_beams=1,
                                   bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
                                   return_dict_in_generate=True,)
    
        predictions = []
        for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
            seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
            seq = re.sub(r"<.*?>", "", seq, count=1).strip()  # remove first task start token
            predictions.append(seq)

        scores = list()
        for pred, answer in zip(predictions, answers):
            pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
            answer = re.sub(r"<.*?>", "", answer, count=1)
            answer = answer.replace(self.processor.tokenizer.eos_token, "")
            scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))

            if self.config.get("verbose", False) and len(scores) == 1:
                print(f"Prediction: {pred}")
                print(f"    Answer: {answer}")
                print(f" Normed ED: {scores[0]}")

        return scores

    def validation_epoch_end(self, validation_step_outputs):
        # I set this to 1 manually
        # (previously set to len(self.config.dataset_name_or_paths))
        num_of_loaders = 1
        if num_of_loaders == 1:
            validation_step_outputs = [validation_step_outputs]
        assert len(validation_step_outputs) == num_of_loaders
        cnt = [0] * num_of_loaders
        total_metric = [0] * num_of_loaders
        val_metric = [0] * num_of_loaders
        for i, results in enumerate(validation_step_outputs):
            for scores in results:
                cnt[i] += len(scores)
                total_metric[i] += np.sum(scores)
            val_metric[i] = total_metric[i] / cnt[i]
            val_metric_name = f"val_metric_{i}th_dataset"
            self.log_dict({val_metric_name: val_metric[i]}, sync_dist=True)
        self.log_dict({"val_metric": np.sum(total_metric) / np.sum(cnt)}, sync_dist=True)

    def configure_optimizers(self):
        # TODO add scheduler
        optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr"))
    
        return optimizer

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return val_dataloader

Next, we instantiate the module:

In [None]:
config = {"max_epochs":4,
          "val_check_interval":0.5, # how many times we want to validate during an epoch
          "check_val_every_n_epoch":1,
          "gradient_clip_val":1.0,
          "num_training_samples_per_epoch": 2000,
          "lr":3e-5,
          "train_batch_sizes": [4],
          "val_batch_sizes": [1],
          # "seed":2022,
          "num_nodes": 1,
          "warmup_steps": 1425, # 800/8*30/10, 10%
          "result_path": "./result",
          "verbose": True,
          }

model_module = DonutModelPLModule(config, processor, model)

## Train!

In [None]:
from pytorch_lightning.loggers import WandbLogger

wandb_logger = WandbLogger(project="Donut-fish-label")

trainer = pl.Trainer(
        accelerator="gpu",
        devices=1,
        max_epochs=config.get("max_epochs"),
        val_check_interval=config.get("val_check_interval"),
        check_val_every_n_epoch=config.get("check_val_every_n_epoch"),
        gradient_clip_val=config.get("gradient_clip_val"),
        precision=16, # we'll use mixed precision
        num_sanity_val_steps=0,
        logger=wandb_logger,
        # callbacks=[lr_callback, checkpoint_callback],
)

trainer.fit(model_module)

INFO:pytorch_lightning.utilities.rank_zero:Using 16bit None Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:lightning_fabric.accelerators.cuda:You are using a CUDA device ('A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type                      | Params
-----------

Training: 0it [00:00, ?it/s]

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 

··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Validation: 0it [00:00, ?it/s]

Prediction: <s_Packing>Aile</s_Packing><sep/><s_Fish>Raie</s_Fish><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Latin>Torredo morhorata</s_Latin><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>VI</s_SzoneCode><sep/><s_Zone>Atlantique Nord Est</s_Zone><sep/><s_Szone>Ouest Ecosse</s_Szone><sep/><s_Gear>Chalut</s_Gear></s_ner>
    Answer: <s_Packing>Aile</s_Packing><sep/><s_Fish>Raie</s_Fish><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Latin>Torpedo marmorata</s_Latin><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>VI</s_SzoneCode><sep/><s_Zone>Atlantique Nord Est</s_Zone><sep/><s_Szone>Ouest Ecosse</s_Szone><sep/><s_Gear>Chalut</s_Gear></s_ner>
 Normed ED: 0.009554140127388535
Prediction: <s_Packing>Ftoueue</s_Packing><sep/><s_Fish>Cabillaud</s_Fish><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>V</s_SzoneCode><sep/><s_Latin>Gadus macrocephalus</s_Latin><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Szone>Islande et Féroé</s_Szone><sep/><s_Gear>Ligne et hameçons</s_Gear></s_ner>
    

Validation: 0it [00:00, ?it/s]

Prediction: <s_Packing>Aile</s_Packing><sep/><s_Fish>Raie</s_Fish><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Latin>Torpeao barbata</s_Latin><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>VI</s_SzoneCode><sep/><s_Zone>Atlantique Nord Est</s_Zone><sep/><s_Szone>Ouest Ecosse</s_Szone><sep/><s_Gear>Chalut</s_Gear></s_ner>
    Answer: <s_Packing>Aile</s_Packing><sep/><s_Fish>Raie</s_Fish><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Latin>Torpedo marmorata</s_Latin><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>VI</s_SzoneCode><sep/><s_Zone>Atlantique Nord Est</s_Zone><sep/><s_Szone>Ouest Ecosse</s_Szone><sep/><s_Gear>Chalut</s_Gear></s_ner>
 Normed ED: 0.01592356687898089
Prediction: <s_Packing>Ftiqueue</s_Packing><sep/><s_Fish>Cabillaud</s_Fish><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>V</s_SzoneCode><sep/><s_Latin>Gadus macrocephalus</s_Latin><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Szone>Islande et Féroé</s_Szone><sep/><s_Gear>Ligne et hameçons</s_Gear></s_ner>
    An

Validation: 0it [00:00, ?it/s]

Prediction: <s_Packing>Aile</s_Packing><sep/><s_Fish>Raie</s_Fish><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Latin>Torpedo morhorata</s_Latin><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>VI</s_SzoneCode><sep/><s_Zone>Atlantique Nord Est</s_Zone><sep/><s_Szone>Ouest Ecosse</s_Szone><sep/><s_Gear>Chalut</s_Gear></s_ner>
    Answer: <s_Packing>Aile</s_Packing><sep/><s_Fish>Raie</s_Fish><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Latin>Torpedo marmorata</s_Latin><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>VI</s_SzoneCode><sep/><s_Zone>Atlantique Nord Est</s_Zone><sep/><s_Szone>Ouest Ecosse</s_Szone><sep/><s_Gear>Chalut</s_Gear></s_ner>
 Normed ED: 0.006369426751592357
Prediction: <s_Packing>Ftouche</s_Packing><sep/><s_Fish>Cabillaud</s_Fish><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>V</s_SzoneCode><sep/><s_Latin>Gadus macrocephalus</s_Latin><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Szone>Islande et Féroé</s_Szone><sep/><s_Gear>Ligne et hameçons</s_Gear></s_ner>
    

Validation: 0it [00:00, ?it/s]

Prediction: <s_Packing>Aile</s_Packing><sep/><s_Fish>Raie</s_Fish><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Latin>Torpedo barbata</s_Latin><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>VI</s_SzoneCode><sep/><s_Zone>Atlantique Nord Est</s_Zone><sep/><s_Szone>Ouest Ecosse</s_Szone><sep/><s_Gear>Chalut</s_Gear></s_ner>
    Answer: <s_Packing>Aile</s_Packing><sep/><s_Fish>Raie</s_Fish><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Latin>Torpedo marmorata</s_Latin><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>VI</s_SzoneCode><sep/><s_Zone>Atlantique Nord Est</s_Zone><sep/><s_Szone>Ouest Ecosse</s_Szone><sep/><s_Gear>Chalut</s_Gear></s_ner>
 Normed ED: 0.012738853503184714
Prediction: <s_Packing>Ftoneuse</s_Packing><sep/><s_Fish>Cabillaud</s_Fish><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>V</s_SzoneCode><sep/><s_Latin>Gadus macrocephalus</s_Latin><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Szone>Islande et Féroé</s_Szone><sep/><s_Gear>Ligne et hameçons</s_Gear></s_ner>
    A

Validation: 0it [00:00, ?it/s]

Prediction: <s_Packing>Aile</s_Packing><sep/><s_Fish>Raie</s_Fish><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Latin>Torycodo barbata</s_Latin><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>VI</s_SzoneCode><sep/><s_Zone>Atlantique Nord Est</s_Zone><sep/><s_Szone>Ouest Ecosse</s_Szone><sep/><s_Gear>Chalut</s_Gear></s_ner>
    Answer: <s_Packing>Aile</s_Packing><sep/><s_Fish>Raie</s_Fish><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Latin>Torpedo marmorata</s_Latin><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>VI</s_SzoneCode><sep/><s_Zone>Atlantique Nord Est</s_Zone><sep/><s_Szone>Ouest Ecosse</s_Szone><sep/><s_Gear>Chalut</s_Gear></s_ner>
 Normed ED: 0.022292993630573247
Prediction: <s_Fish>Ftoneuse</s_Fish><sep/><s_Latin>Gadus macrocephalus</s_Latin><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>V</s_SzoneCode><sep/><s_Latin>Gadus macrocephalus</s_Latin><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Szone>Islande et Féroé</s_Szone><sep/><s_Gear>Ligne et hameçons</s_Gear></s_ner

Validation: 0it [00:00, ?it/s]

Prediction: <s_Packing>Aile</s_Packing><sep/><s_Fish>Raie</s_Fish><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Latin>Torpeao macroata</s_Latin><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>VI</s_SzoneCode><sep/><s_Zone>Atlantique Nord Est</s_Zone><sep/><s_Szone>Ouest Ecosse</s_Szone><sep/><s_Gear>Chalut</s_Gear></s_ner>
    Answer: <s_Packing>Aile</s_Packing><sep/><s_Fish>Raie</s_Fish><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Latin>Torpedo marmorata</s_Latin><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>VI</s_SzoneCode><sep/><s_Zone>Atlantique Nord Est</s_Zone><sep/><s_Szone>Ouest Ecosse</s_Szone><sep/><s_Gear>Chalut</s_Gear></s_ner>
 Normed ED: 0.012738853503184714
Prediction: <s_Packing>Ftoneuse</s_Packing><sep/><s_Fish>Cabillaud</s_Fish><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>V</s_SzoneCode><sep/><s_Latin>Gadus macrocephalus</s_Latin><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Szone>Islande et Féroé</s_Szone><sep/><s_Gear>Ligne et hameçons</s_Gear></s_ner>
    

Validation: 0it [00:00, ?it/s]

Prediction: <s_Packing>Aile</s_Packing><sep/><s_Fish>Raie</s_Fish><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Latin>Torycolarmarmoratu</s_Latin><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>VI</s_SzoneCode><sep/><s_Zone>Atlantique Nord Est</s_Zone><sep/><s_Szone>Ouest Ecosse</s_Szone><sep/><s_Gear>Chalut</s_Gear></s_ner>
    Answer: <s_Packing>Aile</s_Packing><sep/><s_Fish>Raie</s_Fish><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Latin>Torpedo marmorata</s_Latin><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>VI</s_SzoneCode><sep/><s_Zone>Atlantique Nord Est</s_Zone><sep/><s_Szone>Ouest Ecosse</s_Szone><sep/><s_Gear>Chalut</s_Gear></s_ner>
 Normed ED: 0.022222222222222223
Prediction: <s_Fish>Ftiqueue</s_Packing><sep/><s_Fish>Cabillaud</s_Fish><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>V</s_SzoneCode><sep/><s_Latin>Gadus macrocephalus</s_Latin><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Szone>Islande et Féroé</s_Szone><sep/><s_Gear>Ligne et hameçons</s_Gear></s_ner>
    A

Validation: 0it [00:00, ?it/s]

Prediction: <s_Packing>Aile</s_Packing><sep/><s_Fish>Raie</s_Fish><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Latin>Torpedo mambrata</s_Latin><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>VI</s_SzoneCode><sep/><s_Zone>Atlantique Nord Est</s_Zone><sep/><s_Szone>Ouest Ecosse</s_Szone><sep/><s_Gear>Chalut</s_Gear></s_ner>
    Answer: <s_Packing>Aile</s_Packing><sep/><s_Fish>Raie</s_Fish><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Latin>Torpedo marmorata</s_Latin><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>VI</s_SzoneCode><sep/><s_Zone>Atlantique Nord Est</s_Zone><sep/><s_Szone>Ouest Ecosse</s_Szone><sep/><s_Gear>Chalut</s_Gear></s_ner>
 Normed ED: 0.006369426751592357
Prediction: <s_Fish>Fléqueue</s_Fish><sep/><s_Fish>Cabillaud</s_Fish><sep/><s_ZoneCode>27</s_ZoneCode><sep/><s_SzoneCode>V</s_SzoneCode><sep/><s_Latin>Gadus macrocephalus</s_Latin><sep/><s_Peche>Pêché(es)</s_Peche><sep/><s_Szone>Islande et Féroé</s_Szone><sep/><s_Gear>Ligne et hameçons</s_Gear></s_ner>
    Answer

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=4` reached.


# Push to hub and reuse

HuggingFace's [hub](https://huggingface.co/) is a nice place to host, version and share machine learning models (and datasets, and demos in the form of [Spaces](https://huggingface.co/spaces)).

We first provide our authentication token.

In [None]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` now requires a token generated from https://huggingface.co/settings/tokens .
    
Token: 
Add token as git credential? (Y/n) y
Token is valid.
[1m[31mCannot authenticate through git-credential as no helper is defined on your machine.
You might have to re-authenticate when pushing to the Hugging Face Hub.
Run the following command in your terminal in case you want to set the 'store' credenti

Pushing to the hub after training is as easy as:

In [None]:
repo_name = "hublot/doner"

# here we push the processor and model to the hub
# note that you can add `private=True` in case you're using the private hub
# which makes sure the model is only shared with your colleagues
model_module.processor.push_to_hub(repo_name)
model_module.model.push_to_hub(repo_name)

CommitInfo(commit_url='https://huggingface.co/hublot/doner/commit/107238c84a1b285cfa03b525ed6a691c5f733cd6', commit_message='Upload model', commit_description='', oid='107238c84a1b285cfa03b525ed6a691c5f733cd6', pr_url=None, pr_revision=None, pr_num=None)

Reloading can then be done as:

In [None]:
!pip install  accelerate>=0.12.0

In [None]:
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel

processor = DonutProcessor.from_pretrained("hublot/doner", use_auth_token=True, torch_dtype=torch.float16)
model = VisionEncoderDecoderModel.from_pretrained("hublot/doner", use_auth_token=True, torch_dtype=torch.float16)

Downloading:   0%|          | 0.00/422 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/552 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.30M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/4.02M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/676 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/355 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/4.98k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/809M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/186 [00:00<?, ?B/s]

## Inference

For inference, we refer to the [docs](https://huggingface.co/docs/transformers/main/en/model_doc/donut#inference) of Donut, or the corresponding [notebook](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Donut/DocVQA/Quick_inference_with_DONUT_for_DocVQA.ipynb).

In [None]:
!pip install gradio

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting gradio
  Downloading gradio-3.16.2-py3-none-any.whl (14.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.2/14.2 MB[0m [31m91.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting httpx
  Downloading httpx-0.23.3-py3-none-any.whl (71 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m71.5/71.5 KB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting python-multipart
  Downloading python-multipart-0.0.5.tar.gz (32 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting uvicorn
  Downloading uvicorn-0.20.0-py3-none-any.whl (56 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.9/56.9 KB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting orjson
  Downloading orjson-3.8.5-cp38-cp38-manylinux_2_28_x86_64.whl (140 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m140.6/140.6 KB[0m [

In [None]:
import gradio as gr
import torch
import re
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model = model.half()

def greet(image):
    task_prompt = "<s_ner>"
    pixel_values = processor(image, return_tensors="pt").pixel_values
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
    pixel_values = pixel_values.to(device, torch.float16 )
    decoder_input_ids = decoder_input_ids.to(device)
    # print("model",model)
    print("pixel_values",pixel_values)
    print("decoder_input_ids",decoder_input_ids)
    print("model.decoder.config.max_position_embeddings",model.decoder.config.max_position_embeddings)
    print("processor.tokenizer.pad_token_id",processor.tokenizer.pad_token_id)
    print("processor.tokenizer.eos_token_id",processor.tokenizer.eos_token_id)
    print("processor.tokenizer.unk_token_id",processor.tokenizer.unk_token_id)

    outputs = model.generate(
        pixel_values,
        decoder_input_ids=decoder_input_ids,
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )
    sequence = processor.batch_decode(outputs.sequences)[0]
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token
    return processor.token2json(sequence)
    return outputs

demo = gr.Interface(fn=greet, inputs=gr.Image(type="pil"), outputs="json")

demo.launch(share=True, debug=False)  

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://b282a486-c327-4d9e.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces


