# Setup

In [33]:
import os
from pathlib import Path

BASE_DIR = Path(".")


def check_env() -> str:
    if os.environ.get('KAGGLE_KERNEL_RUN_TYPE'):
        print("Running on Kaggle")
        return "kaggle"
    else:
        print("Running locally")
        return "local"


ENV = check_env()

if ENV == "kaggle":
    data_dir = Path("/kaggle/input/ka-ocr")
else:
    data_dir = BASE_DIR / "data"

print(f"\nDataset contents in {data_dir}:")
for item in data_dir.iterdir():
    print(f"{item.name}")

Running locally

Dataset contents in data:
.cache
3d_unicode
alkroundedmtav-medium
alkroundednusx-medium
ar-archy-regular
arial_geo
arial_geo-bold
arial_geo-bold-italic
arial_geo-italic
bpg_algeti
bpg_algeti_compact
bpg_arial_2009
bpg_boxo
bpg_boxo-boxo
bpg_classic_medium
bpg_dedaena
bpg_dedaena_nonblock
bpg_excelsior_caps_dejavu_2010
bpg_excelsior_dejavu_2010
bpg_extrasquare_2009
bpg_extrasquare_mtavruli_2009
bpg_glaho
bpg_glaho_2008
bpg_glaho_arial
bpg_glaho_bold
bpg_glaho_sylfaen
bpg_glaho_traditional
bpg_ingiri_2008
bpg_irubaqidze
bpg_mrgvlovani_caps_2010
bpg_nino_elite_exp
bpg_nino_elite_ultra
bpg_nino_elite_ultra_caps
bpg_nino_medium_caps
bpg_nino_mtavruli_bold
bpg_nino_mtavruli_book
bpg_nino_mtavruli_normal
bpg_no9
bpg_nostalgia
bpg_paata
bpg_paata_caps
bpg_paata_cond
bpg_paata_cond_caps
bpg_paata_exp
bpg_phone_sans_bold
bpg_phone_sans_bold_italic
bpg_phone_sans_italic
bpg_quadrosquare_2009
bpg_rioni
bpg_rioni_contrast
bpg_rioni_vera
bpg_sans_2008
bpg_serif_2008
bpg_square_2009


# Explore data

In [34]:
import pandas as pd

In [35]:
df = pd.read_csv(data_dir/"metadata.csv")
print(df.head())
print(df.tail())

                        file_name      text
0  3d_unicode/3d_unicode_0000.png        რა
1  3d_unicode/3d_unicode_0001.png  წყდებოდა
2  3d_unicode/3d_unicode_0002.png   ჭჯნწფუქ
3  3d_unicode/3d_unicode_0003.png    წმინდა
4  3d_unicode/3d_unicode_0004.png     ჯილდო
                                         file_name         text
100495  NotoSansGeorgian/NotoSansGeorgian_1495.png     ხთვჰჟშხშ
100496  NotoSansGeorgian/NotoSansGeorgian_1496.png       რომლის
100497  NotoSansGeorgian/NotoSansGeorgian_1497.png  შემორჩენილი
100498  NotoSansGeorgian/NotoSansGeorgian_1498.png    ნეოლითური
100499  NotoSansGeorgian/NotoSansGeorgian_1499.png         უცხო


In [36]:
print(df["text"].value_counts())

text
და              4205
არ              1155
რომ             1044
იყო              785
კი               612
                ... 
ფოდჩ               1
ვჯნც               1
ტემპი              1
ჰიბიწნჟცჭჟქი       1
ჯპჭოჯგ             1
Name: count, Length: 37994, dtype: int64


In [37]:
# Check text length variations
df["text_len"] = df["text"].str.len()
print(df["text_len"].describe())

count    100500.000000
mean          6.396020
std           2.979581
min           2.000000
25%           4.000000
50%           6.000000
75%           8.000000
max          30.000000
Name: text_len, dtype: float64


# Prepare dataset and tokenizer

Checking if trocr already support tokenization for Georgian

In [38]:
from transformers import TrOCRProcessor
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")

# Test Georgian tokenization
test_text = "გამარჯობა"
tokens = processor.tokenizer.tokenize(test_text)
print(tokens)  # If you see lots of <unk> or weird splits, you need a custom tokenizer

['á', 'ĥ', 'Ĵ', 'á', 'ĥ', 'Ĳ', 'á', 'ĥ', 'Ľ', 'á', 'ĥ', 'Ĳ', 'á', 'ĥ', 'ł', 'á', 'ĥ', '¯', 'á', 'ĥ', 'Ŀ', 'á', 'ĥ', 'ĳ', 'á', 'ĥ', 'Ĳ']


That's not what we need, so we'll create custom, character-based tokenizer.

The model predicts next token, in this case token represents char, not a word.

In [39]:
class GeorgianTokenizer:
    def __init__(self, max_length: int = 32):
        # Special tokens
        self.pad_token = "<pad>"
        self.bos_token = "<s>"      # beginning of sequence
        self.eos_token = "</s>"     # end of sequence
        self.unk_token = "<unk>"    # unknown character

        # Georgian alphabet (33 letters)
        self.georgian_chars = "აბგდევზთიკლმნოპჟრსტუფქღყშჩცძწჭხჯჰ"

        # Build vocabulary: special tokens + Georgian characters
        self.vocab = [self.pad_token, self.bos_token, self.eos_token, self.unk_token]
        self.vocab.extend(list(self.georgian_chars))

        # Create mappings
        self.char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
        self.id_to_char = {idx: char for idx, char in enumerate(self.vocab)}

        # Token IDs for special tokens
        self.pad_token_id = 0
        self.bos_token_id = 1
        self.eos_token_id = 2
        self.unk_token_id = 3

        self.max_length = max_length

    def encode(self, text: str, padding: bool = True) -> list[int]:
        """Convert Georgian text to token IDs."""
        # Start with BOS token
        ids = [self.bos_token_id]

        # Convert each character
        for char in text:
            ids.append(self.char_to_id.get(char, self.unk_token_id))

        # Add EOS token
        ids.append(self.eos_token_id)

        # Truncate if too long
        if len(ids) > self.max_length:
            ids = ids[:self.max_length - 1] + [self.eos_token_id]

        # Pad if needed
        if padding:
            ids.extend([self.pad_token_id] * (self.max_length - len(ids)))

        return ids

    def decode(self, ids: list[int]) -> str:
        """Convert token IDs back to text."""
        chars = []
        for id in ids:
            if id in (self.pad_token_id, self.bos_token_id, self.eos_token_id):
                continue
            chars.append(self.id_to_char.get(id, ""))
        return "".join(chars)

    def __len__(self):
        return len(self.vocab)

Test tokenization again

In [40]:
tokenizer = GeorgianTokenizer(max_length=32)

# Test encoding
text = "გამარჯობა"
ids = tokenizer.encode(text)
print(f"Text: {text}")
print(f"IDs: {ids[:15]}...")  # First 15 tokens
print(f"Length: {len(ids)}")

# Test decoding
decoded = tokenizer.decode(ids)
print(f"Decoded: {decoded}")

# Verify vocab size
print(f"Vocab size: {len(tokenizer)}")  # Should be 37 (4 special + 33 Georgian)

Text: გამარჯობა
IDs: [1, 6, 4, 15, 4, 20, 35, 17, 5, 4, 2, 0, 0, 0, 0]...
Length: 32
Decoded: გამარჯობა
Vocab size: 37


Works as expected.

Now we prepare dataset using this tokenizer

In [41]:
import torch
from torch.utils.data import Dataset
from PIL import Image, ImageOps


class GeorgianOCRDataset(Dataset):
    def __init__(self, df: pd.DataFrame, root_dir: str, processor, tokenizer: GeorgianTokenizer):
        self.df = df.reset_index(drop=True)
        self.root_dir = root_dir
        self.processor = processor
        self.tokenizer = tokenizer  # custom tokenizer

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
        text = self.df.iloc[idx]['text']
        file_path = f"{self.root_dir}/{self.df.iloc[idx]['file_name']}"

        # Open and process image
        img = Image.open(file_path).convert("RGB")
        w, h = img.size
        target_size = 384

        # Scale height to target_size, width proportionally
        scale = target_size / max(w, h)
        new_w, new_h = int(w * scale), int(h * scale)
        img = img.resize((new_w, new_h), Image.Resampling.BILINEAR)

        # Pad to square
        new_img = Image.new("RGB", (target_size, target_size), (255, 255, 255))
        offset = ((target_size - new_w) // 2, (target_size - new_h) // 2)
        new_img.paste(img, offset)

        # Use Processor for Normalization
        pixel_values = self.processor(new_img, return_tensors="pt").pixel_values

        # Tokenize Georgian Text
        labels = self.tokenizer.encode(text)

        # Replace padding token id with -100 so it's ignored by the loss function
        labels = [label if label != self.tokenizer.pad_token_id else -100 for label in labels]

        return {
            "pixel_values": pixel_values.squeeze(),
            "labels": torch.tensor(labels)
        }


Prepare model...

we give it our 37 tokens so model predicts only 37 possible outputs instead of original 50k.

In [42]:
from transformers import VisionEncoderDecoderModel

# Create your tokenizer
tokenizer = GeorgianTokenizer(max_length=32)

# Load model and resize token embeddings
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
model.decoder.resize_token_embeddings(len(tokenizer))  # Resize to 37

# Configure special tokens
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.eos_token_id

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-printed and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Train test split

In [43]:
from sklearn.model_selection import train_test_split


train_df, test_df = train_test_split(
    df, 
    test_size=0.10, 
    random_state=42, 
    shuffle=True
)

print(train_df["text"].value_counts())

text
და               3803
არ               1044
რომ               947
იყო               705
კი                546
                 ... 
სლგწრნუძწასდ        1
ლბევწდც             1
+995549397648       1
სიამოვნებასაც       1
ზელენსკის           1
Name: count, Length: 35172, dtype: int64


# Dataloaders

In [44]:
from torch.utils.data import DataLoader

train_dataset = GeorgianOCRDataset(train_df, data_dir, processor, tokenizer)
test_dataset = GeorgianOCRDataset(test_df, data_dir, processor, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

print(f"Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")

Train batches: 11307, Test batches: 1257


# Set up training

In [45]:
from torch.optim import AdamW

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=5e-5)

print(f"Training on: {device}")

Training on: cpu


# Training loop

In [None]:
num_epochs = 3

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch_idx, batch in enumerate(train_loader):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(pixel_values=pixel_values, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if batch_idx % 100 == 0:
            print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} complete. Avg Loss: {avg_loss:.4f}")

# Evaluation

# Saving the fine-tuned model