In [1]:
!pip install -q "transformers>=4.25.1" "datasets[image]" "sentencepiece" "torch" "accelerate" "pillow" "wandb"

[0m

In [2]:
import json
import torch
from PIL import Image
from torch.utils.data import Dataset
from datasets import load_dataset
from transformers import (
    DonutProcessor,
    VisionEncoderDecoderModel,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from huggingface_hub import login
import wandb

In [4]:
login()
wandb.login()

ImportError: The `notebook_login` function can only be used in a notebook (Jupyter or Colab) and you need the `ipywidgets` module: `pip install ipywidgets`.

In [5]:
raw_dataset = load_dataset("naver-clova-ix/cord-v2", split="train")

Generating train split: 100%|██████████| 800/800 [00:03<00:00, 204.61 examples/s]
Generating validation split: 100%|██████████| 100/100 [00:00<00:00, 190.21 examples/s]
Generating test split: 100%|██████████| 100/100 [00:00<00:00, 203.90 examples/s]


In [6]:
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base")



Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [7]:
def json_to_special_token_sequence(json_data):
    """Converts a JSON object to Donut's special token string format."""
    sequence = ""
    for key, value in json_data.items():
        if value is None: continue
        if isinstance(value, list):
            sequence += f"<s_{key}>"
            for item in value:
                if isinstance(item, dict): sequence += json_to_special_token_sequence(item)
            sequence += f"</s_{key}>"
        elif isinstance(value, dict):
            sequence += f"<s_{key}>"
            sequence += json_to_special_token_sequence(value)
            sequence += f"</s_{key}>"
        else:
            sequence += f"<s_{key}>{str(value)}</s_{key}>"
    return sequence

new_tokens = set()
for example in raw_dataset:
    try:
        ground_truth = json.loads(example['ground_truth'])
        json_data = ground_truth['gt_parse']
        
        def extract_keys(d):
            for k, v in d.items():
                new_tokens.add(f"<s_{k}>"); new_tokens.add(f"</s_{k}>")
                if isinstance(v, dict): extract_keys(v)
                elif isinstance(v, list):
                    for item in v:
                        if isinstance(item, dict): extract_keys(item)
        extract_keys(json_data)
    except (KeyError, json.JSONDecodeError):
        continue

processor.tokenizer.add_special_tokens({"additional_special_tokens": sorted(list(new_tokens))})
model.decoder.resize_token_embeddings(len(processor.tokenizer))

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


MBartScaledWordEmbedding(57579, 1024, padding_idx=1)

In [8]:
class ReceiptDataset(Dataset):
    """PyTorch Dataset for Donut, which processes data on-the-fly."""
    def __init__(self, dataset, processor, max_length=1536):
        self.dataset = dataset
        self.processor = processor
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        example = self.dataset[idx]

        image = example["image"].convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        
        ground_truth = json.loads(example['ground_truth'])
        json_data = ground_truth['gt_parse']
        target_sequence = json_to_special_token_sequence(json_data)
        
        input_ids = self.processor.tokenizer(
            target_sequence, 
            add_special_tokens=False, 
            max_length=self.max_length,
            padding="max_length", 
            truncation=True, 
            return_tensors="pt"
        ).input_ids.squeeze(0)
        
        labels = input_ids.clone()
        labels[labels == self.processor.tokenizer.pad_token_id] = -100
        
        return {"pixel_values": pixel_values.squeeze(), "labels": labels}

train_dataset = ReceiptDataset(raw_dataset, processor)

In [9]:
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s>'])[0]

training_args = Seq2SeqTrainingArguments(
    output_dir="donut_finetuned_receipts_final_v2",
    num_train_epochs=12,          
    learning_rate=5e-5,
    per_device_train_batch_size=2, 
    gradient_accumulation_steps=8,  
    fp16=True,                       
    gradient_checkpointing=True,     
    weight_decay=0.01,
    logging_steps=20,
    save_strategy="epoch",
    report_to="wandb",
    run_name="donut-receipt-run-12-epochs",
)



In [10]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
)


In [11]:

trainer.train()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /teamspace/studios/this_studio/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmanoharnimiditalli001[0m ([33mmanoharnimiditalli001-indian-institute-of-information-te[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`...
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


Step,Training Loss
20,10.1392
40,7.1344
60,5.6561
80,3.7801
100,2.424
120,1.9656
140,1.8164
160,1.5156
180,1.2624
200,1.1105


TrainOutput(global_step=600, training_loss=1.5102253524462381, metrics={'train_runtime': 7583.6891, 'train_samples_per_second': 1.266, 'train_steps_per_second': 0.079, 'total_flos': 1.2007340739919872e+20, 'train_loss': 1.5102253524462381, 'epoch': 12.0})

In [12]:
final_model_path = "final_donut_receipt_model"
trainer.save_model(final_model_path)
processor.save_pretrained(final_model_path)

[]

In [14]:
import torch
import json
import re
from PIL import Image
from transformers import DonutProcessor, VisionEncoderDecoderModel

MODEL_PATH = "./final_donut_receipt_model"
IMAGE_PATH = "/teamspace/studios/this_studio/test_image12.jpg"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

processor = DonutProcessor.from_pretrained(MODEL_PATH)
model = VisionEncoderDecoderModel.from_pretrained(MODEL_PATH)
model.to(DEVICE)

try:
    image = Image.open(IMAGE_PATH).convert("RGB")
    pixel_values = processor(image, return_tensors="pt").pixel_values.to(DEVICE)
    
    task_prompt = "<s>"
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(DEVICE)

    outputs = model.generate(
        pixel_values,
        decoder_input_ids=decoder_input_ids,
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )

    sequence = processor.batch_decode(outputs.sequences)[0]
    parsed_json = processor.token2json(sequence)
    
    print(json.dumps(parsed_json, indent=2))

except FileNotFoundError:
    print(f"Error: Image not found at '{IMAGE_PATH}'. Please update the path.")

The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


{
  "nm": "POTSISISI",
  "unitprice": ".",
  "price": "13<s_price>13<s_subtotal_price> 10 :P :<s_price>",
  "tax_price": {
    "nm": "GRE CLBC<s_nm> PO",
    "total": "14"
  },
  "total": "<s_total><s_total><s_total><s_total> 237"
}
