In [6]:
import itertools
import os, json, random
from pathlib import Path
from collections import defaultdict

from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models
import torch.optim as optim
import matplotlib.pyplot as plt           # optional for sanity-check
import regex  # pip install regex  (needed for \X = Unicode grapheme)
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
from matplotlib import font_manager
from PIL import Image, ImageOps

In [9]:
"""
make_line_crops.py
────────────────────────────────────────────────────────────────────────
• Extracts line-level crops from PAGE-level images using word-box JSON.
• Normalises every crop to identical size (H=32 px, W=512 px).
• Builds the Sinhala grapheme-cluster charset and writes:
      ├─ <CROP_DIR>/labels.json   (crop_name  →  full line text)
      └─ <CROP_DIR>/chars.txt     (one cluster per line)
"""

import os, json, regex, math, random
from pathlib import Path
from collections import defaultdict

from PIL import Image, ImageOps, ImageFont, ImageDraw
import matplotlib.pyplot as plt
from matplotlib import font_manager

# ─────────────────── 0.  EDIT THESE PATHS  ─────────────────────────
PAGE_DIR = Path(r"D:\python\data\images")          # page-level images
ANN_FILE = Path(r"D:\python\data\json\labels.json")
CROP_DIR = Path(r"D:\python\data\line_crops")      # new crops + ckpts
CKPT_DIR = CROP_DIR / "checkpoints"

font_path = (
    r"C:/Users/ASUS/Downloads/Noto_Sans_Sinhala,Yuji_Mai/"
    r"Noto_Sans_Sinhala/NotoSansSinhala-VariableFont_wdth,wght.ttf"
)

# ─────────────────── 1.  ONE-OFF SET-UP  ───────────────────────────
font_manager.fontManager.addfont(font_path)
plt.rcParams["font.family"] = "Noto Sans Sinhala"

CROP_DIR.mkdir(parents=True, exist_ok=True)
CKPT_DIR.mkdir(exist_ok=True)

# ─────────────────── 2.  LOAD WORD-BOX JSON  ───────────────────────
with ANN_FILE.open(encoding="utf-8") as f:
    raw = json.load(f)           # list[dict] – {uid: [word,…]} per page

label_map   = {}                # crop_name → GT string
line_counts = defaultdict(int)

# ─────────────────── 3.  HELPER: GRAPHEME CLUSTER SPLIT ────────────
cluster_re = regex.compile(r"\X", flags=regex.UNICODE)
def get_clusters(s: str) -> list[str]:
    """Return extended grapheme clusters sans whitespace."""
    return [c for c in cluster_re.findall(s) if not c.isspace()]

# ─────────────────── 4.  HELPER: RESIZE + PAD 32×512  ──────────────
def resize_and_pad(img: Image.Image,
                   target_size=(32, 512),
                   fill=255) -> Image.Image:
    """
    Keep aspect ratio, resize by **height**, then pad on the right
    to `target_size` with colour `fill` (white for mode "L").
    """
    tgt_h, tgt_w = target_size
    w, h = img.size
    scale = tgt_h / h
    new_w = int(w * scale)
    img_rs = img.resize((new_w, tgt_h), Image.BILINEAR)

    # If width still too big crop centre-left; otherwise pad right.
    if new_w >= tgt_w:
        img_rs = img_rs.crop((0, 0, tgt_w, tgt_h))
        return img_rs
    else:
        pad_w = tgt_w - new_w
        padding = (0, 0, pad_w, 0)                # (left, top, right, bottom)
        return ImageOps.expand(img_rs, padding, fill=fill)

# ─────────────────── 5.  MAIN LOOP: PAGES → LINE CROPS  ────────────
pad_vert = 4            # vertical padding in original crop
pad_horz = 4            # horizontal padding

all_clusters: set[str] = set()

for entry in raw:                       # entry = {uid: [word,…]}
    for uid, words in entry.items():
        page_path = PAGE_DIR / f"image_{uid}.png"
        if not page_path.exists():
            print(f"⚠️  missing page {page_path}")
            continue

        # ── group words roughly by y-centre (‘lines’) ────────────
        lines = []
        for w in sorted(words, key=lambda w: w["y"]):
            cy = w["y"]
            for line in lines:
                if abs(cy - line["cy"]) < 15:           # ≈ one text line
                    line["words"].append(w)
                    line["cy"] = (line["cy"] + cy) / 2
                    break
            else:
                lines.append({"cy": cy, "words": [w]})

        # ── open page and iterate each detected line ─────────────
        page = Image.open(page_path).convert("L")
        W_page, H_page = page.size

        for idx, line in enumerate(lines):
            bxs, bys, bxe, bye = [], [], [], []
            for w in line["words"]:
                x, y = w["x"], w["y"]
                w_ = w.get("w", w.get("width",  0))
                h_ = w.get("h", w.get("height", 0))
                x1 = w.get("x1", x + w_)
                y1 = w.get("y1", y + h_)
                bxs.append(x); bys.append(y)
                bxe.append(x1); bye.append(y1)

            # Fallback if heights are missing
            if max(bye) == min(bys):
                est_line_h = int(H_page * 0.05)
                bye = [y + est_line_h for y in bys]

            x0 = max(min(bxs) - pad_horz, 0)
            y0 = max(min(bys) - pad_vert, 0)
            x1 = min(max(bxe) + pad_horz, W_page)
            y1 = min(max(bye) + pad_vert, H_page)

            raw_crop = page.crop((x0, y0, x1, y1))

            # NORMALISE CROP SIZE 32×512
            crop = resize_and_pad(raw_crop, target_size=(32, 512), fill=255)

            crop_name = f"image_{uid}_line{idx:02d}.png"
            crop.save(CROP_DIR / crop_name)

            sorted_words = sorted(line["words"], key=lambda w: w["x"])
            text = " ".join(w["text"] for w in sorted_words).strip()
            label_map[crop_name] = text
            line_counts[uid] += 1

            # Gather clusters for charset
            for w in sorted_words:
                all_clusters.update(get_clusters(w["text"]))

# ─────────────────── 6.  BUILD CHARSET + TABLES  ───────────────────
charset = sorted(all_clusters)
if " " not in charset:             # ensure space present
    charset.append(" ")

stoi = {c: i + 1 for i, c in enumerate(charset)}   # 0 = blank for CTC
itos = {i: c for c, i in stoi.items()}

# ─────────────────── 7.  STATS + SAVE TO DISK  ─────────────────────
print("Unique grapheme clusters:", len(charset))
print("✅ Saved", len(label_map), "line crops.")
print("Avg lines per page:",
      sum(line_counts.values()) / max(len(line_counts), 1))

# 7a.  chars.txt
vocab_path = CROP_DIR / "chars.txt"
with vocab_path.open("w", encoding="utf-8") as f:
    for cluster in charset:
        f.write(cluster + "\n")

# 7b.  labels.json
out_path = CROP_DIR / "labels.json"
with out_path.open("w", encoding="utf-8") as f:
    json.dump(label_map, f, ensure_ascii=False, indent=2)

print(f"chars.txt  → {vocab_path}")
print(f"labels.json→ {out_path}")


⚠️  missing page D:\python\data\images\image_670194b5-4593-40fe-abd2-ae0e5fb80700.png
⚠️  missing page D:\python\data\images\image_e5b610ed-3182-4f13-806f-6703b7d70a77.png
⚠️  missing page D:\python\data\images\image_8b37e49a-e51e-4895-a66b-611353893213.png
⚠️  missing page D:\python\data\images\image_3c9951c2-a6f3-4bc7-82b7-45c2ec74a6d4.png
⚠️  missing page D:\python\data\images\image_20547e05-0bae-4f79-b7b4-b29080eadc74.png
⚠️  missing page D:\python\data\images\image_53096bc1-0c0f-4418-b417-3167bfebe962.png
⚠️  missing page D:\python\data\images\image_699bca2a-ca36-40a1-b0c1-107f5f01d519.png
⚠️  missing page D:\python\data\images\image_d96a0a39-0243-4582-b583-358b1b78d2ce.png
⚠️  missing page D:\python\data\images\image_b4dfe0f0-73a1-450e-a595-7cac023756a2.png
⚠️  missing page D:\python\data\images\image_7fa3004a-98cb-46f8-82de-7e46c81c655c.png
⚠️  missing page D:\python\data\images\image_dfc7563b-a265-40b0-b2f5-387beb63f620.png
⚠️  missing page D:\python\data\images\image_d59a13d7-

In [11]:
# ----------------------------------------------------------------------
# Load charset from a plain-text file saved as one character per line
# ----------------------------------------------------------------------
from pathlib import Path

charset_path = Path(r"d:/python/data/line_crops/chars.txt")   # <- point to your file
with charset_path.open("r", encoding="utf-8") as f:
    # keep original order, strip new-lines and ignore blank lines
    charset = [line.rstrip("\n") for line in f if line.strip()]

# Optionally verify there are no duplicates
dupes = [c for c, n in __import__("collections").Counter(charset).items() if n > 1]
if dupes:
    raise ValueError(f"Duplicate entries in charset.txt: {dupes}")

# Build look-up tables
char2idx  = {c: i for i, c in enumerate(charset)}
idx2char  = {i: c for c, i in char2idx.items()}
blank_id  = len(charset)          # CTC blank token index

print(f"Loaded {len(charset)} symbols. First 10 → {charset[:10]}")


Loaded 145 symbols. First 10 → ['"', ',', '.', '0', '1', '2', '3', '4', '5', '7']


In [17]:
# ─────────────────────────────────────────────────────────────────────
# Sinhala OCR – ResNet-Transformer + CTC
# ---------------------------------------------------------------------
# ❶  Image crops: 32 × 512  (RGB or gray, any aspect OK)
# ❷  labels.json  –  { "img_000123.png": "ඉතාලි …", ... }
# ❸  chars.txt    –  one grapheme-cluster per line (order = class id-1)
# ---------------------------------------------------------------------
#    python train_sinhala_ocr.py
# ─────────────────────────────────────────────────────────────────────
import json, random, regex, os
from pathlib import Path
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CTCLoss, TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models
# ---------------------------------------------------------------------
# FILE LOCATIONS – edit to match your drive
LABELS_JSON = Path(r"d:/python/data/line_crops/labels.json")
CHARS_TXT   = Path(r"d:/python/data/line_crops/chars.txt")
IMG_DIR     = Path(r"d:/python/data/line_crops")
CKPT_DIR    = Path(r"d:/python/checkpoints"); CKPT_DIR.mkdir(exist_ok=True)
# ---------------------------------------------------------------------
# CHARSET & LABEL MAP
with LABELS_JSON.open(encoding="utf-8") as f:
    line_map: dict[str, str] = json.load(f)

with CHARS_TXT.open(encoding="utf-8") as f:
    charset = [ln.rstrip("\n") for ln in f if ln.strip()]

char2idx = {c: i + 1 for i, c in enumerate(charset)}   # 0 = blank
idx2char = {i: c for c, i in char2idx.items()}
blank_idx = 0

cluster_re = regex.compile(r"\X", regex.UNICODE)
def split_clusters(s):       # drop whitespace clusters
    return [c for c in cluster_re.findall(s) if not c.isspace()]

img_paths, labels = [], []
for fname, text in line_map.items():
    img_file = IMG_DIR / fname
    if not img_file.exists():
        continue
    clust = split_clusters(text)
    if any(c not in char2idx for c in clust):
        print("⚠️  unseen cluster in", fname); continue
    img_paths.append(str(img_file))
    labels.append(clust)

print(f"Loaded {len(img_paths)} lines   charset = {len(charset)} symbols")
# ---------------------------------------------------------------------
# DATASET
class SinhalaOCRDataset(Dataset):
    def __init__(self, paths, lbls, c2i, size=(32,512), augment=False):
        self.paths, self.lbls, self.c2i, self.augment = paths, lbls, c2i, augment
        self.base_tf = T.Compose([
            T.Grayscale(num_output_channels=3),
            T.Resize(size, antialias=True),
            T.ToTensor(),
            T.Normalize(mean=[0.485,0.456,0.406],
                        std =[0.229,0.224,0.225]),
        ])
        self.aug_tf = T.Compose([
            T.RandomRotation(2),
            T.ColorJitter(0.1,0.1)
        ])

    def __len__(self): return len(self.paths)
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        if self.augment and random.random() < 0.5:
            img = self.aug_tf(img)
        img = self.base_tf(img)
        tgt = torch.tensor([self.c2i[c] for c in self.lbls[idx]],
                           dtype=torch.long)
        return img, tgt

def collate(batch):
    imgs, lbls = zip(*batch)
    return torch.stack(imgs, 0), list(lbls)   # imgs -> (B,3,32,512)

train_ds = SinhalaOCRDataset(img_paths, labels, char2idx,
                             size=(32,512), augment=True)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True,
                          num_workers=0, collate_fn=collate)
# ---------------------------------------------------------------------
# MODEL
class OCRModel(nn.Module):
    def __init__(self, n_chars):
        super().__init__()
        base = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        base.conv1.stride = (1,1)
        base.maxpool = nn.Identity()
        self.cnn = nn.Sequential(*list(base.children())[:-2])      # (B,512,H',W')
        self.tcn = nn.Conv1d(512,256,kernel_size=3,padding=1,groups=4)
        enc = TransformerEncoderLayer(d_model=256, nhead=4, dim_feedforward=512)
        self.tr  = TransformerEncoder(enc, num_layers=2)
        self.head = nn.Linear(256, n_chars+1)   # + blank
        nn.init.constant_(self.head.bias, -2.0) # discourage blanks
        self.head.bias.data[blank_idx] = 0.0    # but not too much

    def forward(self, x):                       # x (B,3,32,512)
        f = self.cnn(x)                         # (B,512,H',W')
        f = F.adaptive_avg_pool2d(f, (1,64)).squeeze(2)   # (B,512,64)
        f = self.tcn(f)                         # (B,256,64)
        seq = f.permute(2,0,1)                  # (64,B,256)
        h   = self.tr(seq)                      # (64,B,256)
        return F.log_softmax(self.head(h), dim=2)  # (T,B,C)
# ---------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = OCRModel(len(charset)).to(device)
opt    = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
sched  = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=50)
ctc    = CTCLoss(blank=blank_idx, zero_infinity=True)
# ---------------------------------------------------------------------
epochs, accum = 50, 4
for epoch in range(1, epochs+1):
    model.train(); opt.zero_grad(); total, steps = 0.0, 0
    for imgs, lbls in train_loader:
        imgs = imgs.to(device)
        logp = model(imgs)                     # (T,B,C)

        T_lens = torch.full((logp.size(1),), logp.size(0),
                            dtype=torch.long, device=device)
        L_lens = torch.tensor([t.size(0) for t in lbls],
                             dtype=torch.long, device=device)
        targets = torch.cat([t.to(device) for t in lbls])

        loss = ctc(logp, targets, T_lens, L_lens) / accum
        loss.backward(); total += loss.item(); steps += 1
        if steps % accum == 0:
            opt.step(); opt.zero_grad()

    if steps % accum:           # flush leftover grads
        opt.step(); opt.zero_grad()
    sched.step()

    print(f"Epoch {epoch:02d}  loss {total/steps:.4f}  lr {sched.get_last_lr()[0]:.2e}")

    # quick sanity-check decode on one random training line
    model.eval()
    with torch.no_grad():
        img, gt = train_ds[random.randint(0, len(train_ds)-1)]
        pred = model(img.unsqueeze(0).to(device)).squeeze(1)   # (T,C)
        arg  = torch.argmax(pred, 1).cpu().tolist()
        txt, prev = [], blank_idx
        for a in arg:
            if a != prev and a != blank_idx:
                txt.append(idx2char[a])
            prev = a
        print(" GT:", "".join(idx2char[i.item()] for i in gt))
        print(" PR:", "".join(txt))
        print("-"*50)

    # save checkpoint each epoch
    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": opt.state_dict(),
        "charset": charset,
    }, CKPT_DIR / f"sinhala_ocr_epoch{epoch:02d}.pt")


Loaded 28 lines   charset = 145 symbols
Epoch 01  loss 2.3501  lr 3.00e-04
 GT: බලපෑම්කරතිබේ.
 PR: 
--------------------------------------------------
Epoch 02  loss 1.2888  lr 2.99e-04
 GT: අනුවදෙවියන්වහන්සේවිශ්වයමවාඇත්තේමිනිසාගේ
 PR: 
--------------------------------------------------
Epoch 03  loss 1.3053  lr 2.97e-04
 GT: යෝජනාවක්‌ඉදිරිපත්කරනලදි.ඉන්ඔහුවඉතාලි
 PR: 
--------------------------------------------------
Epoch 04  loss 1.2135  lr 2.95e-04
 GT: අනුවදෙවියන්වහන්සේවිශ්වයමවාඇත්තේමිනිසාගේ
 PR: 
--------------------------------------------------
Epoch 05  loss 1.1467  lr 2.93e-04
 GT: ගැනදන්නාඅයකුඒමගහැරයෑමටහෝතමාටවිෂය
 PR: න්ර
--------------------------------------------------
Epoch 06  loss 1.1443  lr 2.89e-04
 GT: බටහිරමතවාදයකි.අමරතුංගමහතාගේඅවුල්අදහස්
 PR: රරරන්
--------------------------------------------------
Epoch 07  loss 1.0877  lr 2.86e-04
 GT: ග්රෑන්ඩ්කවුන්සිලයේනිගමනයන්ටඅනුව1943දීජුලි
 PR: න්
--------------------------------------------------
Epoch 08  loss 1.0420  lr 2

In [19]:
# ───────── Sinhala OCR – **Inference / Evaluation notebook cell** ──────────
#
# 1.  Point CHECKPOINT, CHARS_TXT, and IMG_DIR / LABELS_JSON to your files.
# 2.  Run this single cell – it loads the model, runs greedy CTC decoding
#     on every image in the JSON, prints the first 10, and reports CER.
# --------------------------------------------------------------------------
import json, regex, random
from pathlib import Path
from PIL import Image
import torch, torch.nn as nn, torch.nn.functional as F
from torch.nn import CTCLoss, TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torchvision.models as models
try:
    import editdistance                       # pip install editdistance
except ImportError:
    raise RuntimeError("pip install editdistance")

# ─── 1. FILES ────────────────────────────────────────────────────────────────
CHECKPOINT  = Path(r"d:/python/checkpoints/sinhala_ocr_epoch50.pt")
CHARS_TXT   = Path(r"d:/python/data/line_crops/chars.txt")
LABELS_JSON = Path(r"d:/python/data/line_crops/labels.json")
IMG_DIR     = Path(r"d:/python/data/line_crops")
# ─── 2. REBUILD CHARSET MAPPINGS ─────────────────────────────────────────────
charset   = [ln.rstrip("\n") for ln in CHARS_TXT.open(encoding="utf-8") if ln.strip()]
char2idx  = {c:i+1 for i,c in enumerate(charset)}    # 0 = blank
idx2char  = {i:c for c,i in char2idx.items()}
blank_idx = 0
# grapheme-cluster splitter
cluster_re = regex.compile(r"\X", regex.UNICODE)
def clusters(s): return [c for c in cluster_re.findall(s) if not c.isspace()]

# ─── 3. LOAD TEST SET (here we just reuse the same JSON) ────────────────────
with LABELS_JSON.open(encoding="utf-8") as f:
    line_map = json.load(f)

t_paths, t_labels = [], []
for fname, txt in line_map.items():
    fp = IMG_DIR / fname
    if not fp.exists(): continue
    cl = clusters(txt)
    if any(c not in char2idx for c in cl): continue
    t_paths.append(str(fp)); t_labels.append(cl)

# ─── 4. DATASET / DATALOADER (*) no augmentation ────────────────────────────
class TestDS(Dataset):
    tfm = T.Compose([
        T.Grayscale(num_output_channels=3),
        T.Resize((32,512), antialias=True),
        T.ToTensor(),
        T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ])
    def __init__(self, paths, lbls): self.p,self.l=paths,lbls
    def __len__(self): return len(self.p)
    def __getitem__(self,i):
        img = self.tfm(Image.open(self.p[i]).convert("RGB"))
        tgt = torch.tensor([char2idx[c] for c in self.l[i]], dtype=torch.long)
        return img, tgt
def collate(b): imgs,lbls=zip(*b); return torch.stack(imgs,0), list(lbls)
loader = DataLoader(TestDS(t_paths,t_labels), batch_size=4,
                    shuffle=False, num_workers=0, collate_fn=collate)

# ─── 5. MODEL  (must mirror training definition) ────────────────────────────
class OCRModel(nn.Module):
    def __init__(self, n):
        super().__init__()
        base = models.resnet18(weights=None)              # weights None: we’ll load
        base.conv1.stride, base.maxpool = (1,1), nn.Identity()
        self.cnn = nn.Sequential(*list(base.children())[:-2])
        self.tcn = nn.Conv1d(512,256,3,1,1,groups=4)
        enc = TransformerEncoderLayer(256,4,512)
        self.tr = TransformerEncoder(enc, 2)
        self.head = nn.Linear(256, n+1)
    def forward(self,x):
        f = self.cnn(x)
        f = F.adaptive_avg_pool2d(f,(1,64)).squeeze(2)
        f = self.tcn(f)
        h = self.tr(f.permute(2,0,1))
        return F.log_softmax(self.head(h),2)              # (T,B,C)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = OCRModel(len(charset)).to(device)
ckpt   = torch.load(CHECKPOINT, map_location=device)
model.load_state_dict(ckpt["model_state"]); model.eval()

# ─── 6. GREEDY DECODER ───────────────────────────────────────────────────────
def greedy_decode(logp):                   # logp (T,C)
    arg = torch.argmax(logp,1).cpu().tolist()
    out, prev = [], blank_idx
    for a in arg:
        if a!=prev and a!=blank_idx: out.append(idx2char[a])
        prev = a
    return "".join(out)

# ─── 7. EVALUATION ───────────────────────────────────────────────────────────
tot_cer, seen = 0.0, 0
for imgs, lbls in loader:
    imgs = imgs.to(device)
    with torch.no_grad():
        logp = model(imgs)                 # (T,B,C)
    for b in range(imgs.size(0)):
        pred = greedy_decode(logp[:,b,:])
        gt   = "".join(idx2char[i.item()] for i in lbls[b])
        cer  = editdistance.eval(pred, gt) / max(1,len(gt))
        tot_cer += cer; seen += 1
        if seen <= 10:                     # show first 10 examples
            print(f"GT:   {gt}\nPR:   {pred}\nCER:  {cer:.3f}\n{'-'*50}")

print(f"\nAverage CER over {seen} samples: {tot_cer/seen:.4f}")


  ckpt   = torch.load(CHECKPOINT, map_location=device)


GT:   අපපෘතුවිවායුගෝලයේඉහලස්තරකීපයක්
PR:   අපපෘතුවිවායුගෝලයේඉහලස්තරකීපයක්
CER:  0.000
--------------------------------------------------
GT:   තිබේ.පොලවටකිලෝමීටර්40කටවඩාඑපිටින්ඇති
PR:   තිබේ.පොලවටකිලෝමීටර්40කටවඩාඑපිටින්ඇති
CER:  0.000
--------------------------------------------------
GT:   මෙමස්තරයඅයනගෝලයලෙසහැදින්වේ.
PR:   මෙමස්තරයඅයනගෝලයලෙසහැදින්වේ.
CER:  0.000
--------------------------------------------------
GT:   අමරතුංගමහතාගේපෙබරවාරි10වැනිදාලිපිය
PR:   අමරතුංගමහතාගේපෙබරවාරි10වැනිදාලිපිය
CER:  0.000
--------------------------------------------------
GT:   අනුවදෙවියන්වහන්සේවිශ්වයමවාඇත්තේමිනිසාගේ
PR:   අනුවදෙවියන්වහන්සේවිශ්වයමවාඇත්තේමිනිසාගේ
CER:  0.000
--------------------------------------------------
GT:   ප්රයෝජනයපිණිසයයන්නබටහිරදර්ශනයය.
PR:   ප්රයෝජනයපිණිසයයන්නබටහිරදර්ශනයය.
CER:  0.000
--------------------------------------------------
GT:   එහෙත්පෙබරවාරි17වැනිදාලිපියටඅනුවඒ
PR:   එහෙත්පෙබරවාරි17වැනිදාලිපියටඅනුව
CER:  0.031
--------------------------------------------------
GT: 

2nd way which is more suitable for a research


In [None]:
import os
import json
import random
import regex                                # pip install regex
from pathlib import Path
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CTCLoss, TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models
import editdistance                          # pip install editdistance

# -----------------------------------------------------------------------------
# 1) Grapheme cluster utilities & label loading
cluster_re = regex.compile(r"\X", flags=regex.UNICODE)
def get_clusters(s: str) -> list[str]:
    return [c for c in cluster_re.findall(s) if not c.isspace()]

# load JSON: filename → sentence
json_path = Path("d:/python/data/line_crops/labels.json")
with open(json_path, encoding="utf-8") as f:
    line_map = json.load(f)

# build charset
all_clusters = set()
for txt in line_map.values():
    all_clusters.update(get_clusters(txt))
charset  = sorted(all_clusters)
char2idx = {c:i for i,c in enumerate(charset)}
idx2char = {i:c for c,i in char2idx.items()}
blank_idx = len(char2idx)

# collect image paths & labels
img_dir   = Path("d:/python/data/line_crops")
img_paths, labels = [], []
for fname, txt in line_map.items():
    p = img_dir/fname
    if not p.exists(): continue
    cls = get_clusters(txt)
    if cls:
        img_paths.append(str(p))
        labels.append(cls)

# -----------------------------------------------------------------------------
# 2) Dataset classes
class LineDataset(Dataset):
    """For RotNet self-supervision (no labels)."""
    def __init__(self, img_paths, img_size=(512,512)):
        self.paths = img_paths
        self.tf    = T.Compose([
            T.Grayscale(num_output_channels=3),
            T.Resize(img_size),
            T.ToTensor(),
            T.Normalize([0.5]*3, [0.5]*3)
        ])
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, i):
        img = Image.open(self.paths[i]).convert("RGB")
        return self.tf(img)

class OCRDataset(Dataset):
    """Line-level OCR data: image + grapheme-cluster labels."""
    def __init__(self, img_paths, labels, char2idx, img_size=(512,512), augment=False):
        self.paths = img_paths
        self.lbls  = labels
        self.c2i   = char2idx
        self.aug   = augment
        self.base  = T.Compose([
            T.Grayscale(num_output_channels=3),
            T.Resize(img_size),
            T.ToTensor(),
            T.Normalize([0.5]*3, [0.5]*3)
        ])
        self.aug_tf = T.Compose([
            T.RandomRotation(2),
            T.ColorJitter(brightness=0.1, contrast=0.1)
        ])
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, i):
        img = Image.open(self.paths[i]).convert("RGB")
        if self.aug and random.random()<0.5:
            img = self.aug_tf(img)
        img = self.base(img)
        lbl = torch.tensor([self.c2i[c] for c in self.lbls[i]], dtype=torch.long)
        return img, lbl

def ocr_collate(batch):
    return batch  # keep variable-length labels

# -----------------------------------------------------------------------------
# 3) Model definitions
class RotNet(nn.Module):
    def __init__(self):
        super().__init__()
        base = models.resnet18(pretrained=False)
        # make grayscale→3 channel conv, no downsample in conv1/pool
        base.conv1 = nn.Conv2d(3,64,7,1,3,bias=False)
        base.maxpool = nn.Identity()
        self.features = nn.Sequential(*list(base.children())[:-2])
        self.pool     = nn.AdaptiveAvgPool2d(1)
        self.head     = nn.Linear(base.fc.in_features, 4)
    def forward(self, x):
        f = self.features(x)
        f = self.pool(f).view(x.size(0), -1)
        return self.head(f)

class OCRModel(nn.Module):
    def __init__(self, num_chars):
        super().__init__()
        base = models.resnet18(pretrained=False)
        base.conv1   = nn.Conv2d(3,64,7,1,3,bias=False)
        base.maxpool = nn.Identity()
        self.cnn     = nn.Sequential(*list(base.children())[:-2])
        self.tcn     = nn.Conv1d(512,256,3,padding=1,groups=4)
        enc_layer     = TransformerEncoderLayer(d_model=256, nhead=4, dim_feedforward=512)
        self.transformer = TransformerEncoder(enc_layer, num_layers=2)
        self.fc      = nn.Linear(256, num_chars+1)  # +1 blank
    def forward(self, x):
        f   = self.cnn(x)             # (B,512,H',W')
        f   = f.mean(2)               # (B,512,W')
        f   = self.tcn(f)             # (B,256,W')
        seq = f.permute(2,0,1)        # (W',B,256)
        out = self.transformer(seq)   # (W',B,256)
        return F.log_softmax(self.fc(out), dim=2)

# -----------------------------------------------------------------------------
# 4) Pretraining RotNet
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
rotnet = RotNet().to(device)
optim_r = torch.optim.Adam(rotnet.parameters(), lr=2e-4)
crit_r  = nn.CrossEntropyLoss()
ds_r    = LineDataset(img_paths); loader_r = DataLoader(ds_r, batch_size=16, shuffle=True)

for epoch in range(5):
    rotnet.train(); total=0
    for imgs in loader_r:
        imgs = imgs.to(device)
        # random rotate
        R = torch.randint(0,4,(imgs.size(0),), device=device)
        imgs_r = torch.rot90(imgs, k=R, dims=(2,3))
        logits = rotnet(imgs_r)
        loss   = crit_r(logits, R)
        optim_r.zero_grad(); loss.backward(); optim_r.step()
        total += loss.item()
    print(f"RotNet Epoch {epoch+1} loss: {total/len(loader_r):.4f}")

# -----------------------------------------------------------------------------
# 5) OCR training
ocr = OCRModel(num_chars=len(char2idx)).to(device)
# transfer pretrained features
ocr.cnn.load_state_dict(rotnet.features.state_dict(), strict=False)

optim_o = torch.optim.Adam(ocr.parameters(), lr=1e-4)
ctc     = CTCLoss(blank=blank_idx, zero_infinity=False)
ds_o    = OCRDataset(img_paths, labels, char2idx, augment=True)
loader_o= DataLoader(ds_o, batch_size=2, shuffle=True, collate_fn=ocr_collate)
epochs  = 20

for ep in range(1, epochs+1):
    ocr.train(); tot=0
    for batch in loader_o:
        img, lbl = batch[0]
        img = img.unsqueeze(0).to(device)
        L   = lbl.size(0)
        logp= ocr(img)                     # (T,1,C)
        T_l = logp.size(0)
        inp = torch.full((1,), T_l, dtype=torch.long)
        tgt = torch.tensor([L], dtype=torch.long)
        loss= ctc(logp, lbl.unsqueeze(0).to(device), inp, tgt)
        optim_o.zero_grad(); loss.backward(); optim_o.step()
        tot += loss.item()
    print(f"OCR Epoch {ep} avg loss: {tot/len(loader_o):.4f}")

# -----------------------------------------------------------------------------
# 6) Evaluation (greedy CTC + CER)
def greedy_decode(lp: torch.Tensor) -> str:
    arg = lp.argmax(1).cpu().tolist()
    res, prev = [], blank_idx
    for a in arg:
        if a!=prev and a!=blank_idx:
            res.append(idx2char[a])
        prev = a
    return "".join(res)

ocr.eval(); total_cer=0
with torch.no_grad():
    for i,(img,lbl) in enumerate(ds_o):
        lp = ocr(img.unsqueeze(0).to(device)).squeeze(1)
        pred = greedy_decode(lp)
        truth= "".join(lbl.tolist() and [idx2char[x] for x in lbl.tolist()])
        cer = editdistance.eval(pred, truth)/max(1,len(truth))
        total_cer += cer
        if i<5:
            print(f"GT:   {truth}")
            print(f"PRED: {pred}, CER={cer:.3f}\n")
print(f"Overall CER: {total_cer/len(ds_o):.3f}")
