In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
from collections import Counter
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AdamW
from PIL import Image
from tqdm.notebook import tqdm
import evaluate

In [2]:
TRAIN_CSV = "D:\\Projects\\Priorbank\\Payment-logos\\For testing\\train_data.csv"
IMAGE_DIR = "D:\\Projects\\Priorbank\\Payment-logos\\For testing\\Train\\"
PROCESSOR_PATH = "microsoft/trocr-base-handwritten"
BATCH_SIZE = 10

In [3]:
df = pd.read_csv(TRAIN_CSV)
df.head()

Unnamed: 0,file_name,text
0,Belkart_0.jpg,Belkart
1,Belkart_1.jpg,Belkart
2,Belkart_1_0.png,Belkart
3,Belkart_1_1.png,Belkart
4,Belkart_1_10.png,Belkart


In [14]:
train_df = df[~df.file_name.str.startswith("Validation")]
train_df.reset_index(drop=True,inplace=True)

val_df = df[df.file_name.str.startswith("Validation")]
val_df.reset_index(drop=True,inplace=True)

print("Unique values for Train: " + str({**Counter(train_df.text)}))
print("Unique values for Validation: " + str({**Counter(val_df.text)}))

Unique values for Train: {'Belkart': 216, 'Mastercard': 102, 'Mir': 101, 'Visa Mastercard Mir': 226, 'Visa': 123, 'Visa Mastercard Belkart': 202}
Unique values for Validation: {'Visa Mastercard Mir': 101, 'Visa Mastercard': 100, 'Mir Belkart': 100}


In [4]:
""" np.random.seed(11)

train_df, val_df = train_test_split(df, test_size=0.2, shuffle=True)
train_df.reset_index(drop=True,inplace=True)
val_df.reset_index(drop=True,inplace=True)

print("Unique values for Train: " + str({**Counter(train_df.text)}))
print("Unique values for Validation: " + str({**Counter(val_df.text)})) """

Unique values for Train: {'Visa Mastercard Mir': 240, 'Visa': 99, 'Mir': 79, 'Mastercard': 83}
Unique values for Validation: {'Visa Mastercard Mir': 61, 'Mir': 22, 'Mastercard': 19, 'Visa': 24}


In [6]:
class Loader(Dataset):
    def __init__(self, data_dir, df, processor, max_length=10):
        super().__init__()
        self.data_dir = data_dir
        self.df = df
        self.max_length = max_length
        self.processor = processor

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        file_name = self.df["file_name"][index]
        text = self.df["text"][index]
        image = Image.open(self.data_dir + file_name).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        labels = self.processor.tokenizer(text,
                                          padding="max_length",
                                          max_length=self.max_length).input_ids
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
        
        return {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}

In [7]:
processor = TrOCRProcessor.from_pretrained(PROCESSOR_PATH)

train_dataset = Loader(data_dir=IMAGE_DIR, df=train_df,
                       processor=processor)
val_dataset = Loader(data_dir=IMAGE_DIR, df=val_df,
                      processor=processor)



In [8]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
eval_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = VisionEncoderDecoderModel.from_pretrained(PROCESSOR_PATH).to(device)

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
model.config.encoder.encoder_stride = 30
model.config.encoder.patch_size = 30
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size

model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

In [37]:
""" state = torch.load("D:\\Projects\\Priorbank\\Payment-logos\\For testing\\model_state.pt")
model.load_state_dict(state["state_dict"]) """

  state = torch.load("D:\\Projects\\Priorbank\\Payment-logos\\For testing\\model_state.pt")


<All keys matched successfully>

In [11]:
cer_metric = evaluate.load("cer")

def compute_cer(pred_ids, label_ids):
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return cer

In [18]:
optimizer = AdamW(model.parameters(), lr=1e-4)
#optimizer.load_state_dict(state["optimizer"])

for epoch in range(2):
   model.train()
   for batch in tqdm(train_dataloader):
      for k,v in batch.items():
        batch[k] = v.to(device)

      outputs = model(**batch)
      loss = outputs.loss
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

   model.eval()
   valid_cer = 0.0
   with torch.no_grad():
     for batch in tqdm(eval_dataloader):
       outputs = model.generate(batch["pixel_values"].to(device))
       cer = compute_cer(pred_ids=outputs, label_ids=batch["labels"])
       valid_cer += cer 

   print("Validation CER:", valid_cer / len(eval_dataloader))

  0%|          | 0/97 [00:00<?, ?it/s]

  0%|          | 0/31 [00:00<?, ?it/s]

Validation CER: 0.6160680293108137


  0%|          | 0/97 [00:00<?, ?it/s]

  0%|          | 0/31 [00:00<?, ?it/s]

Validation CER: 0.6160680293108137


In [19]:
model.eval()
with torch.no_grad():
    for batch in tqdm(eval_dataloader):
        outputs = model.generate(batch["pixel_values"].to(device))
        pred_str = processor.batch_decode(outputs, skip_special_tokens=True)
        batch["labels"][batch["labels"] == -100] = processor.tokenizer.pad_token_id
        label_str = processor.batch_decode(batch["labels"], skip_special_tokens=True)
        print("Prediction: " + str(pred_str) + "\n" + "True: " + str(label_str))
        print(compute_cer(pred_ids=outputs, label_ids=batch["labels"]))

  0%|          | 0/31 [00:00<?, ?it/s]

Prediction: ['Belkart', 'Belkart', 'Belkart', 'Belkart', 'Belkart', 'Belkart', 'Belkart', 'Belkart', 'Belkart', 'Belkart']
True: ['Visa Mastercard Mir', 'Visa Mastercard Mir', 'Visa Mastercard Mir', 'Visa Mastercard Mir', 'Visa Mastercard Mir', 'Visa Mastercard Mir', 'Visa Mastercard Mir', 'Visa Mastercard Mir', 'Visa Mastercard Mir', 'Visa Mastercard Mir']
0.8421052631578947
Prediction: ['Belkart', 'Belkart', 'Belkart', 'Belkart', 'Belkart', 'Belkart', 'Belkart', 'Belkart', 'Belkart', 'Belkart']
True: ['Visa Mastercard Mir', 'Visa Mastercard Mir', 'Visa Mastercard Mir', 'Visa Mastercard Mir', 'Visa Mastercard Mir', 'Visa Mastercard Mir', 'Visa Mastercard Mir', 'Visa Mastercard Mir', 'Visa Mastercard Mir', 'Visa Mastercard Mir']
0.8421052631578947
Prediction: ['Belkart', 'Belkart', 'Belkart', 'Belkart', 'Belkart', 'Belkart', 'Belkart', 'Belkart', 'Belkart', 'Belkart']
True: ['Visa Mastercard Mir', 'Visa Mastercard Mir', 'Visa Mastercard Mir', 'Visa Mastercard Mir', 'Visa Mastercard Mir

In [48]:
""" state = {
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
}
torch.save(state, "D:\\Projects\\Priorbank\\Payment-logos\\For testing\\model_state.pt") """