# 写在前面
TrOCR原论文,预训练了6.8亿的人造数据，算力有限训不动(试过,没大量数据跑不起来)，因此做第二阶段微调。

然后huggingface有提供一阶段的参数。但是一阶段效果就很好了。因此这里将使用trocr的印刷体版本训练成手写字版本

这里数据好像给的太慢了,导致训练时间长，可以增加num_proc在datasets.map中以及trainingArgument中设置num_work。

jupyter在windows中存在多进程bug，所以这里没有实现。在mac和linux环境都是可以直接在jupyter中调用

看了很多实现，都没有实现动态padding，这里的代码实现了动态padding。


In [None]:
from datasets import load_dataset, concatenate_datasets

In [None]:
raw_dataset = load_dataset("priyank-m/IAM_words_text_recognition")
raw_dataset

In [None]:
val_raw_dataset = raw_dataset.pop("val")
raw_dataset["train"] = concatenate_datasets([raw_dataset["train"], val_raw_dataset])
raw_dataset

In [None]:
sample_dataset = raw_dataset["train"].select(range(10))
sample_dataset

In [None]:
test_data = sample_dataset[0]
test_data

In [None]:
test_image = test_data["image"]
test_image

In [7]:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

In [None]:
checkpoint = "microsoft/trocr-small-printed"
processor = TrOCRProcessor.from_pretrained(checkpoint)
model = VisionEncoderDecoderModel.from_pretrained(checkpoint)

# 补一下 好像不在模型自身的config decoder里有
model.config.decoder_start_token_id = processor.tokenizer.eos_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size

In [9]:
def map_function(examples):
    origin_image_list = examples["image"]
    origin_text_list = examples["text"]
    
    image_list = []
    text_list = []
    for image, text in zip(origin_image_list, origin_text_list):
        image = image.convert("RGB")
        image_list.append(image)
        text_list.append(text)
    #BUG processor.tokenizer.bos_token != decoder_start_token 这个放在DataCollator中实现
    return processor(images=image_list, text=text_list)  

In [None]:
tokenizer_sample_dataset = sample_dataset.map(map_function, batched=True)
tokenizer_sample_dataset

In [11]:
import torch
from dataclasses import dataclass
from transformers import default_data_collator
from typing import Any, List, Dict, Union

@dataclass
class DataCollatorForOCR:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor ([`WhisperProcessor`])
            The processor used for processing the data.
        decoder_start_token_id (`int`)
            The begin-of-sentence of the decoder.

    """

    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        model_input_name = self.processor.model_input_names[0]
        input_features = [{model_input_name: feature[model_input_name]} for feature in features]
        # 去掉process.tokenizer.bos_token_id  模型内部会给他补上 decoder_start_token_id 这里补也行 没必要
        label_features = [{"input_ids": feature["labels"][1:]} for feature in features]

        batch = default_data_collator(input_features, return_tensors="pt")

        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [None]:
from torch.utils.data import DataLoader

collate_fn = DataCollatorForOCR(processor, model.decoder.config.decoder_start_token_id)
dataloader = DataLoader(tokenizer_sample_dataset, batch_size=2, collate_fn=collate_fn)

for item in dataloader:
    pixel_values = item["pixel_values"]
    labels = item["labels"]
    print(pixel_values.size())
    print(labels.size())
    print(labels)
    break

In [None]:
tokenizer_datasets = raw_dataset.map(map_function, batched=True, remove_columns=raw_dataset["train"].column_names)
tokenizer_datasets

In [None]:
import wandb
import evaluate
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, EvalPrediction

wandb.init(project="trocr_stage2")
accuracy = evaluate.load("accuracy")

def eval_function(eval_prediction:EvalPrediction):
    predictions = eval_prediction.predictions
    label_ids = eval_prediction.label_ids
    acc = accuracy.compute(references=label_ids, predictions=predictions)
    return acc

#! 因为动态padding,很有可能训练过程中 爆显存
train_args = Seq2SeqTrainingArguments(
    output_dir = "output/trocr_stage2",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    learning_rate=5e-5,
    warmup_steps=1000,
    num_train_epochs=3,
    lr_scheduler_type="cosine",
    logging_steps=10,
    logging_first_step=True,
    logging_strategy="steps",
    eval_steps=500,
    eval_strategy="steps",
    save_strategy="epoch",
    save_safetensors=True,
    bf16=True,
    report_to="wandb"  # 如果不想用wandb 就改成tensorboard
)
trainer = Seq2SeqTrainer(
    model=model,
    args=train_args,
    data_collator=collate_fn,
    train_dataset=tokenizer_datasets["train"],
    eval_dataset=tokenizer_datasets["test"],
    compute_metrics=eval_function,
    processing_class=processor,
    
)

In [None]:
trainer.train()

In [None]:
test_image

In [None]:
test_image = test_image.convert("RGB")
pixel_values = processor(images=test_image, return_tensors="pt")["pixel_values"]
pixel_values

In [None]:
answer_token = model.generate(pixel_values.to(model.device))
answer_token

In [None]:
processor.batch_decode(answer_token)