In [1]:
%pip install --upgrade pip
%pip install torch torchvision torchaudio
%pip install -q datasets jiwer
%pip install sklearn

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
import os, sys, itertools
os.environ['TOKENIZERS_PARALLELISM']='false'

import pandas as pd
from sklearn.model_selection import train_test_split

from PIL import Image

import torch
from torch.utils.data import Dataset

import datasets
from datasets import load_dataset

import transformers
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import VisionEncoderDecoderModel, TrOCRProcessor, default_data_collator

import evaluate

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
print("Python:".rjust(15), sys.version[0:6])
print("Pandas:".rjust(15), pd.__version__)
print("Datasets:".rjust(15), datasets.__version__)
print("Transformers:".rjust(15), transformers.__version__)
print("Torch:".rjust(15), torch.__version__)

        Python: 3.10.1
        Pandas: 2.1.2
      Datasets: 2.14.6
  Transformers: 4.35.0
         Torch: 2.1.0+cu118


In [4]:
path = "plates.csv"

dataset = pd.read_csv(path)

# train/test split
train_dataset, test_dataset = train_test_split(dataset, train_size=0.80, random_state=42)

train_dataset.reset_index(drop=True, inplace=True)
test_dataset.reset_index(drop=True, inplace=True)

print(train_dataset.info())
print(test_dataset.info())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 83 entries, 0 to 82
Data columns (total 2 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   file_name  83 non-null     object
 1   text       83 non-null     object
dtypes: object(2)
memory usage: 1.4+ KB
None
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 21 entries, 0 to 20
Data columns (total 2 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   file_name  21 non-null     object
 1   text       21 non-null     object
dtypes: object(2)
memory usage: 464.0+ bytes
None


In [None]:
train_dataset.head(12)

In [6]:
class License_Plates_OCR_Dataset(Dataset):
    
    def __init__(self, root_dir, df, processor, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # get file name + text
        file_name = self.df['file_name'][idx]
        text = self.df['text'][idx]
        # prepare image (i.e. resize + normalize)
        image = Image.open(self.root_dir + file_name).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text, padding="max_length", max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id 
                  else -100 for label in labels]
        
        encoding = {"pixel_values" : pixel_values.squeeze(), "labels" : torch.tensor(labels)}
        return encoding

In [7]:
MODEL_CKPT = "microsoft/trocr-base-printed"
MODEL_NAME =  MODEL_CKPT.split("/")[-1] + "_license_plates_ocr"
NUM_OF_EPOCHS = 2

In [8]:
processor = TrOCRProcessor.from_pretrained(MODEL_CKPT)

cwd = f"{os.getcwd()}/"

train_ds = License_Plates_OCR_Dataset(root_dir=cwd, df=train_dataset, processor=processor)
test_ds = License_Plates_OCR_Dataset(root_dir=cwd, df=test_dataset, processor=processor)

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [9]:
print(f"The training dataset has {len(train_ds)} samples in it.")
print(f"The testing dataset has {len(test_ds)} samples in it.")

The training dataset has 83 samples in it.
The testing dataset has 21 samples in it.


In [10]:
encoding = train_ds[0]

for k,v in encoding.items():
    print(k, " : ", v.shape)

pixel_values  :  torch.Size([3, 384, 384])
labels  :  torch.Size([128])


In [None]:
image = Image.open(train_ds.root_dir + train_dataset['file_name'][0]).convert("RGB")

image

In [None]:
labels = encoding['labels']
labels[labels == -100] = processor.tokenizer.pad_token_id
label_str = processor.decode(labels, skip_special_tokens=True)
print(label_str)

In [13]:
model = VisionEncoderDecoderModel.from_pretrained(MODEL_CKPT)

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-printed and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id

model.config.vocab_size = model.config.decoder.vocab_size

model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

In [15]:
cer_metric = evaluate.load("cer")

def compute_metrics(pred):
    label_ids = pred.label_ids
    pred_ids = pred.predictions
    
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)
    
    cer = cer_metric.compute(predictions=pred_str, references=label_str)
    
    return {"cer" : cer}

In [18]:
args = Seq2SeqTrainingArguments(
    output_dir = MODEL_NAME,
    num_train_epochs=NUM_OF_EPOCHS,
    predict_with_generate=True,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    logging_first_step=True,
    hub_private_repo=False,
    push_to_hub=False
)

In [19]:
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=args,
    compute_metrics=compute_metrics,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    data_collator=default_data_collator
)



In [20]:
train_results = trainer.train()

  return F.conv2d(input, weight, bias, self.stride,
  8%|▊         | 1/12 [00:02<00:22,  2.02s/it]

{'loss': 11.3434, 'learning_rate': 4.5833333333333334e-05, 'epoch': 0.17}


                                              


{'eval_loss': 0.7706347703933716, 'eval_cer': 0.19310344827586207, 'eval_runtime': 2.8412, 'eval_samples_per_second': 7.391, 'eval_steps_per_second': 0.704, 'epoch': 1.0}


                                               
100%|██████████| 12/12 [00:25<00:00,  1.58s/it]

{'eval_loss': 0.13573069870471954, 'eval_cer': 0.041379310344827586, 'eval_runtime': 2.9428, 'eval_samples_per_second': 7.136, 'eval_steps_per_second': 0.68, 'epoch': 2.0}


100%|██████████| 12/12 [00:28<00:00,  2.34s/it]

{'train_runtime': 28.1456, 'train_samples_per_second': 5.898, 'train_steps_per_second': 0.426, 'train_loss': 2.677683115005493, 'epoch': 2.0}





In [21]:
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

***** train metrics *****
  epoch                    =        2.0
  train_loss               =     2.6777
  train_runtime            = 0:00:28.14
  train_samples_per_second =      5.898
  train_steps_per_second   =      0.426


In [22]:
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

100%|██████████| 2/2 [00:00<00:00,  2.15it/s]

***** eval metrics *****
  epoch                   =        2.0
  eval_cer                =     0.0414
  eval_loss               =     0.1357
  eval_runtime            = 0:00:02.97
  eval_samples_per_second =       7.07
  eval_steps_per_second   =      0.673



