In [None]:
import torch
import requests
import zipfile
import pathlib
from pathlib import Path

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
data_path = Path("/kaggle/input/ai-of-god-3/Public_data")

In [None]:
train_path = data_path/"train_images"
test_path = data_path/"test_images"
train_csv = data_path/"train.csv"

In [None]:
from torch.utils.data import DataLoader,Dataset
from torchvision import datasets, transforms
from PIL import Image
import csv

In [None]:
import os, csv, pathlib
from PIL import Image
from torch.utils.data import Dataset

class CustomImageDataset(Dataset):
    def __init__(self, root, csv_file, tokenizer=None, transform=None, mode="train", max_length=32):
        self.root = root
        self.paths = list(pathlib.Path(self.root).glob("*.png"))
        self.transform = transform
        self.tokenizer = tokenizer
        self.mode = mode
        self.max_length = max_length
        self.labels = {'id': [], 'transcription': []}
        
        if mode == "train":
            with open(csv_file, mode='r') as file:
                reader = csv.DictReader(file)
                for row in reader:
                    self.labels['id'].append(row['unique Id'])
                    self.labels['transcription'].append(row['transcription'])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        image_path = self.paths[index]
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        if self.mode == "train":
            image_id = os.path.splitext(os.path.basename(image_path))[0]
            transcription = self.labels['transcription'][self.labels['id'].index(image_id)]
            encoding = self.tokenizer(transcription, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
            input_ids = encoding["input_ids"].squeeze(0)  # [max_len]
            return image, input_ids

        return image


In [None]:
vocab_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")  # or custom tokenizer
pad_token_id = vocab_tokenizer.pad_token_id
bos_token_id = vocab_tokenizer.bos_token_id

# Dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])


In [None]:
data = CustomImageDataset(root="train/images", csv_file="train/labels.csv", tokenizer=vocab_tokenizer, transform=transform)

In [None]:
from torch.utils.data import random_split
train_size = int(0.8 * len(data))
test_size = len(data) - train_size

train_data, test_data = random_split(data, [train_size, test_size])

In [None]:
from torch.utils.data import DataLoader
import os
NUM_WORKERS = os.cpu_count()
BATCH_SIZE = 32
train_dataloader = DataLoader(dataset=train_data, # use custom created train Dataset
                                     batch_size=BATCH_SIZE, # how many samples per batch?
                                     num_workers=NUM_WORKERS, # how many subprocesses to use for data loading? (higher = more)
                                     shuffle=True) # shuffle the data?

test_dataloader = DataLoader(dataset=test_data, # use custom created test Dataset
                                    batch_size=BATCH_SIZE,
                                    num_workers=NUM_WORKERS,
                                    shuffle=False) # don't usually need to shuffle testing data

train_dataloader, test_dataloader

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import math

# --- Positional Encoding ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)  # [max_len, d_model]
        position = torch.arange(0, max_len).unsqueeze(1)  # [max_len, 1]
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))  # [d_model/2]
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)  # [1, max_len, d_model]

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)].to(x.device)
        return x

# --- CNN Encoder using ResNet ---
class CNNEncoder(nn.Module):
    def __init__(self, output_dim=512):
        super().__init__()
        resnet = models.resnet18(pretrained=True)
        self.features = nn.Sequential(*list(resnet.children())[:-2])  # Remove avgpool and fc
        self.project = nn.Conv2d(512, output_dim, kernel_size=1)  # project to model_dim

    def forward(self, x):  # x: [B, 3, H, W]
        features = self.features(x)  # [B, 512, H/32, W/32]
        features = self.project(features)  # [B, output_dim, H', W']
        B, C, H, W = features.size()
        features = features.permute(0, 2, 3, 1).reshape(B, H*W, C)  # [B, HW, C]
        return features  # sequence of visual tokens

# --- Transformer Decoder ---
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, model_dim=512, nhead=8, num_layers=6, max_len=256):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, model_dim)
        self.pos_encoder = PositionalEncoding(model_dim, max_len)
        decoder_layer = nn.TransformerDecoderLayer(d_model=model_dim, nhead=nhead)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.fc_out = nn.Linear(model_dim, vocab_size)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
        tgt_emb = self.embedding(tgt)  # [T, B, D]
        tgt_emb = self.pos_encoder(tgt_emb)
        output = self.transformer_decoder(tgt_emb, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
        return self.fc_out(output)

# --- Full OCR Model ---
class CNNTransformerOCR(nn.Module):
    def __init__(self, vocab_size, model_dim=512, max_len=256):
        super().__init__()
        self.encoder = CNNEncoder(output_dim=model_dim)
        self.decoder = TransformerDecoder(vocab_size, model_dim=model_dim, max_len=max_len)
        self.model_dim = model_dim

    def forward(self, images, tgt_seq, tgt_mask=None):
        memory = self.encoder(images)  # [B, S, D]
        memory = memory.permute(1, 0, 2)  # [S, B, D] for Transformer
        tgt_seq = tgt_seq.permute(1, 0)  # [T, B]
        output = self.decoder(tgt_seq, memory, tgt_mask)
        return output  # [T, B, vocab_size]


In [None]:
model = CNNTransformerOCR(vocab_size=len(vocab_tokenizer)).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=pad_token_id)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
def wer(reference, hypothesis):
    # Split the reference and hypothesis sentences into words
    reference_words = reference.split()
    hypothesis_words = hypothesis.split()

    # Create a matrix to store the distances
    d = np.zeros((len(reference_words) + 1, len(hypothesis_words) + 1), dtype=np.uint8)

    # Initialize the matrix
    for i in range(1, len(reference_words) + 1):
        d[i][0] = i
    for j in range(1, len(hypothesis_words) + 1):
        d[0][j] = j

    # Fill the matrix
    for i in range(1, len(reference_words) + 1):
        for j in range(1, len(hypothesis_words) + 1):
            if reference_words[i - 1] == hypothesis_words[j - 1]:
                d[i][j] = d[i - 1][j - 1]
            else:
                d[i][j] = min(d[i - 1][j], d[i][j - 1], d[i - 1][j - 1]) + 1

    wer_value = d[len(reference_words)][len(hypothesis_words)] / len(reference_words)

    return wer_value

In [None]:
def train_and_validate(model, train_dataloader, val_dataloader, tokenizer, optimizer, criterion, device, num_epochs=10, pad_token_id=0):
    best_wer = float('inf')

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0

        for images, tgt_input_ids in tqdm(train_dataloader, desc=f"Epoch {epoch+1} [Train]"):
            images = images.to(device)
            tgt_input_ids = tgt_input_ids.to(device)

            tgt_input = tgt_input_ids[:, :-1]  # [B, T-1]
            tgt_output = tgt_input_ids[:, 1:]  # [B, T-1]

            tgt_input = tgt_input.permute(1, 0)  # [T, B]
            tgt_output = tgt_output.permute(1, 0)  # [T, B]

            tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_input.size(0)).to(device)

            logits = model(images, tgt_input, tgt_mask=tgt_mask)
            loss = criterion(logits.view(-1, logits.size(-1)), tgt_output.reshape(-1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_dataloader)
        print(f"[Epoch {epoch+1}] Train Loss: {avg_train_loss:.4f}")

        # Validation
        model.eval()
        val_loss = 0
        references = []
        predictions = []

        with torch.no_grad():
            for images, tgt_input_ids in tqdm(val_dataloader, desc=f"Epoch {epoch+1} [Val]"):
                images = images.to(device)
                tgt_input_ids = tgt_input_ids.to(device)
                batch_size = images.size(0)

                # Decode token-by-token
                decoded = torch.full((batch_size, 1), tokenizer.bos_token_id, dtype=torch.long).to(device)
                memory = model.encoder(images).permute(1, 0, 2)  # [S, B, D]

                for _ in range(tgt_input_ids.size(1)):
                    tgt_mask = nn.Transformer.generate_square_subsequent_mask(decoded.size(1)).to(device)
                    output = model.decoder(decoded.permute(1, 0), memory, tgt_mask=tgt_mask)
                    next_token = output.argmax(-1)[-1, :]  # [B]
                    decoded = torch.cat([decoded, next_token.unsqueeze(1)], dim=1)

                for i in range(batch_size):
                    pred_text = tokenizer.decode(decoded[i].tolist(), skip_special_tokens=True)
                    true_text = tokenizer.decode(tgt_input_ids[i].tolist(), skip_special_tokens=True)
                    predictions.append(pred_text)
                    references.append(true_text)

                # Optional loss (for logging)
                tgt_input = tgt_input_ids[:, :-1]
                tgt_output = tgt_input_ids[:, 1:]
                tgt_input = tgt_input.permute(1, 0)
                tgt_output = tgt_output.permute(1, 0)
                tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_input.size(0)).to(device)
                logits = model(images, tgt_input, tgt_mask=tgt_mask)
                loss = criterion(logits.view(-1, logits.size(-1)), tgt_output.reshape(-1))
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_dataloader)
        print(f"[Epoch {epoch+1}] Val Loss: {avg_val_loss:.4f}")

        # Calculate Average WER
        wers = [wer(ref, hyp) for ref, hyp in zip(references, predictions)]
        avg_wer = sum(wers) / len(wers)
        print(f"[Epoch {epoch+1}] Val WER: {avg_wer:.4f}")

        # Save best model
        if avg_wer < best_wer:
            best_wer = avg_wer
            torch.save(model.state_dict(), "best_ocr_model.pt")
            print(f"✅ Best model saved with WER: {best_wer:.4f}")


In [None]:
train_and_validate(
    model=model,
    train_dataloader=train_dataloader,
    val_dataloader=test_dataloader,
    tokenizer=vocab_tokenizer,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    num_epochs=10,
    pad_token_id=vocab_tokenizer.pad_token_id
)
