In [1]:
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
import cv2
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
import numpy as np

In [2]:
def preprocess(image_path):
    img = cv2.imread(image_path)
    
    mask = np.all(img == [0, 0, 0], axis=-1)
    img[mask] = [255, 255, 255]
    
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    binary = cv2.adaptiveThreshold(
        gray, 255,
        cv2.ADAPTIVE_THRESH_MEAN_C,
        cv2.THRESH_BINARY,
        blockSize=15,
        C=3
    )
    
    x, y, w, h = cv2.boundingRect(cv2.findNonZero(cv2.bitwise_not(binary)))
    return img[y:y+h, x:x+w]

In [3]:
class CaptchaDataset(Dataset):
    def __init__(self, folder_path, vocab):
        self.folder_path = folder_path
        self.image_files = os.listdir(folder_path)
        self.vocab = vocab
        self.char_to_idx = {c: i + 1 for i, c in enumerate(vocab)}  # 0 is CTC blank
        self.idx_to_char = {i + 1: c for i, c in enumerate(vocab)}

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        filename = self.image_files[idx]
        label_str = filename.split("-")[0]
        label = torch.tensor([self.char_to_idx[c] for c in label_str], dtype=torch.long)        

        image_path = os.path.join(self.folder_path, filename)
        image = preprocess(image_path)
        
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

        # Resize to model input size
        image = cv2.resize(image, (200, 50))  # (W, H)

        # Optional normalization
        image = image.astype("float32")
        image /= 255.0  # Normalize to [0, 1]

        # Convert to tensor with shape (1, H, W)
        image = torch.from_numpy(image).unsqueeze(0)  # shape: (1, 50, 200)
        
        return image, label, len(label)


# Basic model from scratch

In [4]:
class CNNLSTMCTC(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),  # -> (32, 50, 200)
            nn.ReLU(),
            nn.MaxPool2d(2, 2),              # -> (32, 25, 100)
            nn.Conv2d(32, 64, 3, padding=1), # -> (64, 25, 100)
            nn.ReLU(),
            nn.MaxPool2d(2, 2),              # -> (64, 12, 50)
        )
        self.lstm = nn.LSTM(input_size=64*12, hidden_size=128, num_layers=2, bidirectional=True)
        self.fc = nn.Linear(128 * 2, vocab_size + 1)  # +1 for blank

    def forward(self, x):
        x = self.cnn(x)  # (B, C, H, W)
        b, c, h, w = x.size()
        x = x.permute(3, 0, 1, 2).contiguous()  # (W, B, C, H)
        x = x.view(w, b, c * h)  # (T, B, input_size)
        x, _ = self.lstm(x)      # (T, B, 256)
        x = self.fc(x)           # (T, B, vocab+1)
        return x


In [5]:
def collate_fn(batch):
    images, labels, label_lengths = zip(*batch)
    images = torch.stack(images)
    labels = torch.cat(labels)
    label_lengths = torch.tensor(label_lengths, dtype=torch.long)

    return images, labels, label_lengths


def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    num_batches = 0
    
    for images, labels, label_lengths in dataloader:
        images = images.to(device)
        labels = labels.to(device)
        label_lengths = label_lengths.to(device)

        logits = model(images)  # (T, B, C)
        input_lengths = torch.full(size=(images.size(0),), fill_value=logits.size(0), dtype=torch.long).to(device)

        log_probs = F.log_softmax(logits, dim=2)
        loss = criterion(log_probs, labels, input_lengths, label_lengths)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1

    avg_loss = total_loss / num_batches
    print(f"Average Epoch Loss: {avg_loss:.4f}")


In [6]:
def greedy_decoder(logits, idx_to_char):
    pred = torch.argmax(logits, dim=2)  # (T, B)
    pred = pred.permute(1, 0)  # (B, T)

    results = []
    for p in pred:
        seq = []
        prev = -1
        for char_idx in p:
            if char_idx.item() != prev and char_idx.item() != 0:
                seq.append(idx_to_char[char_idx.item()])
            prev = char_idx.item()
        results.append("".join(seq))
    return results


In [7]:
vocab = "abcdefghijklmnopqrstuvwxyz0123456789"

In [8]:
# Dataset & DataLoader
dataset = CaptchaDataset("train", vocab)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)

# Model, Optimizer, Loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNNLSTMCTC(vocab_size=len(vocab)).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CTCLoss(blank=0, zero_infinity=True)

checkpoint_path = "image_input_scratch.pt"
if os.path.exists(checkpoint_path):
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    print("✅ Loaded pretrained model from", checkpoint_path)
else:
    print("🚀 No checkpoint found, training from scratch")
    for epoch in range(10):
        print(f"\nEpoch {epoch+1}")
        train(model, dataloader, optimizer, criterion, device)
    torch.save(model.state_dict(), checkpoint_path)


🚀 No checkpoint found, training from scratch

Epoch 1
Average Epoch Loss: 4.1389

Epoch 2
Average Epoch Loss: 3.9185

Epoch 3
Average Epoch Loss: 3.8611

Epoch 4
Average Epoch Loss: 3.0780

Epoch 5
Average Epoch Loss: 1.9143

Epoch 6
Average Epoch Loss: 1.3952

Epoch 7
Average Epoch Loss: 1.1443

Epoch 8
Average Epoch Loss: 0.9731

Epoch 9
Average Epoch Loss: 0.8385

Epoch 10
Average Epoch Loss: 0.7386


## Result
Word Accuracy: 28.00%
Character Accuracy: 70.67%

In [9]:
test_dataset = CaptchaDataset("test", vocab)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

In [12]:
model.eval()

total_words = 0
correct_words = 0

total_chars = 0
correct_chars = 0

with torch.no_grad():
    for images, labels, label_lengths in test_loader:
        images = images.to(device)
        logits = model(images)  # (T, B, C)
        predictions = greedy_decoder(logits.cpu(), test_dataset.idx_to_char)

        for i in range(len(predictions)):
            true_label = "".join([test_dataset.idx_to_char[idx.item()] for idx in labels])
            pred_label = predictions[i]

            total_words += 1
            if pred_label == true_label:
                correct_words += 1

            char_matches = sum(1 for a, b in zip(true_label, pred_label) if a == b)
            correct_chars += char_matches
            total_chars += len(true_label)

            print(f"True: {true_label.ljust(10)} | Pred: {pred_label.ljust(10)} | Char Acc: {char_matches}/{len(true_label)}")

word_acc = correct_words / total_words * 100
char_acc = correct_chars / total_chars * 100
print(f"Word Accuracy: {word_acc:.2f}%")
print(f"Character Accuracy: {char_acc:.2f}%")


True: 002e23     | Pred: 002e23     | Char Acc: 6/6
True: 03yl9s     | Pred: d3yi9s     | Char Acc: 4/6
True: 03yuav5    | Pred: 03yuqv5    | Char Acc: 6/7
True: 03zl9o     | Pred: 03zl90     | Char Acc: 5/6
True: 04zqohgi   | Pred: 04zqoh9i   | Char Acc: 7/8
True: 05htm      | Pred: 05htm      | Char Acc: 5/5
True: 05pb       | Pred: 03pb       | Char Acc: 3/4
True: 07oj       | Pred: 070j       | Char Acc: 3/4
True: 07z0       | Pred: 07z0       | Char Acc: 4/4
True: 08ft2e2z   | Pred: 08i12e22   | Char Acc: 5/8
True: 08nxd77    | Pred: 08nxd77    | Char Acc: 7/7
True: 08otejfi   | Pred: 080tejfz   | Char Acc: 6/8
True: 08sba      | Pred: 08sba      | Char Acc: 5/5
True: 09popm     | Pred: 09popm     | Char Acc: 6/6
True: 0a7sh2wp   | Pred: qc7jh2np   | Char Acc: 4/8
True: 0al4pl     | Pred: 0al49l     | Char Acc: 5/6
True: 0c3dp      | Pred: 0c3dp      | Char Acc: 5/5
True: 0c7mgdmd   | Pred: 007m9dmd   | Char Acc: 6/8
True: 0chnm34    | Pred: 048nm34    | Char Acc: 5/7
True: 0col7w