# Captcha CRNN training pipeline (including pretraining).

This notebook includes two stages:
1. Single-letter classification pretraining (SimpleCNN).
2. Load pretrained CNN weights and train a CRNN + CTC decoder for CAPTCHA recognition.

**Assumes CAPTCHAs contain only lowercase English letters, no digits.**

In [None]:
import os

def get_unique_filename(base_name):
    """
    Automatically finds non-conflicting filenames;
    if the base name exists, suffixes like _1, _2, etc. are added.
    """
    if not os.path.exists(base_name):
        return base_name

    name, ext = os.path.splitext(base_name)
    i = 1
    while True:
        new_name = f"{name}_{i}{ext}"
        if not os.path.exists(new_name):
            return new_name
        i += 1

In [None]:
# Character set defined as lowercase letters plus a blank token
import string
CHARS = string.ascii_lowercase
BLANK = "-"

char2idx = {BLANK: 0}
for i, c in enumerate(CHARS, 1):
    char2idx[c] = i
idx2char = {i: c for c, i in char2idx.items()}
NUM_CLASSES = len(char2idx)
print(f"Classes: {NUM_CLASSES}, Characters: {CHARS}")

BLANK_INDEX = 0  # Corresponds to the CTC blank token


In [None]:
# SimpleCNN: Single-Character Classifier
import torch
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.AdaptiveAvgPool2d(1)
        )
        self.fc = nn.Linear(256, n_classes)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

In [None]:
from torch.utils.data import Dataset
from PIL import Image
import os

class SingleLetterDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.samples = []
        self.transform = transform
        for label in os.listdir(root_dir):
            label_path = os.path.join(root_dir, label)
            if not os.path.isdir(label_path): continue
            for file in os.listdir(label_path):
                if file.endswith('.png') or file.endswith('.jpg'):
                    self.samples.append((os.path.join(label_path, file), label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        image = Image.open(path).convert("L")
        if self.transform:
            image = self.transform(image)
        return image, char2idx[label]

In [None]:
def train_single_letter(model, dataset, device, epochs=10):
    from torch.utils.data import DataLoader
    import torch.optim as optim
    model.to(device)
    loader = DataLoader(dataset, batch_size=64, shuffle=True)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss/len(loader):.4f}")

In [None]:
def evaluate_single_letter_accuracy(model, dataset, device="cpu"):
    from torch.utils.data import DataLoader
    model.eval()
    model.to(device)

    loader = DataLoader(dataset, batch_size=64, shuffle=False)
    correct = 0
    total = 0

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            output = model(x)
            pred = output.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)

    acc = correct / total
    print(f"✅ Single letter acc: {acc:.2%}")
    return acc


In [None]:
# CRNN + CTC
class CRNN(nn.Module):
    def __init__(self, img_h, n_channels, n_classes, rnn_hidden=256):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(n_channels, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(True), nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(True), nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(True), nn.MaxPool2d((2,1), (2,1)),
        )
        self.rnn_input = 256 * (img_h // 8)
        self.rnn = nn.LSTM(self.rnn_input, rnn_hidden, num_layers=2, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(rnn_hidden * 2, n_classes)

    def forward(self, x):
        conv = self.cnn(x)
        b, c, h, w = conv.size()
        conv = conv.permute(0, 3, 1, 2)
        rnn_in = conv.view(b, w, c * h)
        rnn_out, _ = self.rnn(rnn_in)
        out = self.fc(rnn_out)
        return out.log_softmax(2)

In [None]:
def load_cnn_weights_into_crnn(crnn_model, cnn_model):
    """
    Transfers convolutional feature weights from SimpleCNN to the CRNN's CNN module.
    """
    cnn_state = cnn_model.conv.state_dict()
    crnn_state = crnn_model.cnn.state_dict()

    # Filters for matching parameter keys
    filtered_state = {k: v for k, v in cnn_state.items() if k in crnn_state}
    crnn_state.update(filtered_state)

    crnn_model.cnn.load_state_dict(crnn_state)
    print("✅ Successfully loaded convolutional weights from SimpleCNN into the CRNN!")


In [None]:
from torch.utils.data import Dataset
from PIL import Image
import os

class CaptchaDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        root_dir: the image folder, e.g., "captchas", where filenames follow the pattern abcd.png.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []

        for fname in os.listdir(root_dir):
            if fname.endswith(".png"):
                label = fname.split(".")[0]
                self.samples.append((os.path.join(root_dir, fname), label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label_str = self.samples[idx]
        img = Image.open(path).convert("L")

        if self.transform:
            img = self.transform(img)

        label = torch.tensor([char2idx[c] for c in label_str], dtype=torch.long)
        return img, label, torch.tensor(len(label))


In [None]:
def collate_fn(batch):
    imgs, labels, label_lens = zip(*batch)
    imgs = torch.stack(imgs, 0)  # (B, C, H, W)
    label_lens = torch.tensor(label_lens, dtype=torch.long)  # 1D tensor with length equal to batch size
    labels = torch.cat(labels)
    return imgs, labels, label_lens


In [None]:
def train_crnn(model, dataset, device="cuda", epochs=20, freeze_cnn_epochs=5):
    from torch.utils.data import DataLoader
    import torch.optim as optim

    loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
    ctc_loss = nn.CTCLoss(blank=BLANK_INDEX, zero_infinity=True)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    model.to(device)

    for epoch in range(epochs):
        model.train()

        # Whether to freeze CNN (only for the first few epochs).
        if epoch < freeze_cnn_epochs:
            for param in model.cnn.parameters():
                param.requires_grad = False
        else:
            for param in model.cnn.parameters():
                param.requires_grad = True

        total_loss = 0
        for imgs, labels, label_lens in loader:
            imgs = imgs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            logits = model(imgs)  # current order (B, T, C)
            logits = logits.permute(1, 0, 2)  # rearrange into (T, B, C)
            log_probs = nn.functional.log_softmax(logits, dim=2)

            input_lens = torch.full(size=(logits.size(1),), fill_value=logits.size(0), dtype=torch.long)

            loss = ctc_loss(log_probs, labels, input_lens, label_lens)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss/len(loader):.4f}",
              "(CNN frozen)" if epoch < freeze_cnn_epochs else "")


In [None]:
import torch

def decode_prediction(pred_indices, idx2char, blank_idx=0):
    """
    Applies simple CTC decoding (removing repeated characters and blanks)
    pred_indices: List[int], the index sequence predicted by the model
    """
    decoded = []
    prev = None
    for idx in pred_indices:
        if idx != blank_idx and idx != prev:
            decoded.append(idx2char[idx])
        prev = idx
    return ''.join(decoded)

def test_crnn(model, dataset, device, idx2char, batch_size=32, verbose=False):
    from torch.utils.data import DataLoader

    GREEN = "\033[92m"
    RED = "\033[91m"
    RESET = "\033[0m"

    model.eval()
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    correct = 0
    total = 0

    with torch.no_grad():
        for imgs, labels, label_lens in loader:
            imgs = imgs.to(device)
            logits = model(imgs)

            if logits.shape[0] == imgs.size(0):
                logits = logits.permute(1, 0, 2)
            log_probs = torch.nn.functional.log_softmax(logits, dim=2)

            pred_indices = log_probs.argmax(dim=2).cpu().numpy().transpose(1, 0)
            labels = labels.cpu().numpy()
            label_lens = label_lens.cpu().numpy()

            label_ptr = 0
            for i in range(len(imgs)):
                length = label_lens[i]
                true_label_idx = labels[label_ptr:label_ptr + length]
                label_ptr += length

                pred_label_idx = pred_indices[i].tolist()
                pred_text = decode_prediction(pred_label_idx, idx2char)
                true_text = ''.join(idx2char[idx] for idx in true_label_idx)

                if pred_text == true_text:
                    correct += 1
                    color = GREEN
                else:
                    color = RED

                if verbose:
                    print(f"{color}Predict: {pred_text} | GT: {true_text}{RESET}")

                total += 1

    acc = correct / total if total > 0 else 0
    print(f"\n✅ Test Accuracy: {acc*100:.2f}% ({correct}/{total})")
    return acc


In [None]:
from IPython.display import clear_output
!unzip /content/single_char_images.zip -d /content/single_char_images
clear_output()

In [None]:
import torchvision.transforms as T

transform = T.Compose([
    T.Resize((30, 100)),
    T.ToTensor()
])

from torch.utils.data import DataLoader
single_dataset = SingleLetterDataset("single_char_images", transform)


In [None]:
from google.colab import files

cnn = SimpleCNN(n_classes=len(char2idx))
train_single_letter(cnn, single_dataset, device="cuda", epochs=30)

cnn_filename = get_unique_filename("cnn_pretrained.pth")
torch.save(cnn.state_dict(), cnn_filename)
print(f"✅ Model saved as: {cnn_filename}")
files.download(cnn_filename)


In [None]:
from IPython.display import clear_output
!unzip /content/single_char_test_images.zip -d /content/single_char_test_images
clear_output()

In [None]:
transform = T.Compose([
    T.Resize((30, 100)),
    T.ToTensor()
])

test_single_dataset = SingleLetterDataset("single_char_test_images", transform)


In [None]:
from IPython.display import clear_output
cnn = SimpleCNN(n_classes=len(char2idx))
cnn.load_state_dict(torch.load(cnn_filename, map_location="cuda"))
cnn.to("cuda")
cnn.eval()
clear_output()

In [None]:
evaluate_single_letter_accuracy(cnn, single_dataset, device="cuda")
evaluate_single_letter_accuracy(cnn, test_single_dataset, device="cuda")


In [None]:
from IPython.display import clear_output
!unzip /content/train_captchas.zip -d /content/train_captchas
clear_output()

In [None]:
transform = T.Compose([
    T.Resize((30, 120)),  # Size of 4-letter CAPTCHA images
    T.ToTensor()
])

captcha_dataset = CaptchaDataset("train_captchas", transform)


In [None]:
# Load CNN weights into CRNN
cnn = SimpleCNN(n_classes=len(char2idx))
cnn.load_state_dict(torch.load(cnn_filename))

crnn = CRNN(img_h=30, n_channels=1, n_classes=NUM_CLASSES)
load_cnn_weights_into_crnn(crnn, cnn)


In [None]:
from google.colab import files

train_crnn(crnn, captcha_dataset, device="cuda", epochs=40, freeze_cnn_epochs=5)

crnn_filename = get_unique_filename("crnn.pth")
torch.save(crnn.state_dict(), crnn_filename)
print(f"✅ Model saved as: {crnn_filename}")
files.download(crnn_filename)


In [None]:
from IPython.display import clear_output
!unzip /content/test_captchas.zip -d /content/test_captchas
clear_output()

In [None]:
transform = T.Compose([
    T.Resize((30, 120)),
    T.ToTensor()
])

test_captcha_dataset = CaptchaDataset("test_captchas", transform)

In [None]:
from IPython.display import clear_output
crnn = CRNN(img_h=30, n_channels=1, n_classes=NUM_CLASSES)
crnn.load_state_dict(torch.load(crnn_filename, map_location="cuda"))
crnn.to("cuda")
crnn.eval()
clear_output()

In [None]:
test_accuracy = test_crnn(crnn, captcha_dataset, device="cuda", idx2char=idx2char)
test_accuracy = test_crnn(crnn, test_captcha_dataset, device="cuda", idx2char=idx2char, verbose=True)
