In [61]:
import torch
import torch.nn as nn
import torch.nn.functional as F     
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from torchvision import transforms
from PIL import Image
import pandas as pd
import os
    

In [62]:
class CRNN(nn.Module):
    def __init__(self, input_channels, num_classes, img_height, img_width):
        super(CRNN, self).__init__()

        # Getting the spacial features using a CNN
        self.conv1 = nn.Conv2d(input_channels, 8, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)

        cnn_output_height = img_height // 8
        rnn_input_size = 32*cnn_output_height

        self.bi_lstm = nn.LSTM(input_size=rnn_input_size, hidden_size=128, num_layers=2, bidirectional=True)

        # Final classification layer (Transcription layer)
        self.fc = nn.Linear(128 * 2, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        

        """
        According to the paper
        from left to right, a vector of feature sequence is generated from the feature maps. 
        This means the ith feature vector is the concatenation of the ith column of all maps
        """
        b,c,h,w = x.size()
        # (Batch size, channels*height, width)
        x = x.view(b, c*h, w)
        # (sequence length = width, batch = b, features = channels*height)
        x = x.permute(2, 0, 1)  

        x, _ = self.bi_lstm(x)
        x = self.fc(x)
        return x

In [63]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [64]:
device

device(type='cuda')

In [65]:
vocab = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
char_to_idx = {c:i+1 for i, c in enumerate(vocab)} 
idx_to_char = {i+1:c for i, c in enumerate(vocab)}

 

In [66]:
input_channels = 3
# vocab size (no. of characters) + blank
num_classes = len(vocab) + 1 
img_height = 200
img_width = 80

In [67]:
model = CRNN(input_channels=input_channels, num_classes=num_classes, img_height=img_height, img_width=img_width).to(device)

In [68]:
class CustomImageDataset(Dataset):
    def __init__(self, csv_file, img_dir, transforms = None, char_to_idx = None):
        self.labels_df = pd.read_csv(csv_file)
        self.img_dir = img_dir 
        self.transforms = transforms
        self.char_to_idx = char_to_idx


    def __len__(self):
        return len(self.labels_df)
    
    def __getitem__(self, idx):
        img_name = self.labels_df.iloc[idx, 0]
        img_path = os.path.join(self.img_dir, img_name)

        image = Image.open(img_path)
        label = self.labels_df.iloc[idx, 1]
        y_label = [self.char_to_idx[c] for c in label]
        # print(y_label)
        if self.transforms:
            image = self.transforms(image)

        return image, torch.tensor(y_label)

In [69]:
transform = transforms.Compose([
    transforms.Resize((200, 80)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


In [87]:
train_dataset_easy = CustomImageDataset(
    csv_file = 'captcha_images/train/labels_easy.csv',
    img_dir = 'captcha_images/train/easy',
    transforms = transform,
    char_to_idx = char_to_idx 
)

train_dataset_hard = CustomImageDataset(
    csv_file = 'captcha_images/train/labels_hard.csv',
    img_dir = 'captcha_images/train/hard',
    transforms = transform,
    char_to_idx = char_to_idx
)

train_dataset = ConcatDataset([train_dataset_easy, train_dataset_hard])


test_dataset_easy = CustomImageDataset(
    csv_file = 'captcha_images/test/labels_easy.csv',
    img_dir = 'captcha_images/test/easy',
    transforms = transform,
    char_to_idx = char_to_idx
)

test_dataset_hard = CustomImageDataset(
    csv_file = 'captcha_images/test/labels_hard.csv',
    img_dir = 'captcha_images/test/hard',
    transforms = transform,
    char_to_idx = char_to_idx
)

test_dataset = ConcatDataset([test_dataset_easy, test_dataset_hard])


In [71]:
def ctc_collate_fn(batch):
    # [(image, tensor label), (image, tensor label)] -> (image1, image2),(tensor label1, tensor label2)
    images, labels = zip(*batch)
    images = torch.stack(images)

    targets = torch.cat(labels)

    target_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)

    return images, targets, target_lengths

In [88]:

train_dataloader = DataLoader(train_dataset, batch_size = 8, shuffle = True, collate_fn=ctc_collate_fn, num_workers=4)
test_dataloader = DataLoader(test_dataset, batch_size = 8, shuffle = False, collate_fn=ctc_collate_fn, num_workers=4)

In [73]:
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

In [74]:
# Earlier I used CrossEntropyLoss, but that cant handle variable target lengths
criterion = nn.CTCLoss(blank=0, zero_infinity=True)

In [75]:
num_epochs = 15

In [None]:
# for epoch in range(num_epochs):
#     model.train()
#     epoch_loss = 0.0
#     for batch_idx, (data, targets, target_lengths) in enumerate(train_dataloader):
#         data = data.to(device = device)
#         targets = targets.to(device = device)
#         target_lengths = target_lengths.to(device = device)

#         # Forward pass
#         scores = model(data)
#         scores = scores.log_softmax(2)

#         seq_len, batch_size, _ = scores.size()
#         input_lengths = torch.full(size=(batch_size,), fill_value=seq_len, dtype=torch.long).to(device)
        
#         loss = criterion(scores, targets, input_lengths, target_lengths)

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
        
#         epoch_loss += loss.item()
    
#     print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_dataloader):.4f}')


Epoch [1/15], Loss: 1.2173
Epoch [2/15], Loss: 0.3358
Epoch [2/15], Loss: 0.3358
Epoch [3/15], Loss: 0.2108
Epoch [3/15], Loss: 0.2108
Epoch [4/15], Loss: 0.1523
Epoch [4/15], Loss: 0.1523
Epoch [5/15], Loss: 0.1189
Epoch [5/15], Loss: 0.1189
Epoch [6/15], Loss: 0.0956
Epoch [6/15], Loss: 0.0956
Epoch [7/15], Loss: 0.0809
Epoch [7/15], Loss: 0.0809
Epoch [8/15], Loss: 0.0698
Epoch [8/15], Loss: 0.0698
Epoch [9/15], Loss: 0.0621
Epoch [9/15], Loss: 0.0621
Epoch [10/15], Loss: 0.0571
Epoch [10/15], Loss: 0.0571
Epoch [11/15], Loss: 0.0529
Epoch [11/15], Loss: 0.0529
Epoch [12/15], Loss: 0.0516
Epoch [12/15], Loss: 0.0516
Epoch [13/15], Loss: 0.0487
Epoch [13/15], Loss: 0.0487
Epoch [14/15], Loss: 0.0448
Epoch [14/15], Loss: 0.0448
Epoch [15/15], Loss: 0.0442
Epoch [15/15], Loss: 0.0442


- Number of epochs as 10 - 1% accuracy
- Number of epochs as 100 - 100% accuracy
- Right now model is just memorizing, need to add more samples.

- then i reduced from 6m to 500k parameters and model works fine now.

In [None]:
# torch.save(model.state_dict(), 'smolCRNN_model.pth')
# print("Model saved as smolCRNN_model.pth")

Model saved as smolCRNN_model.pth


In [80]:
# Load the saved model for evaluation
loaded_model = CRNN(input_channels=input_channels, num_classes=num_classes, img_height=img_height, img_width=img_width).to(device)
loaded_model.load_state_dict(torch.load('smolCRNN_model.pth'))
loaded_model.eval()
print("Model loaded successfully for evaluation")

Model loaded successfully for evaluation


In [None]:
def ctc_decode(preds, idx_to_char):
    preds = preds.cpu().numpy()
    seq_len, batch_size = preds.shape
    decoded = []

    for b in range(batch_size):
        seq = preds[:, b]
        word = []
        prev = None
        for idx in seq:
            if idx != prev and idx != 0:  # remove blanks and repeated
                word.append(idx_to_char[idx])
            prev = idx
        decoded.append("".join(word))
    return decoded


In [84]:
train_dataset_easy = CustomImageDataset(
    csv_file = 'wordlist_captcha_images/train/labels_easy.csv',
    img_dir = 'wordlist_captcha_images/train/easy',
    transforms = transform,
    char_to_idx = char_to_idx 
)

train_dataset_hard = CustomImageDataset(
    csv_file = 'wordlist_captcha_images/train/labels_hard.csv',
    img_dir = 'wordlist_captcha_images/train/hard',
    transforms = transform,
    char_to_idx = char_to_idx
)

train_dataset = ConcatDataset([train_dataset_easy, train_dataset_hard])


test_dataset_easy = CustomImageDataset(
    csv_file = 'wordlist_captcha_images/test/labels_easy.csv',
    img_dir = 'wordlist_captcha_images/test/easy',
    transforms = transform,
    char_to_idx = char_to_idx
)

test_dataset_hard = CustomImageDataset(
    csv_file = 'wordlist_captcha_images/test/labels_hard.csv',
    img_dir = 'wordlist_captcha_images/test/hard',
    transforms = transform,
    char_to_idx = char_to_idx
)

test_dataset = ConcatDataset([test_dataset_easy, test_dataset_hard])


In [85]:

train_dataloader = DataLoader(train_dataset, batch_size = 8, shuffle = True, collate_fn=ctc_collate_fn, num_workers=4)
test_dataloader = DataLoader(test_dataset, batch_size = 8, shuffle = False, collate_fn=ctc_collate_fn, num_workers=4)

In [89]:
def check_crnn_accuracy(loader, model, idx_to_char):
    model.eval()
    num_correct = 0
    num_samples = 0

    with torch.no_grad():
        for images, targets, target_lengths in loader:
            images = images.to(device)
            # forward pass
            scores = model(images)  
            preds = scores.argmax(2)  
            decoded_words = ctc_decode(preds, idx_to_char)

            # reconstruct true words from targets
            targets = targets.cpu().numpy()
            ptr = 0
            true_words = []
            for length in target_lengths:
                word_indices = targets[ptr:ptr+length]
                word = "".join([idx_to_char[idx] for idx in word_indices])
                true_words.append(word)
                ptr += length

            # compare
            for pred_word, true_word in zip(decoded_words, true_words):
                if pred_word == true_word:
                    num_correct += 1
                num_samples += 1

    acc = num_correct / num_samples
    print(f'Accuracy: {acc*100:.2f}% ({num_correct}/{num_samples})')
    return acc

check_crnn_accuracy(train_dataloader, model, idx_to_char)
check_crnn_accuracy(test_dataloader, model, idx_to_char)


Accuracy: 93.19% (46595/50000)
Accuracy: 84.04% (12606/15000)
Accuracy: 84.04% (12606/15000)


0.8404

In [90]:
total_params = sum(p.numel() for p in model.parameters())
total_params

1369807