In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
from torch.nn.utils.rnn import pad_sequence
from torch.nn.functional import log_softmax
import numpy as np

In [27]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


Define Model

In [None]:
class CRNN(nn.Module):
    def __init__(self, num_classes):
        super(CRNN, self).__init__()

        # Convolutional layers
        self.conv_1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.pool_1 = nn.MaxPool2d(kernel_size=2, stride=2)  # (16, 64)

        self.conv_2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool_2 = nn.MaxPool2d(kernel_size=2, stride=2)  # (8, 32)

        self.conv_3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.pool_4 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))  # (4, 32)

        self.conv_5 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.batch_norm_5 = nn.BatchNorm2d(512)

        self.conv_6 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.batch_norm_6 = nn.BatchNorm2d(512)
        self.pool_6 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))  # (2, 32)

        # LSTM layers
        self.lstm_1 = nn.LSTM(input_size=1024, hidden_size=128, num_layers=1, bidirectional=True, batch_first=True)
        self.lstm_2 = nn.LSTM(input_size=256, hidden_size=128, num_layers=1, bidirectional=True, batch_first=True)

        # Final output layer (dense layer)
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        x = F.relu(self.conv_1(x))
        x = self.pool_1(x)

        x = F.relu(self.conv_2(x))
        x = self.pool_2(x)

        x = F.relu(self.conv_3(x))
        x = F.relu(self.conv_4(x))
        x = self.pool_4(x)

        x = F.relu(self.conv_5(x))
        x = self.batch_norm_5(x)

        x = F.relu(self.conv_6(x))
        x = self.batch_norm_6(x)
        x = self.pool_6(x)

        # Reshape for LSTM input
        batch_size, channels, height, width = x.size()

        # Permute and reshape to match LSTM input requirements
        x = x.permute(0, 3, 1, 2)  
        x = x.contiguous().view(batch_size, width, height * channels) 

        if x.size(-1) != 1024:
            raise ValueError(f"Expected input size of 1024, but got {x.size(-1)}")

        x, _ = self.lstm_1(x)
        x, _ = self.lstm_2(x)
        x = self.fc(x)

        return x

CTC Loss and Dataloader

In [None]:
# Define the loss function (CTC Loss)
def ctc_loss(pred, target, input_lengths, target_lengths):
    # CTC Loss expects input in shape (seq_len, batch, num_classes)
    return nn.CTCLoss(blank=0, zero_infinity=True)(pred, target, input_lengths, target_lengths)

# Define the Dataset class to load your PNG images and their corresponding text
class OCRDataset(Dataset):
    def __init__(self, image_folder, char_list, transform=None):
        self.image_folder = image_folder
        self.char_list = char_list
        self.transform = transform
        self.image_paths = [os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.endswith('.png')]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert('L')  # Convert to grayscale

        # Convert the text label (image name without extension)
        label = os.path.basename(img_path).split('.')[0]

        # Create target tensor for CTC loss (mapping characters to indices)
        target = torch.tensor([self.char_list.index(c) for c in label], dtype=torch.int)

        if self.transform:
            img = self.transform(img)

        img_width, img_height = img.shape[1], img.shape[2]
        return img, target, torch.tensor([img_width // 4], dtype=torch.int), torch.tensor([len(label)], dtype=torch.int)

In [None]:
def collate_fn(batch):
    images = []
    targets = []
    input_lengths = []
    target_lengths = []

    for img, target, input_len, target_len in batch:
        images.append(img)
        targets.append(target)
        input_lengths.append(input_len)
        target_lengths.append(target_len)

    targets = pad_sequence(targets, batch_first=True, padding_value=0)  

    images = torch.stack(images, 0)

    input_lengths = torch.tensor(input_lengths, dtype=torch.int)
    target_lengths = torch.tensor(target_lengths, dtype=torch.int)

    return images, targets, input_lengths, target_lengths

Initialisation

In [None]:
# Parameters
image_folder = '/content/dataset_1/dataset_1'  
char_list = [''] + list('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ')  
num_classes = len(char_list) 

# Data Transformations
transform = transforms.Compose([
    transforms.Resize((32, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Create Dataset and DataLoader
dataset = OCRDataset(image_folder, char_list, transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=20, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=20, shuffle=False, collate_fn=collate_fn)

model = CRNN(num_classes).to(device)  
optimizer = Adam(model.parameters(), lr=0.001)

Greedy Decoding

In [None]:
def decode_predictions(output, char_list, blank_index=0):  
    decoded_preds = []
    for i in range(output.size(1)): 
        pred_indices = torch.argmax(output[:, i, :], dim=1).tolist()  
        pred_text = []
        previous_char = None
        for index in pred_indices:
            if index != previous_char and index != blank_index:  
                pred_text.append(char_list[index])
            previous_char = index
        decoded_preds.append(''.join(pred_text))
    return decoded_preds

Training Loop

In [None]:
for epoch in range(30):  
    model.train()
    train_loss = 0
    for images, targets, _, target_lengths in train_loader:
        images, targets, target_lengths = images.to(device), targets.to(device), target_lengths.to(device)  # Move to device

        optimizer.zero_grad()
        output = model(images)
        batch_size, seq_len, num_classes = output.size()
        input_lengths = torch.full((batch_size,), seq_len, dtype=torch.int32).to(device)

        # Compute CTC loss
        loss = ctc_loss(output.permute(1, 0, 2), targets, input_lengths, target_lengths)
        train_loss += loss.item()

        loss.backward()
        optimizer.step()

    avg_train_loss = train_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}")

    # Validation Loop
    model.eval()
    val_loss = 0
    total_characters = 0
    correct_characters = 0
    print_count = 0 

    with torch.no_grad():
        for images, targets, _, target_lengths in val_loader:
            images, targets, target_lengths = images.to(device), targets.to(device), target_lengths.to(device)  # Move to device

            output = model(images)
            batch_size, seq_len, num_classes = output.size()
            input_lengths = torch.full((batch_size,), seq_len, dtype=torch.int32).to(device)

            # Compute CTC loss
            loss = ctc_loss(output.permute(1, 0, 2), targets, input_lengths, target_lengths)
            val_loss += loss.item()

            # Use decode_predictions function
            decoded_preds = decode_predictions(output.permute(1, 0, 2), char_list, blank_index=num_classes-1)
            for pred_text, target in zip(decoded_preds, targets):
                target_text = ''.join([char_list[c.item()] for c in target if c.item() != 0])  # Remove padding

                if print_count < 3:  # Print only first 3 predictions
                    print(f"Target: {target_text}, Predicted: {pred_text}")
                    print_count += 1

                # Calculate character-level accuracy
                correct_characters += sum(pc == tc for pc, tc in zip(pred_text, target_text))
                total_characters += len(target_text)

    avg_val_loss = val_loss / len(val_loader)
    character_accuracy = correct_characters / total_characters if total_characters > 0 else 0
    print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}, Character Accuracy: {character_accuracy:.4f}")

Epoch 1, Training Loss: 4.2398
Target: oLem, Predicted: 
Target: YhrSejfZ, Predicted: 
Target: qSHMvZ, Predicted: 
Epoch 1, Validation Loss: 4.6754, Character Accuracy: 0.0000
Epoch 2, Training Loss: 4.0724
Target: oLem, Predicted: lc
Target: YhrSejfZ, Predicted: NcgG
Target: qSHMvZ, Predicted: dG
Epoch 2, Validation Loss: 3.2918, Character Accuracy: 0.0383
Epoch 3, Training Loss: 1.3585
Target: oLem, Predicted: oLem
Target: YhrSejfZ, Predicted: YhrSejf
Target: qSHMvZ, Predicted: dSHMv
Epoch 3, Validation Loss: 0.4771, Character Accuracy: 0.8476
Epoch 4, Training Loss: 0.1217
Target: oLem, Predicted: oLem
Target: YhrSejfZ, Predicted: YhrSejf
Target: qSHMvZ, Predicted: qSHMv
Epoch 4, Validation Loss: 0.0295, Character Accuracy: 0.9234
Epoch 5, Training Loss: 0.0522
Target: oLem, Predicted: oLem
Target: YhrSejfZ, Predicted: YhrSejf
Target: qSHMvZ, Predicted: qSHMv
Epoch 5, Validation Loss: 0.1194, Character Accuracy: 0.8907
Epoch 6, Training Loss: 0.0367
Target: oLem, Predicted: oLem
Tar

Save Model

In [None]:
torch.save(model.state_dict(), 'task-2_data_1.pth')
print("Model saved after all epochs.")

Model saved after all epochs.
