In [9]:
import os
import platform

from google.colab import drive
drive.mount("/content/drive")

In [None]:
!unzip /content/drive/MyDrive/datasets/ocr_data.zip -d . 1> /dev/null

In [None]:
import os
import glob
import torch
import mimetypes
from PIL import Image
from collections import deque
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

os.environ["HF_TOKEN"] = "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"
# os.environ["CURL_CA_BUNDLE"] = ""
folder = ".cache/"
param_dir = "/content/drive/MyDrive/trocr"
os.makedirs(param_dir, exist_ok=True)

In [10]:
TRAIN_DATA = "data/train"
VAL_DATA = "data/test"
LEARNING_RATE = 0.0001
EPOCHS = 10
BATCH_SIZE = 2
ACCUMULATION_STEPS = 2
MEAN_WINDOW = 10
MIN_PROGRESS = 0.1
MAX_LOSS_DIFF = 0.25
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [11]:
t_losses = deque(maxlen=MEAN_WINDOW)
v_losses = deque(maxlen=MEAN_WINDOW)
t_loss_sum, v_loss_sum = 0, 0

In [12]:
class OCRDataset(Dataset):
    def __init__(self, folder, sep=":"):
        self.images = []
        for f in glob.glob(os.path.join(folder, "*")):
            t, _ = mimetypes.guess_type(f)
            if t and t.startswith("image"):
                self.images.append(f)
        self.labels = {}
        lbl_path = os.path.join(folder, "meta/labels.txt")
        for line in open(lbl_path, encoding="utf-8").read().splitlines():
            splits = line.split(sep, maxsplit=1)
            self.labels[splits[0]] = splits[1]

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx] 
        return image, self.labels[os.path.basename(image)]
    
    @staticmethod
    def collate(batch):
        batch = list(zip(*batch))
        return list(batch[0]), list(batch[1])

In [None]:
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-printed")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-printed")
model = model.to(DEVICE)

In [14]:
model.config.decoder_start_token_id = processor.tokenizer.eos_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size

In [83]:
train_dataloader = DataLoader(OCRDataset(TRAIN_DATA), batch_size=BATCH_SIZE, collate_fn=OCRDataset.collate, shuffle=True)
val_dataloader = DataLoader(OCRDataset(VAL_DATA), batch_size=BATCH_SIZE, collate_fn=OCRDataset.collate, shuffle=True)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

In [84]:
@torch.no_grad()
def get_loss(split, batches=ACCUMULATION_STEPS):
    model.eval()
    acc_loss = 0
    dataloader = train_dataloader if split == "train" else val_dataloader
    for b_no, (images, labels) in enumerate(dataloader, start=1):
        images = [Image.open(img).convert("RGB") for img in images]
        pixel_values = processor(images, return_tensors="pt").pixel_values.to(DEVICE)
        labels = processor.tokenizer(labels, return_tensors="pt", padding=True).input_ids.to(DEVICE)
        outputs = model(pixel_values=pixel_values, labels=labels)
        acc_loss += outputs.loss.item()
        if b_no == batches:
            break
    model.train()
    return acc_loss / batches


def update_loss_stat(split, loss):
    global t_loss_sum, v_loss_sum
    losses = t_losses if split == "train" else v_losses
    first_val = 0
    if len(losses) == losses.maxlen:
        first_val = losses.popleft()
    if split == "train":
        t_loss_sum += (loss - first_val)
    else:
        v_loss_sum += (loss - first_val)
    losses.append(loss)
    return (t_loss_sum if split == "train" else v_loss_sum) / len(losses)


def save_model(mt_loss, mv_loss, folder=param_dir):
    name = "trocr_%.4f_%.4f.pth" % (mt_loss, mv_loss)
    torch.save(model.state_dict(), os.path.join(folder, name))

In [85]:
for e_no in range(1, EPOCHS + 1):
    for b_no, (images, labels) in enumerate(train_dataloader, start=1):
        images = [Image.open(img).convert("RGB") for img in images]
        pixel_values = processor(images, return_tensors="pt").pixel_values.to(DEVICE)
        labels = processor.tokenizer(labels, return_tensors="pt", padding=True).input_ids.to(DEVICE)
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss = outputs.loss
        loss = loss / ACCUMULATION_STEPS
        loss.backward()
        if b_no % ACCUMULATION_STEPS == 0:
            optimizer.step()
            optimizer.zero_grad()
            t_loss, v_loss = get_loss("train"), get_loss("val")
            mt_loss = update_loss_stat("train", t_loss)
            mv_loss = update_loss_stat("val", v_loss)
            print("train: (%.4f | %.4f), val: (%.4f | %.4f)" % (t_loss, mt_loss, v_loss, mv_loss))
        if (e_no > 1 or len(t_losses) >= MEAN_WINDOW) and mv_loss - v_loss >= MIN_PROGRESS:
            if abs(mt_loss - mv_loss) > MAX_LOSS_DIFF:
                overfitting = mt_loss - mv_loss > 0
                word = "overfitting" if overfitting else "underfitting"
                print("the model seems to be %s" % (word, ))
            else:
                save_model(mt_loss, mv_loss)

In [15]:
def infer(image_path):
    image = Image.open(image_path).convert("RGB")
    pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(DEVICE)
    generated_ids = model.generate(pixel_values)
    return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]