In [None]:
import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torchvision.transforms.functional import to_tensor
from torch.nn import CTCLoss
import string
from tqdm import tqdm
import editdistance

In [97]:
BATCH_SIZE = 16
IMG_HEIGHT = 32
IMG_WIDTH = 128
EPOCHS = 50
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHARS = string.ascii_uppercase + string.digits+string.ascii_lowercase+"-_' "

In [None]:
class LabelConverter:
    def __init__(self, characters):
        """
        characters: string of all characters in the dataset (excluding CTC blank)
        """
        self.characters = characters
        self.char2idx = {char: i + 1 for i, char in enumerate(characters)}  # 0 is reserved for blank
        self.idx2char = {i + 1: char for i, char in enumerate(characters)}
        self.blank = 0  # CTC requires blank token at index 0
        self.idx2char[self.blank] = ''  # for decoding

    def encode(self, text):
        """
        Convert a text string to a tensor of indices.
        Returns: Tensor of shape [len(text)]
        """
        return torch.tensor([self.char2idx[char] for char in text], dtype=torch.long)

    def decode(self, preds, merge_repeated=True):
        """
        Decode model output probabilities/logits into strings.
        Args:
            preds: Tensor of shape [T, B, C] (logits before/after softmax)
            merge_repeated: whether to remove repeated characters

        Returns:
            List of decoded strings, length = B
        """
        # Take argmax over class dimension
        pred_indices = preds.argmax(2).permute(1, 0)  # [B, T]
        results = []

        for indices in pred_indices:
            decoded = []
            prev_idx = None
            for idx in indices:
                idx = idx.item()
                if idx == self.blank:
                    prev_idx = None  # reset on blank
                    continue
                if merge_repeated and idx == prev_idx:
                    continue
                decoded.append(self.idx2char.get(idx, ''))
                prev_idx = idx
            results.append(''.join(decoded))
        return results
converter=LabelConverter(CHARS)

In [None]:
class CRNNDataset(Dataset):
    def __init__(self, csv_path, transform=None):
        df = pd.read_csv(csv_path)
        self.paths = df['FILENAME'].values
        self.labels = df['IDENTITY'].values
        self.transform = transform or transforms.Compose([
            transforms.Grayscale(),
            transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.transform(img)
        label = converter.encode(self.labels[idx])
        return {'image': img, 'label': label}

In [108]:
def collate_fn(batch):
    images = [item['image'] for item in batch]
    labels = [item['label'] for item in batch]
    label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)
    images = torch.stack(images)
    labels_concat = torch.cat(labels)
    return images, labels_concat, label_lengths

In [109]:
class CRNN(nn.Module):
    def __init__(self, num_classes):
        super(CRNN, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(),
            nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(), nn.MaxPool2d((2, 1)),
            nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(),
            nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(), nn.MaxPool2d((2, 1)),
            nn.Conv2d(512, 512, 2, 1, 0), nn.ReLU(),
        )
        self.rnn1 = nn.LSTM(512, 256, bidirectional=True, batch_first=True)
        self.rnn2 = nn.LSTM(512, 256, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(512, num_classes + 1)  # +1 for CTC blank

    def forward(self, x):
        x = self.cnn(x)  # [B, 512, 1, W]
        x = x.squeeze(2).permute(0, 2, 1)  # [B, W, 512]
        x, _ = self.rnn1(x)  # [B, W, 512]
        x, _ = self.rnn2(x)  # [B, W, 512]
        x = self.fc(x)  # [B, W, C]
        return x.permute(1, 0, 2)  # [W, B, C] for CTC

In [110]:
def calculate_cer(pred, target):
    """Character Error Rate = (Edit Distance) / (Length of GT)"""
    if len(target) == 0:
        return 1.0 if len(pred) > 0 else 0.0
    return editdistance.eval(pred, target) / len(target)

def calculate_word_accuracy(preds, targets):
    """Word Accuracy = exact match rate"""
    correct = sum([p == t for p, t in zip(preds, targets)])
    return correct / len(targets) if targets else 0.0

In [114]:
def validate(model, val_loader, criterion, converter, print_samples=False):
    model.eval()
    val_loss = 0
    total_cer = 0
    total_words = 0
    correct_words = 0

    all_preds = []
    all_targets = []

    with torch.no_grad():
        for images, labels, label_lengths in val_loader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            preds = model(images)
            input_lengths = torch.full(size=(images.size(0),), fill_value=preds.size(0), dtype=torch.long)

            loss = criterion(preds, labels, input_lengths, label_lengths)
            val_loss += loss.item()

            # Decode predictions and labels
            decoded_preds = converter.decode(preds,merge_repeated=True)
            split_labels = []
            i = 0
            for length in label_lengths:
                text = ''.join([converter.idx2char[int(c.item())] for c in labels[i:i+length]])
                split_labels.append(text)
                i += length

            all_preds.extend(decoded_preds)
            all_targets.extend(split_labels)

            # Print sample predictions
            if print_samples:
                for pred, target in zip(decoded_preds[:3], split_labels[:3]):
                    print(f"GT: {target} | Pred: {pred}")

    # Compute metrics
    total_cer = sum(calculate_cer(p, t) for p, t in zip(all_preds, all_targets)) / len(all_targets)
    word_acc = calculate_word_accuracy(all_preds, all_targets)
    avg_loss = val_loss / len(val_loader)

    return avg_loss, total_cer, word_acc


In [None]:
def train():
    # Datasets
    train_dataset = CRNNDataset(r'C:\Users\Raihan\OneDrive\Desktop\DPIIT HACKATHON\cvsi_10k_subset.csv')
    val_dataset = CRNNDataset(r'C:\Users\Raihan\OneDrive\Desktop\DPIIT HACKATHON\cvsi_val.csv')

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

    model = CRNN(num_classes=len(CHARS)).to(DEVICE)
    criterion = CTCLoss(blank=0, zero_infinity=True)
    optimizer = optim.AdamW(model arameters(), lr=1e-3)

    best_val_loss = float('inf')

    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0

        for images, labels, label_lengths in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            preds = model(images)
            input_lengths = torch.full(size=(images.size(0),), fill_value=preds.size(0), dtype=torch.long)

            loss = criterion(preds, labels, input_lengths, label_lengths)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)

        # ---- VALIDATION ----
        val_loss, val_cer,val_acc = validate(model, val_loader, criterion, converter)

        print(f"[Epoch {epoch+1}] Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}|Val CER: {val_cer:.2f}")

        

    print("Training complete.")
    return model


In [116]:
model=train()

Epoch 1: 100%|██████████| 625/625 [01:16<00:00,  8.21it/s]


[Epoch 1] Train Loss: 3.2414 | Val Loss: 3.2285 | Val Acc: 0.00|Val CER: 1.00


Epoch 2: 100%|██████████| 625/625 [01:22<00:00,  7.60it/s]


[Epoch 2] Train Loss: 3.1557 | Val Loss: 3.3380 | Val Acc: 0.00|Val CER: 1.00


Epoch 3: 100%|██████████| 625/625 [01:19<00:00,  7.90it/s]


[Epoch 3] Train Loss: 3.1008 | Val Loss: 3.3822 | Val Acc: 0.00|Val CER: 0.94


Epoch 4: 100%|██████████| 625/625 [01:14<00:00,  8.36it/s]


[Epoch 4] Train Loss: 3.0316 | Val Loss: 2.8856 | Val Acc: 0.00|Val CER: 0.82


Epoch 5: 100%|██████████| 625/625 [01:36<00:00,  6.51it/s]


[Epoch 5] Train Loss: 2.9690 | Val Loss: 3.0616 | Val Acc: 0.00|Val CER: 0.80


Epoch 6: 100%|██████████| 625/625 [01:35<00:00,  6.55it/s]


[Epoch 6] Train Loss: 2.9421 | Val Loss: 3.3025 | Val Acc: 0.00|Val CER: 0.81


Epoch 7: 100%|██████████| 625/625 [01:20<00:00,  7.73it/s]


[Epoch 7] Train Loss: 2.9249 | Val Loss: 2.7122 | Val Acc: 0.00|Val CER: 0.82


Epoch 8: 100%|██████████| 625/625 [01:01<00:00, 10.14it/s]


[Epoch 8] Train Loss: 2.8982 | Val Loss: 2.8984 | Val Acc: 0.00|Val CER: 0.79


Epoch 9: 100%|██████████| 625/625 [00:44<00:00, 13.90it/s]


[Epoch 9] Train Loss: 2.8292 | Val Loss: 2.7697 | Val Acc: 0.00|Val CER: 0.78


Epoch 10: 100%|██████████| 625/625 [00:43<00:00, 14.33it/s]


[Epoch 10] Train Loss: 2.7397 | Val Loss: 2.6993 | Val Acc: 0.00|Val CER: 0.74


Epoch 11: 100%|██████████| 625/625 [00:44<00:00, 14.16it/s]


[Epoch 11] Train Loss: 2.5056 | Val Loss: 2.3861 | Val Acc: 0.00|Val CER: 0.71


Epoch 12: 100%|██████████| 625/625 [00:44<00:00, 14.09it/s]


[Epoch 12] Train Loss: 2.0956 | Val Loss: 1.9063 | Val Acc: 0.01|Val CER: 0.58


Epoch 13: 100%|██████████| 625/625 [00:43<00:00, 14.32it/s]


[Epoch 13] Train Loss: 1.7059 | Val Loss: 1.6415 | Val Acc: 0.04|Val CER: 0.46


Epoch 14: 100%|██████████| 625/625 [00:49<00:00, 12.75it/s]


[Epoch 14] Train Loss: 1.3521 | Val Loss: 1.2363 | Val Acc: 0.13|Val CER: 0.34


Epoch 15: 100%|██████████| 625/625 [00:43<00:00, 14.26it/s]


[Epoch 15] Train Loss: 1.0322 | Val Loss: 1.0976 | Val Acc: 0.24|Val CER: 0.26


Epoch 16: 100%|██████████| 625/625 [00:44<00:00, 13.89it/s]


[Epoch 16] Train Loss: 0.8381 | Val Loss: 0.8628 | Val Acc: 0.30|Val CER: 0.22


Epoch 17: 100%|██████████| 625/625 [00:48<00:00, 12.95it/s]


[Epoch 17] Train Loss: 0.7095 | Val Loss: 0.8595 | Val Acc: 0.34|Val CER: 0.20


Epoch 18: 100%|██████████| 625/625 [00:40<00:00, 15.38it/s]


[Epoch 18] Train Loss: 0.5982 | Val Loss: 0.7409 | Val Acc: 0.40|Val CER: 0.18


Epoch 19: 100%|██████████| 625/625 [00:56<00:00, 11.10it/s]


[Epoch 19] Train Loss: 0.5103 | Val Loss: 0.6524 | Val Acc: 0.41|Val CER: 0.17


Epoch 20: 100%|██████████| 625/625 [00:48<00:00, 12.91it/s]


[Epoch 20] Train Loss: 0.4490 | Val Loss: 0.6025 | Val Acc: 0.41|Val CER: 0.17


Epoch 21: 100%|██████████| 625/625 [00:45<00:00, 13.68it/s]


[Epoch 21] Train Loss: 0.3900 | Val Loss: 0.6862 | Val Acc: 0.45|Val CER: 0.16


Epoch 22: 100%|██████████| 625/625 [01:01<00:00, 10.24it/s]


[Epoch 22] Train Loss: 0.3458 | Val Loss: 0.5917 | Val Acc: 0.48|Val CER: 0.15


Epoch 23: 100%|██████████| 625/625 [01:11<00:00,  8.73it/s]


[Epoch 23] Train Loss: 0.2942 | Val Loss: 0.5367 | Val Acc: 0.50|Val CER: 0.14


Epoch 24: 100%|██████████| 625/625 [00:40<00:00, 15.29it/s]


[Epoch 24] Train Loss: 0.2639 | Val Loss: 0.6386 | Val Acc: 0.51|Val CER: 0.14


Epoch 25: 100%|██████████| 625/625 [00:40<00:00, 15.45it/s]


[Epoch 25] Train Loss: 0.2331 | Val Loss: 0.5931 | Val Acc: 0.53|Val CER: 0.14


Epoch 26: 100%|██████████| 625/625 [00:43<00:00, 14.49it/s]


[Epoch 26] Train Loss: 0.2072 | Val Loss: 0.5867 | Val Acc: 0.53|Val CER: 0.13


Epoch 27: 100%|██████████| 625/625 [01:11<00:00,  8.77it/s]


[Epoch 27] Train Loss: 0.1863 | Val Loss: 0.6942 | Val Acc: 0.53|Val CER: 0.13


Epoch 28: 100%|██████████| 625/625 [00:42<00:00, 14.73it/s]


[Epoch 28] Train Loss: 0.1695 | Val Loss: 0.7608 | Val Acc: 0.51|Val CER: 0.14


Epoch 29: 100%|██████████| 625/625 [00:47<00:00, 13.23it/s]


[Epoch 29] Train Loss: 0.1597 | Val Loss: 0.7672 | Val Acc: 0.52|Val CER: 0.14


Epoch 30: 100%|██████████| 625/625 [00:44<00:00, 14.14it/s]


[Epoch 30] Train Loss: 0.1400 | Val Loss: 0.7878 | Val Acc: 0.55|Val CER: 0.12


Epoch 31: 100%|██████████| 625/625 [00:43<00:00, 14.23it/s]


[Epoch 31] Train Loss: 0.1377 | Val Loss: 0.6360 | Val Acc: 0.57|Val CER: 0.12


Epoch 32: 100%|██████████| 625/625 [00:59<00:00, 10.48it/s]


[Epoch 32] Train Loss: 0.1133 | Val Loss: 0.6577 | Val Acc: 0.54|Val CER: 0.13


Epoch 33: 100%|██████████| 625/625 [00:42<00:00, 14.74it/s]


[Epoch 33] Train Loss: 0.1115 | Val Loss: 0.7233 | Val Acc: 0.55|Val CER: 0.12


Epoch 34: 100%|██████████| 625/625 [00:42<00:00, 14.67it/s]


[Epoch 34] Train Loss: 0.0926 | Val Loss: 0.6731 | Val Acc: 0.54|Val CER: 0.13


Epoch 35: 100%|██████████| 625/625 [00:42<00:00, 14.80it/s]


[Epoch 35] Train Loss: 0.1010 | Val Loss: 0.7754 | Val Acc: 0.55|Val CER: 0.13


Epoch 36: 100%|██████████| 625/625 [00:43<00:00, 14.46it/s]


[Epoch 36] Train Loss: 0.0994 | Val Loss: 0.7292 | Val Acc: 0.56|Val CER: 0.12


Epoch 37: 100%|██████████| 625/625 [01:07<00:00,  9.29it/s]


[Epoch 37] Train Loss: 0.0782 | Val Loss: 0.6609 | Val Acc: 0.56|Val CER: 0.13


Epoch 38: 100%|██████████| 625/625 [00:42<00:00, 14.81it/s]


[Epoch 38] Train Loss: 0.0809 | Val Loss: 0.7061 | Val Acc: 0.58|Val CER: 0.12


Epoch 39: 100%|██████████| 625/625 [00:41<00:00, 15.22it/s]


[Epoch 39] Train Loss: 0.0720 | Val Loss: 0.7956 | Val Acc: 0.52|Val CER: 0.14


Epoch 40: 100%|██████████| 625/625 [00:41<00:00, 15.24it/s]


[Epoch 40] Train Loss: 0.0789 | Val Loss: 0.7515 | Val Acc: 0.56|Val CER: 0.13


Epoch 41: 100%|██████████| 625/625 [00:41<00:00, 15.03it/s]


[Epoch 41] Train Loss: 0.0649 | Val Loss: 0.6873 | Val Acc: 0.58|Val CER: 0.12


Epoch 42: 100%|██████████| 625/625 [00:40<00:00, 15.25it/s]


[Epoch 42] Train Loss: 0.0577 | Val Loss: 0.7558 | Val Acc: 0.58|Val CER: 0.11


Epoch 43: 100%|██████████| 625/625 [00:40<00:00, 15.35it/s]


[Epoch 43] Train Loss: 0.0651 | Val Loss: 0.7653 | Val Acc: 0.56|Val CER: 0.12


Epoch 44: 100%|██████████| 625/625 [00:41<00:00, 15.24it/s]


[Epoch 44] Train Loss: 0.0663 | Val Loss: 0.7494 | Val Acc: 0.56|Val CER: 0.12


Epoch 45: 100%|██████████| 625/625 [00:57<00:00, 10.87it/s]


[Epoch 45] Train Loss: 0.0531 | Val Loss: 0.6892 | Val Acc: 0.56|Val CER: 0.12


Epoch 46: 100%|██████████| 625/625 [00:55<00:00, 11.26it/s]


[Epoch 46] Train Loss: 0.0513 | Val Loss: 0.7470 | Val Acc: 0.56|Val CER: 0.12


Epoch 47: 100%|██████████| 625/625 [00:44<00:00, 14.00it/s]


[Epoch 47] Train Loss: 0.0622 | Val Loss: 0.7226 | Val Acc: 0.56|Val CER: 0.13


Epoch 48: 100%|██████████| 625/625 [00:45<00:00, 13.65it/s]


[Epoch 48] Train Loss: 0.0466 | Val Loss: 0.8017 | Val Acc: 0.58|Val CER: 0.12


Epoch 49: 100%|██████████| 625/625 [00:42<00:00, 14.63it/s]


[Epoch 49] Train Loss: 0.0440 | Val Loss: 0.6097 | Val Acc: 0.59|Val CER: 0.12


Epoch 50: 100%|██████████| 625/625 [00:42<00:00, 14.72it/s]


[Epoch 50] Train Loss: 0.0487 | Val Loss: 0.7357 | Val Acc: 0.58|Val CER: 0.12
Training complete.


In [None]:
optimizer=optim.AdamW
torch.save({
    'epoch': 50,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict()
}, 'crnn_ctc_checkpoint.pth')

SyntaxError: invalid syntax. Perhaps you forgot a comma? (3153723720.py, line 2)