# Setup

In [10]:
import pandas as pd
import torch
from pathlib import Path
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
from transformers import (
    TrOCRProcessor,
    VisionEncoderDecoderModel,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    default_data_collator
)
from evaluate import load

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
data_dir = Path("data")

In [6]:
metadata_df = pd.read_csv(data_dir / "metadata.csv")
metadata_df.head(5)

Unnamed: 0,file_name,text
0,3d_unicode/3d_unicode_0000.png,ამათ
1,3d_unicode/3d_unicode_0001.png,პარტიანკაში
2,3d_unicode/3d_unicode_0002.png,კომენტარების
3,3d_unicode/3d_unicode_0003.png,ფრიდრიხ
4,3d_unicode/3d_unicode_0004.png,ცდწლოოწნწში


In [8]:
most_freq_word = metadata_df["text"].value_counts()
most_freq_word

text
და              4221
არ              1163
რომ              995
იყო              768
კი               604
                ... 
ფუფთწჟჯკდპბგ       1
ოოწლღდჩქთტტ        1
გითქვამს           1
სიკვდილია          1
ბგშჩდეკშჰ          1
Name: count, Length: 38003, dtype: int64

# Train-test splitting

In [11]:
# 95% for training, 5% testing
train_df, test_df = train_test_split(metadata_df, test_size=0.05, random_state=42)

# clear indexes
train_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)

print(f"Train samples: {len(train_df)}")
print(f"Test samples: {len(test_df)}")

Train samples: 95475
Test samples: 5025


In [13]:
print(train_df["text"].value_counts())

text
და               4027
არ               1104
რომ               937
იყო               732
კი                577
                 ... 
დზშ                 1
პთჩჰოსჭჰჰ           1
+995565947364       1
ტწილშთთჰლრ          1
გიყურებენ           1
Name: count, Length: 36655, dtype: int64


In [14]:
print(test_df["text"].value_counts())

text
და            194
არ             59
რომ            58
იყო            36
რა             28
             ... 
ჩვილი           1
იბბ             1
პოლიციელნი      1
გთხოვ           1
მოდის           1
Name: count, Length: 3443, dtype: int64


# PyTorch dataset

In [12]:
class GeorgianOCRDataset(Dataset):
    def __init__(self, df, processor, max_target_length=32):
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # read images
        file_path = self.df.iloc[idx]['file_name']
        image = Image.open(f"data/{file_path}").convert("RGB")
        
        # process images
        pixel_values = self.processor(image, return_tensors="pt").pixel_values

        # text tokenization
        text = self.df.iloc[idx]['text']
        labels = self.processor.tokenizer(
            text, 
            padding="max_length", 
            max_length=self.max_target_length,
            truncation=True
        ).input_ids

        # replace with -100 to be ignored during loss calculation
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding
