In [17]:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

from datasets import load_dataset
from taiwan_license_plate_recognition.helper import get_num_of_workers

In [18]:
num_workers: int = get_num_of_workers()

In [24]:
dataset = load_dataset("gagan3012/IAM", keep_in_memory=True, num_proc=num_workers).remove_columns(["label"])["train"]

In [25]:
dataset = dataset.train_test_split(test_size=0.2, shuffle=True, seed=37710)

In [30]:
train_dataset = dataset["train"].train_test_split(test_size=0.125, shuffle=True, seed=37710)

In [31]:
train_dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'text'],
        num_rows: 7940
    })
    test: Dataset({
        features: ['image', 'text'],
        num_rows: 1135
    })
})

In [32]:
dataset["train"] = train_dataset["train"]
dataset["validation"] = train_dataset["test"]

In [33]:
dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'text'],
        num_rows: 7940
    })
    test: Dataset({
        features: ['image', 'text'],
        num_rows: 2269
    })
    validation: Dataset({
        features: ['image', 'text'],
        num_rows: 1135
    })
})

In [5]:
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")

In [6]:
base_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1")

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": false,
  "transformers_version": "4.46.0"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "relu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 768,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder

In [7]:
# set special tokens used for creating the decoder_input_ids from the labels
base_model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
base_model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
base_model.config.vocab_size = base_model.config.decoder.vocab_size

# set beam search parameters
base_model.config.eos_token_id = processor.tokenizer.sep_token_id
base_model.config.max_length = 64
base_model.config.early_stopping = True
base_model.config.no_repeat_ngram_size = 3
base_model.config.length_penalty = 2.0
base_model.config.num_beams = 4

In [None]:
dataset = dataset.map(
	lambda samples: {
		"data": [processor.image_processor(image.convert("RGB"), return_tensors="pt").pixel_values for image in samples]
	},
	input_columns=["image"],
	remove_columns=["image"],
	batched=True,
	num_proc=num_workers,
)

Map (num_proc=8):   0%|          | 0/9075 [00:00<?, ? examples/s]

In [13]:
dataset = dataset.map(
	lambda samples: {"label": [processor.tokenizer(text, padding="max_length").input_ids for text in samples]},
	input_columns=["text"],
	remove_columns=["text"],
	batched=True,
	num_proc=num_workers,
)

Map (num_proc=8):   0%|          | 0/9075 [00:00<?, ? examples/s]

Map (num_proc=8):   0%|          | 0/2269 [00:00<?, ? examples/s]