# Mystery box - 3 - Train DONUT based distance detector
### Dennis Bakhuis - 10th November 2022
### https://linkedin.com/in/dennisbakhuis/

Great Donut article by Phillipp Schmid: https://www.philschmid.de/fine-tuning-donut

In [None]:
import json
from PIL import Image
import random
from pathlib import Path

import datasets

import torch

from transformers import (
    DonutProcessor,
    VisionEncoderDecoderModel, 
    Seq2SeqTrainingArguments, 
    Seq2SeqTrainer,
)

import numpy as np

from huggingface_hub import HfFolder

In [None]:
dataset = datasets.load_dataset("bakhuisdennis/mystery_box")

In [None]:
dataset

## Change json to tokens

In [None]:
new_special_tokens = []
task_token_start = "<s>" 
task_token_end = "</s>" 

def json2token(
    obj, 
    update_special_tokens_for_json_key: bool = True, 
    sort_json_key: bool = True,
):
    """Function based on https://github.com/clovaai/donut/blob/master/donut/model.py#L497."""
    if type(obj) == dict:
        if len(obj) == 1 and "text_sequence" in obj:
            return obj["text_sequence"]
        else:
            output = ""
            if sort_json_key:
                keys = sorted(obj.keys(), reverse=True)
            else:
                keys = obj.keys()
            for k in keys:
                if update_special_tokens_for_json_key:
                    new_special_tokens.append(fr"<s_{k}>") if fr"<s_{k}>" not in new_special_tokens else None
                    new_special_tokens.append(fr"</s_{k}>") if fr"</s_{k}>" not in new_special_tokens else None
                output += (
                    fr"<s_{k}>"
                    + json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
                    + fr"</s_{k}>"
                )
            return output
    elif type(obj) == list:
        return r"<sep/>".join(
            [json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
        )
    else:
        obj = str(obj)
        if f"<{obj}/>" in new_special_tokens:
            obj = f"<{obj}/>"
        return obj


def preprocess_documents_for_donut(sample):
    text = json.loads(sample["text"])
    d_doc = task_token_start + json2token(text) + task_token_end

    image = sample["image"].convert('RGB')  # our image is monochrome but models wants RGB
    return {"image": image, "text": d_doc}


proc_dataset = dataset.map(preprocess_documents_for_donut)

print(f"Sample: {proc_dataset['train'][45]['text']}")

### Tokenize data

In [None]:
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")

processor.tokenizer.add_special_tokens(
    {"additional_special_tokens": new_special_tokens + [task_token_start] + [task_token_end]},
)

processor.feature_extractor.size = list(dataset['train'][0]['image'].size)
processor.feature_extractor.do_align_long_axis = False

In [None]:
def transform_and_tokenize(
    sample, 
    processor=processor, 
    split="train", 
    max_length=512, 
    ignore_id=-100,
):
    pixel_values = processor(
        sample["image"], random_padding=split == "train", return_tensors="pt"
    ).pixel_values.squeeze()


    input_ids = processor.tokenizer(
        sample["text"],
        add_special_tokens=False,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )["input_ids"].squeeze(0)

    labels = input_ids.clone()
    labels[labels == processor.tokenizer.pad_token_id] = ignore_id
    
    return {
        "pixel_values": pixel_values, 
        "labels": labels, 
        "target_sequence": sample["text"],
    }


processed_dataset = proc_dataset.map(
    transform_and_tokenize,
    remove_columns=["image","text"],
)

### Train model

In [None]:
processed_dataset = processed_dataset['train'].train_test_split(test_size=0.15)
print(processed_dataset)

In [None]:
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base")

new_emb = model.decoder.resize_token_embeddings(len(processor.tokenizer))
model.config.encoder.image_size = processor.feature_extractor.size[::-1]
model.config.decoder.max_length = len(
    max(processed_dataset["train"]["labels"], key=len),
)

model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s>'])[0]

I'll train this on the CPU as my GPU has insufficient memory. The dataset is small enough such that even on CPU it is still pretty quick (~38min)

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir='output',
    num_train_epochs=3,
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    weight_decay=0.01,
    logging_steps=40,
    save_total_limit=2,
    save_strategy="epoch",
    predict_with_generate=True,
    no_cuda=True,
    # push to hub parameters
    report_to="tensorboard",
    push_to_hub=True,
    hub_strategy="every_save",
    hub_model_id="donut-base-mysterybox",
    hub_token=HfFolder.get_token(),
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=processed_dataset["train"],
    eval_dataset=processed_dataset["test"],
)

In [None]:
trainer.train()

In [None]:
trainer.evaluate()

### Quick test

In [None]:
def run_prediction(sample, model=model, processor=processor):
    # prepare inputs
    pixel_values = torch.tensor(sample["pixel_values"]).unsqueeze(0)
    task_prompt = "<s>"
    decoder_input_ids = processor.tokenizer(
        task_prompt, 
        add_special_tokens=False, 
        return_tensors="pt",
    ).input_ids

    # run inference
    outputs = model.generate(
        pixel_values,
        decoder_input_ids=decoder_input_ids,
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )

    # process output
    prediction = processor.batch_decode(outputs.sequences)[0]
    prediction = processor.token2json(prediction)

    # load reference target
    target = processor.token2json(sample["target_sequence"])
    return prediction, target

same = 0
for example in processed_dataset["test"]:
    prediction, target = run_prediction(example)
    print(f"Reference: {target}   --> Target: {prediction}")
    if prediction == target:
        same += 1

print(f"Total of {same}/{len(processed_dataset['test'])} correct")
print(f"Accuracy: {same / len(processed_dataset['test']):.2%}")

### Save model

In [None]:
# model_path = Path("../data/donut-base-mysterybox")

# if not model_path.exists():
#     model_path.mkdir()
    
# trainer.save_model(model_path / "model")
# processor.save_pretrained(model_path / "processor")

#### push to hub

In [None]:
trainer.push_to_hub()
processor.push_to_hub("bakhuisdennis/donut-base-mysterybox")