# Setup

In [1]:
import os
import zipfile
from pathlib import Path
from dotenv import load_dotenv


def check_env() -> str:
    if os.environ.get('KAGGLE_KERNEL_RUN_TYPE'):
        print("Running on Kaggle")
        return "kaggle"
    else:
        print("Running locally")
        return "local"


ENV = check_env()

if ENV == "kaggle":
    data_dir = Path("/kaggle/input/ka-ocr")
else:
    load_dotenv()

    from huggingface_hub import hf_hub_download

    data_dir = Path("./data")
    data_dir.mkdir(parents=True, exist_ok=True)

    hf_repo = os.getenv("HF_DATASET_REPO")
    hf_token = os.getenv("HF_TOKEN")

    if not hf_repo:
        raise ValueError("HF_DATASET_REPO not set in .env")

    # Download with automatic caching - skips if local matches remote (etag-based)
    zip_path = hf_hub_download(
        repo_id=hf_repo,
        filename="ka-ocr.zip",
        repo_type="dataset",
        token=hf_token,
        local_dir=data_dir,
    )

    # Extract only if not already extracted OR if zip is newer than extraction
    extract_marker = data_dir / ".extracted"
    zip_file = Path(zip_path)
    needs_extract = (
        not extract_marker.exists() or
        zip_file.stat().st_mtime > extract_marker.stat().st_mtime
    )

    if needs_extract:
        print("Extracting dataset...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(data_dir)
        extract_marker.touch()
        print("Extraction complete")
    else:
        print("Dataset already extracted, skipping")

print(f"\nDataset contents in {data_dir}:")
for item in data_dir.iterdir():
    if not item.name.startswith('.') and item.name != "ka-ocr.zip":
        print(f"  {item.name}")

Running locally


  from .autonotebook import tqdm as notebook_tqdm


Dataset already extracted, skipping

Dataset contents in data:
  3d_unicode
  alkroundedmtav-medium
  alkroundednusx-medium
  ar-archy-regular
  arial_geo
  arial_geo-bold
  arial_geo-bold-italic
  arial_geo-italic
  bpg_algeti
  bpg_algeti_compact
  bpg_arial_2009
  bpg_boxo
  bpg_boxo-boxo
  bpg_classic_medium
  bpg_dedaena
  bpg_dedaena_nonblock
  bpg_excelsior_caps_dejavu_2010
  bpg_excelsior_dejavu_2010
  bpg_extrasquare_2009
  bpg_extrasquare_mtavruli_2009
  bpg_glaho
  bpg_glaho_2008
  bpg_glaho_arial
  bpg_glaho_bold
  bpg_glaho_sylfaen
  bpg_glaho_traditional
  bpg_ingiri_2008
  bpg_irubaqidze
  bpg_mrgvlovani_caps_2010
  bpg_nino_elite_exp
  bpg_nino_elite_ultra
  bpg_nino_elite_ultra_caps
  bpg_nino_medium_caps
  bpg_nino_mtavruli_bold
  bpg_nino_mtavruli_book
  bpg_nino_mtavruli_normal
  bpg_no9
  bpg_nostalgia
  bpg_paata
  bpg_paata_caps
  bpg_paata_cond
  bpg_paata_cond_caps
  bpg_paata_exp
  bpg_phone_sans_bold
  bpg_phone_sans_bold_italic
  bpg_phone_sans_italic
  bpg_

# Explore data

In [2]:
import pandas as pd

In [6]:
df = pd.read_csv(data_dir/"metadata.csv")
print(df.head())
print(df.tail())

                        file_name          text
0  3d_unicode/3d_unicode_0000.png          ამათ
1  3d_unicode/3d_unicode_0001.png   პარტიანკაში
2  3d_unicode/3d_unicode_0002.png  კომენტარების
3  3d_unicode/3d_unicode_0003.png       ფრიდრიხ
4  3d_unicode/3d_unicode_0004.png   ცდწლოოწნწში
                                         file_name      text
100495  NotoSansGeorgian/NotoSansGeorgian_1495.png     რიგში
100496  NotoSansGeorgian/NotoSansGeorgian_1496.png     ალიკა
100497  NotoSansGeorgian/NotoSansGeorgian_1497.png      ტარს
100498  NotoSansGeorgian/NotoSansGeorgian_1498.png     კარგი
100499  NotoSansGeorgian/NotoSansGeorgian_1499.png  სასახლეს


In [4]:
print(df["text"].value_counts())

text
და              4221
არ              1163
რომ              995
იყო              768
კი               604
                ... 
ფუფთწჟჯკდპბგ       1
ოოწლღდჩქთტტ        1
გითქვამს           1
სიკვდილია          1
ბგშჩდეკშჰ          1
Name: count, Length: 38003, dtype: int64


In [8]:
# Check text length variations
df["text_len"] = df["text"].str.len()
print(df["text_len"].describe())

count    100500.000000
mean          6.390886
std           2.970310
min           2.000000
25%           4.000000
50%           6.000000
75%           8.000000
max          24.000000
Name: text_len, dtype: float64


# Prepare images

In [None]:
import torch
from torch.utils.data import Dataset
from PIL import Image, ImageOps


class GeorgianOCRDataset(Dataset):
    def __init__(self, df: pd.DataFrame, root_dir: str, processor: object, max_target_length: int = 32):
        self.df = df
        self.root_dir = root_dir
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
        file_path = f"{self.root_dir}/{self.df['file_name'][idx]}"
        text = self.df['text'][idx]

        # Open and convert to RGB
        img = Image.open(file_path).convert("RGB")

        # Smart Resize (Letterbox)
        w, h = img.size
        target_size = 384

        # Scale height to target_size, width proportionally
        new_w = int(w * (target_size / h))
        img = img.resize((new_w, target_size), Image.Resampling.BILINEAR)

        if new_w <= target_size:
            # Pad the width to make it square
            new_img = Image.new("RGB", (target_size, target_size), (255, 255, 255))
            offset = ((target_size - new_w) // 2, 0)
            new_img.paste(img, offset)
        else:
            # If still too wide, force resize to square (slight squish)
            new_img = img.resize((target_size, target_size), Image.Resampling.BILINEAR)

        # Use Processor for Normalization
        pixel_values = self.processor(new_img, return_tensors="pt").pixel_values

        # Tokenize Georgian Text
        labels = self.processor.tokenizer(
            text,
            padding="max_length",
            max_length=self.max_target_length
        ).input_ids

        # Replace padding token id with -100 so it's ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        return {
            "pixel_values": pixel_values.squeeze(),
            "labels": torch.tensor(labels)
        }
