# DrishT: CRNN-Light Text Recognition Training

**Model**: CRNN-Light — LightCNN (depthwise sep + SE) + BiLSTM + CTC (~3M params)  
**Dataset**: 188,879 train / 23,601 val word crops (722 chars, 13 scripts)  
**GPU**: Kaggle T4 (free, 30h/week)  

## Setup
1. Add dataset `drisht-recognition` to this notebook
2. Enable GPU: Settings → Accelerator → GPU T4 x2
3. Run all cells

In [None]:
import os, sys, csv, time, json
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler
import torchvision.transforms as T
from PIL import Image
from tqdm.auto import tqdm

print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Configuration

In [None]:
# --- Adjust this path based on your Kaggle dataset name ---
DATA_ROOT = Path('/kaggle/input/drisht-recognition')
OUTPUT_DIR = Path('/kaggle/working/recognition_output')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

TRAIN_CSV = DATA_ROOT / 'train' / 'labels.csv'
TRAIN_IMAGES = DATA_ROOT / 'train' / 'images'
VAL_CSV = DATA_ROOT / 'val' / 'labels.csv'
VAL_IMAGES = DATA_ROOT / 'val' / 'images'
CHARSET_FILE = DATA_ROOT / 'charset.txt'

# Model
IMG_HEIGHT = 32
IMG_WIDTH = 128
NUM_CHANNELS = 1  # Grayscale
HIDDEN_SIZE = 256
NUM_LSTM_LAYERS = 2
DROPOUT = 0.3
CTC_BLANK = 0

# Training
BATCH_SIZE = 128  # T4 has 16GB, can handle large batches
NUM_WORKERS = 2
EPOCHS = 80
LR = 1e-3
WEIGHT_DECAY = 1e-5
LR_MIN = 1e-6
PATIENCE = 12
USE_AMP = True

# Verify data
for p in [TRAIN_CSV, VAL_CSV, CHARSET_FILE]:
    assert p.exists(), f'Missing: {p}'
print(f'Train images: {len(list(TRAIN_IMAGES.iterdir()))}')
print(f'Val images: {len(list(VAL_IMAGES.iterdir()))}')

## Character Codec

In [None]:
class CharCodec:
    def __init__(self, charset_path):
        with open(charset_path, 'r', encoding='utf-8') as f:
            chars = [line.strip() for line in f if line.strip()]
        self.char_to_idx = {ch: i + 1 for i, ch in enumerate(chars)}
        self.idx_to_char = {i + 1: ch for i, ch in enumerate(chars)}
        self.num_classes = len(chars) + 1  # +1 for CTC blank

    def encode(self, text):
        return [self.char_to_idx.get(ch, 0) for ch in text]

    def decode(self, indices):
        return ''.join(self.idx_to_char.get(idx, '') for idx in indices)


codec = CharCodec(CHARSET_FILE)
print(f'Charset: {codec.num_classes} classes ({codec.num_classes - 1} chars + blank)')

## Model Architecture

In [None]:
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.depthwise = nn.Conv2d(in_ch, in_ch, kernel_size, stride, padding, groups=in_ch, bias=False)
        self.pointwise = nn.Conv2d(in_ch, out_ch, 1, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU6(inplace=True)

    def forward(self, x):
        return self.relu(self.bn(self.pointwise(self.depthwise(x))))


class SEBlock(nn.Module):
    def __init__(self, channels, reduction=4):
        super().__init__()
        mid = max(channels // reduction, 8)
        self.squeeze = nn.AdaptiveAvgPool2d(1)
        self.excitation = nn.Sequential(
            nn.Linear(channels, mid, bias=False), nn.ReLU(inplace=True),
            nn.Linear(mid, channels, bias=False), nn.Sigmoid())

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.squeeze(x).view(b, c)
        return x * self.excitation(y).view(b, c, 1, 1)


class InvertedResidual(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1, expand_ratio=4, use_se=True):
        super().__init__()
        mid_ch = in_ch * expand_ratio
        self.use_residual = (stride == 1 and in_ch == out_ch)
        layers = []
        if expand_ratio != 1:
            layers += [nn.Conv2d(in_ch, mid_ch, 1, bias=False), nn.BatchNorm2d(mid_ch), nn.ReLU6(inplace=True)]
        layers += [nn.Conv2d(mid_ch, mid_ch, 3, stride, 1, groups=mid_ch, bias=False),
                   nn.BatchNorm2d(mid_ch), nn.ReLU6(inplace=True)]
        self.conv = nn.Sequential(*layers)
        self.se = SEBlock(mid_ch) if use_se else nn.Identity()
        self.project = nn.Sequential(nn.Conv2d(mid_ch, out_ch, 1, bias=False), nn.BatchNorm2d(out_ch))

    def forward(self, x):
        out = self.project(self.se(self.conv(x)))
        return out + x if self.use_residual else out


class LightCNNEncoder(nn.Module):
    """(B, C, 32, W) -> (B, W/2, 512)"""
    def __init__(self, in_channels=1):
        super().__init__()
        self.stage1 = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU6(inplace=True),
            DepthwiseSeparableConv(32, 64), nn.MaxPool2d((2, 1)))
        self.stage2 = nn.Sequential(
            InvertedResidual(64, 128), InvertedResidual(128, 128), nn.MaxPool2d((2, 1)))
        self.stage3 = nn.Sequential(
            InvertedResidual(128, 256), InvertedResidual(256, 256), nn.MaxPool2d((2, 2)))
        self.stage4 = nn.Sequential(
            InvertedResidual(256, 384), InvertedResidual(384, 384), nn.MaxPool2d((2, 1)), nn.Dropout2d(0.1))
        self.stage5 = nn.Sequential(
            nn.Conv2d(384, 512, (2, 1), bias=False), nn.BatchNorm2d(512), nn.ReLU6(inplace=True),
            SEBlock(512), nn.Dropout2d(0.1))

    def forward(self, x):
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.stage5(x)
        return x.squeeze(2).permute(0, 2, 1)  # (B, W/2, 512)


class BiLSTMSequenceModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=2, dropout=0.3):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True, batch_first=True,
                           dropout=dropout if num_layers > 1 else 0)
        self.linear = nn.Linear(hidden_size * 2, hidden_size)

    def forward(self, x):
        return self.linear(self.lstm(x)[0])


class CRNN(nn.Module):
    def __init__(self, num_classes, img_h=32, num_channels=1):
        super().__init__()
        self.num_classes = num_classes
        self.img_h = img_h
        self.num_channels = num_channels
        self.cnn = LightCNNEncoder(in_channels=num_channels)
        self.rnn = BiLSTMSequenceModel(512, HIDDEN_SIZE, NUM_LSTM_LAYERS, DROPOUT)
        self.output = nn.Linear(HIDDEN_SIZE, num_classes)

    def forward(self, x):
        return F.log_softmax(self.output(self.rnn(self.cnn(x))), dim=2)

    def decode_greedy(self, log_probs):
        _, preds = log_probs.max(dim=2)
        results = []
        for b in range(preds.size(0)):
            decoded, prev = [], -1
            for p in preds[b].tolist():
                if p != prev and p != CTC_BLANK:
                    decoded.append(p)
                prev = p
            results.append(decoded)
        return results


model = CRNN(num_classes=codec.num_classes, img_h=IMG_HEIGHT, num_channels=NUM_CHANNELS).to(DEVICE)
total = sum(p.numel() for p in model.parameters())
print(f'CRNN-Light: {total:,} params, {sum(p.numel()*p.element_size() for p in model.parameters())/1024**2:.1f} MB')

## Dataset

In [None]:
class RecognitionDataset(Dataset):
    def __init__(self, csv_path, img_dir, codec, img_h=32, img_w=128, num_channels=1, augment=False):
        self.img_dir = Path(img_dir)
        self.codec = codec
        self.img_h, self.img_w = img_h, img_w
        self.num_channels = num_channels
        self.augment = augment
        self.samples = []
        with open(csv_path, 'r', encoding='utf-8') as f:
            for row in csv.DictReader(f):
                label = row['label'].strip()
                if not label or label == 'UNK': continue
                encoded = self.codec.encode(label)
                if any(idx == 0 for idx in encoded): continue
                self.samples.append((row['image'], label))
        print(f'  Loaded {len(self.samples)} from {csv_path}')

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_name, label = self.samples[idx]
        try:
            mode = 'L' if self.num_channels == 1 else 'RGB'
            image = Image.open(self.img_dir / img_name).convert(mode)
        except Exception:
            image = Image.new('L' if self.num_channels == 1 else 'RGB',
                             (self.img_w, self.img_h), 128)
            label = ''

        # Aspect-ratio-preserving resize + right-pad
        w, h = image.size
        new_w = min(int(w * self.img_h / h), self.img_w)
        image = image.resize((new_w, self.img_h), Image.BILINEAR)
        padded = Image.new('L' if self.num_channels == 1 else 'RGB', (self.img_w, self.img_h), 0)
        padded.paste(image, (0, 0))

        if self.augment:
            import random
            if random.random() < 0.3:
                padded = T.functional.adjust_brightness(padded, random.uniform(0.7, 1.3))
            if random.random() < 0.2:
                padded = T.functional.adjust_contrast(padded, random.uniform(0.8, 1.2))

        tensor = T.ToTensor()(padded)
        tensor = (tensor - 0.5) / 0.5 if self.num_channels == 1 else T.Normalize([.485,.456,.406],[.229,.224,.225])(tensor)

        encoded = self.codec.encode(label)
        return tensor, torch.tensor(encoded, dtype=torch.long), torch.tensor(len(encoded), dtype=torch.long)


def collate_fn(batch):
    images, targets, lengths = zip(*batch)
    images = torch.stack(images)
    max_len = max(t.size(0) for t in targets) if targets[0].size(0) > 0 else 1
    padded = torch.zeros(len(targets), max_len, dtype=torch.long)
    for i, t in enumerate(targets):
        if t.size(0) > 0: padded[i, :t.size(0)] = t
    return images, padded, torch.stack(lengths)


train_ds = RecognitionDataset(TRAIN_CSV, TRAIN_IMAGES, codec, IMG_HEIGHT, IMG_WIDTH, NUM_CHANNELS, augment=True)
val_ds = RecognitionDataset(VAL_CSV, VAL_IMAGES, codec, IMG_HEIGHT, IMG_WIDTH, NUM_CHANNELS, augment=False)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, collate_fn=collate_fn, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, collate_fn=collate_fn, pin_memory=True)

print(f'Train: {len(train_ds)}, Val: {len(val_ds)}')

## Training Utilities

In [None]:
def compute_metrics(model, loader, codec, device, max_batches=50):
    model.eval()
    total_chars, correct_chars, total_words, correct_words = 0, 0, 0, 0
    with torch.no_grad():
        for i, (images, targets, lengths) in enumerate(loader):
            if i >= max_batches: break
            log_probs = model(images.to(device))
            decoded = model.decode_greedy(log_probs)
            for b in range(len(decoded)):
                pred = codec.decode(decoded[b])
                gt = codec.decode(targets[b, :lengths[b]].tolist())
                total_words += 1
                if pred == gt: correct_words += 1
                for p, g in zip(pred, gt):
                    total_chars += 1
                    if p == g: correct_chars += 1
                total_chars += abs(len(pred) - len(gt))
    return correct_chars / max(total_chars, 1) * 100, correct_words / max(total_words, 1) * 100


def train_one_epoch(model, loader, optimizer, scaler, device, epoch):
    model.train()
    ctc = nn.CTCLoss(blank=CTC_BLANK, zero_infinity=True)
    total, n = 0.0, 0
    pbar = tqdm(loader, desc=f'Epoch {epoch}')
    for images, targets, lengths in pbar:
        images, targets, lengths = images.to(device), targets.to(device), lengths.to(device)
        optimizer.zero_grad()
        if scaler:
            with autocast('cuda'):
                log_probs = model(images).permute(1, 0, 2)
                input_lengths = torch.full((images.size(0),), log_probs.size(0), dtype=torch.long, device=device)
                loss = ctc(log_probs, targets, input_lengths, lengths)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            log_probs = model(images).permute(1, 0, 2)
            input_lengths = torch.full((images.size(0),), log_probs.size(0), dtype=torch.long, device=device)
            loss = ctc(log_probs, targets, input_lengths, lengths)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()
        total += loss.item(); n += 1
        pbar.set_postfix(loss=f'{loss.item():.4f}')
    return total / max(n, 1)


@torch.no_grad()
def validate(model, loader, device):
    model.eval()
    ctc = nn.CTCLoss(blank=CTC_BLANK, zero_infinity=True)
    total, n = 0.0, 0
    for images, targets, lengths in tqdm(loader, desc='Val'):
        images, targets, lengths = images.to(device), targets.to(device), lengths.to(device)
        log_probs = model(images).permute(1, 0, 2)
        input_lengths = torch.full((images.size(0),), log_probs.size(0), dtype=torch.long, device=device)
        loss = ctc(log_probs, targets, input_lengths, lengths)
        total += loss.item(); n += 1
    return total / max(n, 1)

print('Utilities defined.')

## Train

In [None]:
optimizer = optim.Adam(model.parameters(), lr=LR, betas=(0.9, 0.999), weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=LR_MIN)
scaler = GradScaler('cuda') if USE_AMP and DEVICE.type == 'cuda' else None

best_val = float('inf')
best_word_acc = 0.0
patience_ctr = 0
history = []

print(f'Training CRNN-Light for {EPOCHS} epochs on {DEVICE}')
print(f'AMP: {scaler is not None}, Batch: {BATCH_SIZE}')
print()

In [None]:
for epoch in range(1, EPOCHS + 1):
    t0 = time.time()

    tloss = train_one_epoch(model, train_loader, optimizer, scaler, DEVICE, epoch)
    vloss = validate(model, val_loader, DEVICE)

    char_acc, word_acc = compute_metrics(model, val_loader, codec, DEVICE)

    scheduler.step()
    lr = optimizer.param_groups[0]['lr']
    elapsed = time.time() - t0

    print(f'Epoch {epoch:3d} | Train: {tloss:.4f} | Val: {vloss:.4f} | '
          f'Char: {char_acc:.1f}% | Word: {word_acc:.1f}% | LR: {lr:.6f} | {elapsed:.1f}s')
    history.append({'epoch': epoch, 'train_loss': tloss, 'val_loss': vloss,
                    'char_acc': char_acc, 'word_acc': word_acc, 'lr': lr})

    # Save best by val loss
    if vloss < best_val:
        best_val = vloss
        patience_ctr = 0
        torch.save({'state_dict': model.state_dict(), 'num_classes': codec.num_classes,
                    'img_h': IMG_HEIGHT, 'num_channels': NUM_CHANNELS,
                    'epoch': epoch, 'best_val': best_val},
                   OUTPUT_DIR / 'best.pth')
        print(f'  -> Saved best model (val={vloss:.4f})')
    else:
        patience_ctr += 1

    # Save best by word accuracy
    if word_acc > best_word_acc:
        best_word_acc = word_acc
        torch.save({'state_dict': model.state_dict(), 'num_classes': codec.num_classes,
                    'img_h': IMG_HEIGHT, 'num_channels': NUM_CHANNELS,
                    'epoch': epoch, 'best_word_acc': best_word_acc},
                   OUTPUT_DIR / 'best_acc.pth')
        print(f'  -> Saved best accuracy model (word={word_acc:.1f}%)')

    if epoch % 10 == 0:
        torch.save({'state_dict': model.state_dict(), 'num_classes': codec.num_classes,
                    'img_h': IMG_HEIGHT, 'num_channels': NUM_CHANNELS, 'epoch': epoch},
                   OUTPUT_DIR / f'epoch_{epoch}.pth')

    if patience_ctr >= PATIENCE:
        print(f'\nEarly stopping at epoch {epoch}')
        break

print(f'\nDone! Best val: {best_val:.4f}, Best word acc: {best_word_acc:.1f}%')

## Training Curves

In [None]:
import matplotlib.pyplot as plt

eps = [h['epoch'] for h in history]
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

axes[0].plot(eps, [h['train_loss'] for h in history], label='Train')
axes[0].plot(eps, [h['val_loss'] for h in history], label='Val')
axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('CTC Loss'); axes[0].legend(); axes[0].set_title('Loss')

axes[1].plot(eps, [h['char_acc'] for h in history], 'b-')
axes[1].set_xlabel('Epoch'); axes[1].set_ylabel('%'); axes[1].set_title('Character Accuracy')

axes[2].plot(eps, [h['word_acc'] for h in history], 'g-')
axes[2].set_xlabel('Epoch'); axes[2].set_ylabel('%'); axes[2].set_title('Word Accuracy')

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'training_curves.png', dpi=150)
plt.show()

## Export ONNX

In [None]:
# Load best model for export
ckpt = torch.load(OUTPUT_DIR / 'best_acc.pth', map_location='cpu')
model_export = CRNN(ckpt['num_classes'], ckpt.get('img_h', 32), ckpt.get('num_channels', 1))
model_export.load_state_dict(ckpt['state_dict'])
model_export.eval()

dummy = torch.randn(1, NUM_CHANNELS, IMG_HEIGHT, IMG_WIDTH)
torch.onnx.export(
    model_export, dummy, str(OUTPUT_DIR / 'crnn_recognition.onnx'),
    opset_version=17,
    input_names=['image'],
    output_names=['log_probs'],
    dynamic_axes={'image': {0: 'batch', 3: 'width'}, 'log_probs': {0: 'batch', 1: 'timesteps'}},
)
onnx_size = (OUTPUT_DIR / 'crnn_recognition.onnx').stat().st_size / 1024**2
print(f'ONNX exported: {onnx_size:.1f} MB')

## Save & Download
Download from `/kaggle/working/recognition_output/`

In [None]:
with open(OUTPUT_DIR / 'history.json', 'w') as f:
    json.dump(history, f, indent=2)
# Copy charset for inference
import shutil
shutil.copy2(CHARSET_FILE, OUTPUT_DIR / 'charset.txt')

print('Output files:')
for f in sorted(OUTPUT_DIR.iterdir()):
    print(f'  {f.name}: {f.stat().st_size / 1024**2:.1f} MB')