<a href="https://colab.research.google.com/github/erez-meoded/TrOCR-HTR/blob/master/paper_training_with_augmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install -q transformers evaluate jiwer datasets accelerate

In [None]:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from sklearn.model_selection import train_test_split as tts
from transformers import default_data_collator
from torch.utils.data import Dataset
from PIL import Image, ImageFilter
from pickle import load, dump
from copy import deepcopy
import pandas as pd
import numpy as np

import evaluate
import torch
import os
from torchvision.transforms import Resize, RandomChoice, Compose, Grayscale, GaussianBlur, ElasticTransform,RandomPerspective, RandomRotation, RandomAffine

In [None]:
!unzip -qq "/content/drive/MyDrive/Theses/Data/Historical/paper.zip"

In [None]:
df_train, df_val = tts(load(open('/content/drive/MyDrive/Theses/Data/Historical/df.pkl', 'rb')),test_size=0.2,random_state=81,shuffle=True)
df_train.reset_index(drop=True,inplace=True)
df_val.reset_index(drop=True,inplace=True)
df_val, df_test = tts(df_val,test_size=0.5,random_state=81,shuffle=True)
df_val.reset_index(drop=True,inplace=True)
df_test.reset_index(drop=True,inplace=True)
df_train.shape,df_val.shape,df_test.shape

## Prepare the Gwalther data

In [None]:
# set up dataset class
class GwaltherDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=128, transform=None):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # get file name + text
        file_name = self.df['file_name'][idx]
        text = self.df['text'][idx]
        # prepare image (i.e. resize + normalize)
        image = Image.open(self.root_dir + file_name).convert("RGB")
        if self.transform:
          image = transform(image)
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text,
                                          padding="max_length",
                                          max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"file": file_name, "pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

## Augmentation

In [None]:
class InterpolationMode():
    NEAREST = 0
    BILINEAR = 2
    BICUBIC = 3
    BOX = 4
    HAMMING = 5
    LANCZOS = 1

class Dilation(torch.nn.Module):

    def __init__(self, kernel=3):
        super().__init__()
        self.kernel=kernel

    def forward(self, img):
        return img.filter(ImageFilter.MaxFilter(self.kernel))

    def __repr__(self):
        return self.__class__.__name__ + '(kernel={})'.format(self.kernel)

class Erosion(torch.nn.Module):

    def __init__(self, kernel=3):
        super().__init__()
        self.kernel=kernel

    def forward(self, img):
        return img.filter(ImageFilter.MinFilter(self.kernel))

    def __repr__(self):
        return self.__class__.__name__ + '(kernel={})'.format(self.kernel)

class Underline(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, img):
        img_cp = deepcopy(img)
        img_np = np.array(img_cp.convert('L'))
        black_pixels = np.where(img_np < 50)
        try:
            y1 = max(black_pixels[0])
            x0 = min(black_pixels[1])
            x1 = max(black_pixels[1])
        except:
            return img
        for x in range(x0, x1):
            for y in range(y1, y1-3, -1):
                try:
                    #img.putpixel((x, y), (0, 0, 0)) #original from MS with a bug. This is an 'L' mode (grayscale) and cannot have 3 channels
                    img_cp.putpixel((x, y), 0)
                except:
                    continue
        return img_cp

class KeepOriginal(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, img):
        return img

class ReResize(torch.nn.Module):

    def __init__(self, kernel=3):
        super().__init__()
        self.kernel=kernel

    def forward(self, img):
        return ImgResize(1/self.kernel)(ImgResize(self.kernel)(img))

    def __repr__(self):
        return self.__class__.__name__ + '(kernel={})'.format(self.kernel)

class ImgResize(torch.nn.Module):

    def __init__(self, kernel=3):
        super().__init__()
        self.kernel=kernel

    def forward(self, img):
        size = np.flip(np.array(img.size))
        return Resize((int(size[0]/self.kernel), int(size[1]/self.kernel)), interpolation=InterpolationMode.NEAREST)(img)
    def __repr__(self):
        return self.__class__.__name__ + '(kernel={})'.format(self.kernel)


In [None]:
rotation = RandomRotation(degrees=(-10, 10), expand=True, fill=255)
gaussianblur = GaussianBlur(3)
dilation = Dilation()
erosion = Erosion()
resize = ImgResize()
underline = Underline()
baseline = KeepOriginal()

affine = RandomAffine(degrees=2.5, translate=(0,.250), shear=50, scale=(.5,1), fill=255)
perspective = RandomPerspective(p=1,fill=255)
elastic = ElasticTransform(alpha=10.0, sigma=5.,fill=255)
re_resize = ReResize()

transforms_dict = {
    # "BASELINE":baseline,
    # "RANDOM_ROTATION": rotation,
    # "GAUSSIAN_BLUR": gaussianblur,
    # "DILATION": dilation,
    # "EROSION": erosion,
    # "RESIZE": resize,
    # "UNDERLINE": underline,

    # "RANDOM_AFFINE": affine,
    "RANDOM_PERSPECTIVE": perspective,
    "ELASTIC": elastic,
    "RE_RESIZE": re_resize
}

## Train a model

Or rather, fine-tune


In [None]:
def get_model(model='microsoft/trocr-base-handwritten'):
  model = VisionEncoderDecoderModel.from_pretrained(model)
  # set special tokens used for creating the decoder_input_ids from the labels
  model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
  model.config.pad_token_id = processor.tokenizer.pad_token_id
  # make sure vocab size is set correctly
  model.config.vocab_size = model.config.decoder.vocab_size

  # set to make it trainable:
  model.config.decoder.is_decoder = True
  model.config.decoder.add_cross_attention = True

  # set beam search parameters
  model.config.eos_token_id = processor.tokenizer.sep_token_id
  model.config.max_length = 64
  model.config.early_stopping = True
  model.config.no_repeat_ngram_size = 3
  model.config.length_penalty = 2.0
  model.config.num_beams = 4
  return model

In [None]:
root="/content/drive/MyDrive/Theses/Experiments/Historical/"

In [None]:
cer_metric = evaluate.load("cer")

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
    print(pred_str)
    print(label_str)
    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return {"cer": cer}


In [None]:
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')

In [None]:
for experiment, transform in transforms_dict.items():
  print(experiment)

  model = get_model()
  path = root + experiment

  if not os.path.isdir(path):
    !mkdir {path}
  os.chdir(path)

  train_dataset = GwaltherDataset(root_dir='/content/paper/', df=df_train, processor=processor, transform=RandomChoice([transform,baseline]))
  eval_dataset  = GwaltherDataset(root_dir='/content/paper/', df=df_val,   processor=processor, transform=None)
  test_dataset  = GwaltherDataset(root_dir='/content/paper/', df=df_test,   processor=processor, transform=None)

  training_args = Seq2SeqTrainingArguments(
    predict_with_generate = True,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    load_best_model_at_end = True,
    save_total_limit = 1,
    per_device_train_batch_size = 8,
    per_device_eval_batch_size = 8,
    learning_rate = 2e-05,
#    warmup_steps = 500,
#    weight_decay = 0.0001,
#    lr_scheduler_type = "inverse_sqrt",

    fp16=True,
    output_dir=path,
#    logging_steps=2,
#    save_steps=450,
#    eval_steps=450,
    num_train_epochs=5,
  )

  trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=default_data_collator,
    # callbacks=[EarlyStoppingCallback(3, 0.0)]
  )

  trainer.train()
  eval=trainer.evaluate(test_dataset)
  print(eval)

In [None]:
from google.colab import runtime
runtime.unassign()