In [39]:
#FINAL CODE

import os
import scipy.io
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import string

# ================= Constants =================
CHARS = string.ascii_lowercase + string.digits
char_to_idx = {c: i for i, c in enumerate(CHARS)}
idx_to_char = {i: c for c, i in char_to_idx.items()}
NUM_CLASSES = len(CHARS) + 1  # +1 for CTC blank
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ============= IIIT5K Dataset =============
class IIIT5KDataset(Dataset):
    def __init__(self, root_dir, mat_path, transform=None):
        self.root_dir = root_dir
        self.transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.Resize((32, 100)),
            transforms.RandomRotation(2),
            transforms.RandomAffine(degrees=0, translate=(0.02, 0.02)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

        data = scipy.io.loadmat(mat_path)
        traindata = data['traindata'][0]  # Flatten the MATLAB struct array

        self.samples = []
        for item in traindata:
            img_name = str(item['ImgName'][0])
            word = str(item['GroundTruth'][0]).lower()
            self.samples.append((img_name, word))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_name, word = self.samples[idx]
        img_path = os.path.join(self.root_dir, img_name)
        image = Image.open(img_path).convert('L')
        image = self.transform(image)
        label = torch.tensor([char_to_idx[c] for c in word if c in char_to_idx], dtype=torch.long)
        return image, label


# ============= CRNN Model =============
class CRNN(nn.Module):
    def __init__(self):
        super(CRNN, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.MaxPool2d((2, 1)),
            nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.MaxPool2d((2, 1))
        )

        self.rnn = nn.LSTM(256 * 2, 256, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(512, NUM_CLASSES)

    def forward(self, x):
        x = self.cnn(x)
        b, c, h, w = x.size()
        x = x.permute(0, 3, 1, 2).contiguous()
        x = x.view(b, w, c * h)
        x, _ = self.rnn(x)
        x = self.fc(x)
        return x.permute(1, 0, 2)


# ============= Utils =============
def collate_fn(batch):
    images, labels = zip(*batch)
    image_batch = torch.stack(images)
    label_lengths = torch.tensor([len(l) for l in labels], dtype=torch.long)
    labels_concat = torch.cat(labels)
    return image_batch, labels_concat, label_lengths

def decode(logits):
    probs = torch.softmax(logits, dim=2)
    _, preds = probs.max(2)
    preds = preds.permute(1, 0)

    results = []
    for pred in preds:
        string = ''
        last = -1
        for p in pred:
            p = p.item()
            if p != last and p != len(CHARS):
                string += idx_to_char[p]
            last = p
        results.append(string)
    return results

# ============= Quick Test =============
def test_on_training_data(model, dataset, num_samples=30):
    model.eval()
    print("\n\U0001f9ea Testing on training samples...")
    for i in range(num_samples):
        image, label = dataset[i]
        input_tensor = image.unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            output = model(input_tensor)
            pred_text = decode(output)[0]
        actual_text = ''.join([idx_to_char[c.item()] for c in label])
        print(f"[{i+1}] \U0001f524 Predicted: {pred_text} |  Actual: {actual_text}")

# ============= Training =============
def train_model():
    dataset = IIIT5KDataset("./IIIT5K", "./IIIT5K/trainData.mat")
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

    model = CRNN().to(DEVICE)
    criterion = nn.CTCLoss(blank=len(CHARS))
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    model.train()
    for epoch in range(50):
        total_loss = 0
        for images, labels, label_lengths in dataloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            output_len = torch.full((images.size(0),), outputs.size(0), dtype=torch.long).to(DEVICE)

            optimizer.zero_grad()
            loss = criterion(outputs, labels, output_len, label_lengths)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}: Loss {total_loss / len(dataloader):.4f}")

    torch.save(model.state_dict(), "crnn_model.pth")
    print("\u2705 Model saved as crnn_model.pth")

    test_on_training_data(model, dataset)

if __name__ == "__main__":
    train_model()


Epoch 1: Loss 4.3052
Epoch 2: Loss 3.8556
Epoch 3: Loss 3.8045
Epoch 4: Loss 3.7925
Epoch 5: Loss 3.7946
Epoch 6: Loss 3.7785
Epoch 7: Loss 3.7317
Epoch 8: Loss 3.6892
Epoch 9: Loss 3.6200
Epoch 10: Loss 3.5861
Epoch 11: Loss 3.5424
Epoch 12: Loss 3.5257
Epoch 13: Loss 3.4991
Epoch 14: Loss 3.4667
Epoch 15: Loss 3.4360
Epoch 16: Loss 3.4060
Epoch 17: Loss 3.3537
Epoch 18: Loss 3.3169
Epoch 19: Loss 3.2296
Epoch 20: Loss 3.1431
Epoch 21: Loss 3.0465
Epoch 22: Loss 2.9276
Epoch 23: Loss 2.7285
Epoch 24: Loss 2.5736
Epoch 25: Loss 2.3573
Epoch 26: Loss 2.1596
Epoch 27: Loss 1.8814
Epoch 28: Loss 1.6412
Epoch 29: Loss 1.4480
Epoch 30: Loss 1.2340
Epoch 31: Loss 1.1275
Epoch 32: Loss 0.9939
Epoch 33: Loss 0.8926
Epoch 34: Loss 0.8058
Epoch 35: Loss 0.7154
Epoch 36: Loss 0.6423
Epoch 37: Loss 0.6001
Epoch 38: Loss 0.5566
Epoch 39: Loss 0.5202
Epoch 40: Loss 0.4592
Epoch 41: Loss 0.4388
Epoch 42: Loss 0.3984
Epoch 43: Loss 0.3900
Epoch 44: Loss 0.3431
Epoch 45: Loss 0.3282
Epoch 46: Loss 0.32