# Setup

In [2]:
import os
from pathlib import Path

BASE_DIR = Path(".")


def check_env() -> str:
    if os.environ.get('KAGGLE_KERNEL_RUN_TYPE'):
        print("Running on Kaggle")
        return "kaggle"
    else:
        print("Running locally")
        return "local"


ENV = check_env()

if ENV == "kaggle":
    data_dir = Path("/kaggle/input/ka-ocr")
else:
    data_dir = BASE_DIR / "data"

print(f"\nDataset contents in {data_dir}:")
for item in data_dir.iterdir():
    print(f"{item.name}")

Running on Kaggle

Dataset contents in /kaggle/input/ka-ocr:
bpg_phone_sans_bold_italic
bpg_glaho_bold
bpg_irubaqidze
bpg_boxo-boxo
bpg_glaho_2008
bpg_paata_cond
mg_bitneon_chaos
literaturulitt
version.txt
ar-archy-regular
bpg_paata_caps
mg_bitneon
fixedsys_excelsior
ka_literaturuli
gugeshashvili_slfn_2
bpg_glaho
bpg_dedaena
NotoSansGeorgian
gf_aisi_nus-bold-italic
arial_geo-bold-italic
arial_geo-italic
bpg_sans_2008
bpg_nostalgia
bpg_supersquare_2009
arial_geo
bpg_phone_sans_bold
alkroundednusx-medium
gf_aisi_nus_medium-medium-italic
bpg_phone_sans_italic
bpg_excelsior_caps_dejavu_2010
bpg_nino_elite_ultra
bpg_rioni_contrast
arial_geo-bold
bpg_no9
3d_unicode
bpg_dedaena_nonblock
bpg_extrasquare_2009
bpg_algeti_compact
ka_lortkipanidze
alkroundedmtav-medium
bpg_nino_mtavruli_bold
bpg_glaho_sylfaen
bpg_mrgvlovani_caps_2010
bpg_nino_elite_exp
bpg_ucnobi
bpg_arial_2009
bpg_boxo
bpg_glaho_arial
bpg_rioni_vera
bpg_paata_cond_caps
bpg_nino_mtavruli_book
bpg_paata_exp
bpg_classic_medium
bpg_r

# Explore data

In [3]:
import pandas as pd

In [4]:
df = pd.read_csv(data_dir/"metadata.csv")
print(df.head())
print(df.tail())

                        file_name      text
0  3d_unicode/3d_unicode_0000.png        ·É†·Éê
1  3d_unicode/3d_unicode_0001.png  ·É¨·Éß·Éì·Éî·Éë·Éù·Éì·Éê
2  3d_unicode/3d_unicode_0002.png   ·É≠·ÉØ·Éú·É¨·É§·É£·É•
3  3d_unicode/3d_unicode_0003.png    ·É¨·Éõ·Éò·Éú·Éì·Éê
4  3d_unicode/3d_unicode_0004.png     ·ÉØ·Éò·Éö·Éì·Éù
                                         file_name         text
100495  NotoSansGeorgian/NotoSansGeorgian_1495.png     ·ÉÆ·Éó·Éï·É∞·Éü·É®·ÉÆ·É®
100496  NotoSansGeorgian/NotoSansGeorgian_1496.png       ·É†·Éù·Éõ·Éö·Éò·É°
100497  NotoSansGeorgian/NotoSansGeorgian_1497.png  ·É®·Éî·Éõ·Éù·É†·É©·Éî·Éú·Éò·Éö·Éò
100498  NotoSansGeorgian/NotoSansGeorgian_1498.png    ·Éú·Éî·Éù·Éö·Éò·Éó·É£·É†·Éò
100499  NotoSansGeorgian/NotoSansGeorgian_1499.png         ·É£·É™·ÉÆ·Éù


In [5]:
print(df["text"].value_counts())

text
·Éì·Éê              4205
·Éê·É†              1155
·É†·Éù·Éõ             1044
·Éò·Éß·Éù              785
·Éô·Éò               612
                ... 
·Éò·É§·É®·Éï·Éú·Éî·É¢·É°           1
·Éì·Éê·Éí·Éï·Éî·ÉÆ·Éê·É†·ÉØ·Éù·É°        1
·É•·É†·Éù·Éõ·Éê·É¢·Éò·Éì·Éî·Éë·Éò·É°       1
·Éõ·Éê·Éí·É†·Éê·Éó·Éê            1
·É®·É®·É•·É¢·Éî·É®             1
Name: count, Length: 37994, dtype: int64


In [6]:
# Check text length variations
df["text_len"] = df["text"].str.len()
print(df["text_len"].describe())

count    100500.000000
mean          6.396020
std           2.979581
min           2.000000
25%           4.000000
50%           6.000000
75%           8.000000
max          30.000000
Name: text_len, dtype: float64


# Prepare dataset and tokenizer

Checking if trocr already support tokenization for Georgian

In [7]:
from transformers import TrOCRProcessor
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")

# Test Georgian tokenization
test_text = "·Éí·Éê·Éõ·Éê·É†·ÉØ·Éù·Éë·Éê"
tokens = processor.tokenizer.tokenize(test_text)
print(tokens)  # If you see lots of <unk> or weird splits, you need a custom tokenizer

2026-01-24 06:08:59.891142: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769234940.097086      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769234940.161816      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1769234940.653243      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769234940.653306      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769234940.653309      55 computation_placer.cc:177] computation placer alr

preprocessor_config.json:   0%|          | 0.00/224 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

['√°', 'ƒ•', 'ƒ¥', '√°', 'ƒ•', 'ƒ≤', '√°', 'ƒ•', 'ƒΩ', '√°', 'ƒ•', 'ƒ≤', '√°', 'ƒ•', '≈Ç', '√°', 'ƒ•', '¬Ø', '√°', 'ƒ•', 'ƒø', '√°', 'ƒ•', 'ƒ≥', '√°', 'ƒ•', 'ƒ≤']


That's not what we need, so we'll create custom, character-based tokenizer.

The model predicts next token, in this case token represents char, not a word.

In [8]:
class GeorgianTokenizer:
    def __init__(self, max_length: int = 32):
        # Special tokens
        self.pad_token = "<pad>"
        self.bos_token = "<s>"      # beginning of sequence
        self.eos_token = "</s>"     # end of sequence
        self.unk_token = "<unk>"    # unknown character

        # Georgian alphabet (33 letters)
        self.georgian_chars = "·Éê·Éë·Éí·Éì·Éî·Éï·Éñ·Éó·Éò·Éô·Éö·Éõ·Éú·Éù·Éû·Éü·É†·É°·É¢·É£·É§·É•·É¶·Éß·É®·É©·É™·É´·É¨·É≠·ÉÆ·ÉØ·É∞"

        # Build vocabulary: special tokens + Georgian characters
        self.vocab = [self.pad_token, self.bos_token, self.eos_token, self.unk_token]
        self.vocab.extend(list(self.georgian_chars))

        # Create mappings
        self.char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
        self.id_to_char = {idx: char for idx, char in enumerate(self.vocab)}

        # Token IDs for special tokens
        self.pad_token_id = 0
        self.bos_token_id = 1
        self.eos_token_id = 2
        self.unk_token_id = 3

        self.max_length = max_length

    def encode(self, text: str, padding: bool = True) -> list[int]:
        """Convert Georgian text to token IDs."""
        # Start with BOS token
        ids = [self.bos_token_id]

        # Convert each character
        for char in text:
            ids.append(self.char_to_id.get(char, self.unk_token_id))

        # Add EOS token
        ids.append(self.eos_token_id)

        # Truncate if too long
        if len(ids) > self.max_length:
            ids = ids[:self.max_length - 1] + [self.eos_token_id]

        # Pad if needed
        if padding:
            ids.extend([self.pad_token_id] * (self.max_length - len(ids)))

        return ids

    def decode(self, ids: list[int]) -> str:
        """Convert token IDs back to text."""
        chars = []
        for id in ids:
            if id in (self.pad_token_id, self.bos_token_id, self.eos_token_id):
                continue
            chars.append(self.id_to_char.get(id, ""))
        return "".join(chars)

    def __len__(self):
        return len(self.vocab)

Test tokenization again

In [9]:
tokenizer = GeorgianTokenizer(max_length=32)

# Test encoding
text = "·Éí·Éê·Éõ·Éê·É†·ÉØ·Éù·Éë·Éê"
ids = tokenizer.encode(text)
print(f"Text: {text}")
print(f"IDs: {ids[:15]}...")  # First 15 tokens
print(f"Length: {len(ids)}")

# Test decoding
decoded = tokenizer.decode(ids)
print(f"Decoded: {decoded}")

# Verify vocab size
print(f"Vocab size: {len(tokenizer)}")  # Should be 37 (4 special + 33 Georgian)

Text: ·Éí·Éê·Éõ·Éê·É†·ÉØ·Éù·Éë·Éê
IDs: [1, 6, 4, 15, 4, 20, 35, 17, 5, 4, 2, 0, 0, 0, 0]...
Length: 32
Decoded: ·Éí·Éê·Éõ·Éê·É†·ÉØ·Éù·Éë·Éê
Vocab size: 37


Works as expected.

Now we prepare dataset using this tokenizer

In [10]:
import torch
from torch.utils.data import Dataset
from PIL import Image, ImageOps


class GeorgianOCRDataset(Dataset):
    def __init__(self, df: pd.DataFrame, root_dir: str, processor, tokenizer: GeorgianTokenizer):
        self.df = df.reset_index(drop=True)
        self.root_dir = root_dir
        self.processor = processor
        self.tokenizer = tokenizer  # custom tokenizer

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
        text = self.df.iloc[idx]['text']
        file_path = f"{self.root_dir}/{self.df.iloc[idx]['file_name']}"

        # Open and process image
        img = Image.open(file_path).convert("RGB")
        w, h = img.size
        target_size = 384

        # Scale height to target_size, width proportionally
        scale = target_size / max(w, h)
        new_w, new_h = int(w * scale), int(h * scale)
        img = img.resize((new_w, new_h), Image.Resampling.BILINEAR)

        # Pad to square
        new_img = Image.new("RGB", (target_size, target_size), (255, 255, 255))
        offset = ((target_size - new_w) // 2, (target_size - new_h) // 2)
        new_img.paste(img, offset)

        # Use Processor for Normalization
        pixel_values = self.processor(new_img, return_tensors="pt").pixel_values

        # Tokenize Georgian Text
        labels = self.tokenizer.encode(text)

        # Replace padding token id with -100 so it's ignored by the loss function
        labels = [label if label != self.tokenizer.pad_token_id else -100 for label in labels]

        return {
            "pixel_values": pixel_values.squeeze(),
            "labels": torch.tensor(labels)
        }


Prepare model...

we give it our 37 tokens so model predicts only 37 possible outputs instead of original 50k.

In [11]:
from transformers import VisionEncoderDecoderModel

# Create your tokenizer
tokenizer = GeorgianTokenizer(max_length=32)

# Load model and resize token embeddings
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
model.decoder.resize_token_embeddings(len(tokenizer))  # Resize to 37

# Configure special tokens
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.eos_token_id

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-printed and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


generation_config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

# Train test split

In [12]:
from sklearn.model_selection import train_test_split


train_df, test_df = train_test_split(
    df, 
    test_size=0.10, 
    random_state=42, 
    shuffle=True
)

print(train_df["text"].value_counts())

text
·Éì·Éê          3803
·Éê·É†          1044
·É†·Éù·Éõ          947
·Éò·Éß·Éù          705
·Éô·Éò           546
            ... 
·Éò·É¢·Éê·Éö·Éò·É£·É†·Éò       1
·É•·É£·É©·Éò·É°·Éê         1
·É®·Éó·Éù·Éô           1
·É¶·ÉÆ·Éù·Éõ·Éô·É¢·É™·Éê       1
·É®·Éõ·Éù·Éò           1
Name: count, Length: 35172, dtype: int64


# Dataloaders

In [13]:
from torch.utils.data import DataLoader

train_dataset = GeorgianOCRDataset(train_df, data_dir, processor, tokenizer)
test_dataset = GeorgianOCRDataset(test_df, data_dir, processor, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

print(f"Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")

Train batches: 11307, Test batches: 1257


# Set up training

In [14]:
from torch.optim import AdamW

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=5e-5)

print(f"Training on: {device}")

Training on: cuda


# Validation function

In [None]:
from evaluate import load
import torch

# Load the CER metric (standard for OCR)
cer_metric = load("cer")

def validate_model(
    model: torch.nn.Module, 
    val_loader: torch.utils.data.DataLoader, 
    processor: any, 
    device: torch.device
) -> float:
    model.eval()
    predictions: list[str] = []
    references: list[str] = []
    
    with torch.no_grad():
        for batch in val_loader:
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)

            # Generate text from image
            outputs = model.generate(pixel_values)
            
            # Convert tokens back to strings
            pred_str = processor.batch_decode(outputs, skip_special_tokens=True)
            
            # Convert label tokens back to strings (ignoring -100 padding)
            labels[labels == -100] = processor.tokenizer.pad_token_id
            label_str = processor.batch_decode(labels, skip_special_tokens=True)
            
            predictions.extend(pred_str)
            references.extend(label_str)
    
    # Calculate Character Error Rate
    cer_score: float = cer_metric.compute(predictions=predictions, references=references)
    return cer_score

# Training loop

In [None]:
def train_model(
    model: torch.nn.Module, 
    train_loader: torch.utils.data.DataLoader, 
    optimizer: torch.optim.Optimizer, 
    device: torch.device, 
    epochs: int = 3,
    save_every: int = 1000
) -> None:
    
    for epoch in range(epochs):
        model.train()
        print(f"\n--- Starting Epoch {epoch} ---")
        
        for batch_idx, batch in enumerate(train_loader):
            try:
                # Prepare data
                pixel_values: torch.Tensor = batch["pixel_values"].to(device)
                labels: torch.Tensor = batch["labels"].to(device)

                # Forward pass
                outputs = model(pixel_values=pixel_values, labels=labels)
                loss: torch.Tensor = outputs.loss

                # Backward pass
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                # Logging
                if batch_idx % 100 == 0:
                    print(f"Epoch: {epoch} | Batch: {batch_idx}/{len(train_loader)} | Loss: {loss.item():.4f}")

                # 3. Checkpointing Logic
                if batch_idx > 0 and batch_idx % save_every == 0:
                    checkpoint_name: str = f"trocr_georgian_e{epoch}_s{batch_idx}.pt"
                    checkpoint_path: str = os.path.join("/kaggle/working/", checkpoint_name)
                    
                    torch.save({
                        'epoch': epoch,
                        'batch_idx': batch_idx,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': loss.item(),
                    }, checkpoint_path)
                    print(f"üíæ Saved checkpoint to {checkpoint_path}")

            except RuntimeError as e:
                # 4. The OOM Shield
                if "out of memory" in str(e).lower():
                    print(f"GPU OOM detected at batch {batch_idx}. Cleaning memory and skipping...")
                    
                    # Manually clear all variables that could be holding GPU references
                    if 'outputs' in locals(): del outputs
                    if 'loss' in locals(): del loss
                    del pixel_values, labels
                    
                    optimizer.zero_grad(set_to_none=True) # Heavy-duty grad clearing
                    gc.collect()                          # Python garbage collection
                    torch.cuda.empty_cache()              # Clear NVIDIA cache
                    continue 
                else:
                    raise e # Re-raise if it's a different error

        # At the end of the epoch, check accuracy
        print(f"üìä Running Validation for Epoch {epoch}...")
        current_cer: float = validate_model(model, test_loader, processor, device)
        
        print(f"‚úÖ Epoch {epoch} Results:")
        print(f"   Character Error Rate (CER): {current_cer:.4f}")
        print(f"   (Translation: {100 - (current_cer*100):.2f}% character accuracy)")

# Saving the fine-tuned model