In [None]:
#!unzip /content/IAM_Correct.zip

In [None]:
import os
import jax
import jax.numpy as jnp
import numpy as np
import optax
import string
import random
from flax import linen as nn
from PIL import Image, UnidentifiedImageError

# ======= Character Set for IAM Dataset =======
CHARS = string.ascii_lowercase + string.digits + ".,;:'\"!?()&- "
char_to_index = {char: i + 1 for i, char in enumerate(CHARS)}  # Start from 1 (0 is blank)
index_to_char = {i: char for char, i in char_to_index.items()}  # Converts indices back into characters

class IAMDataset:
    def __init__(self, img_root, label_file, img_size=(32, 128)):
        self.img_root = img_root
        self.img_size = img_size
        self.valid_samples = []

        with open(label_file, "r", encoding="utf-8") as f:
            for line in f:
                if line.startswith("#"):
                    continue
                parts = line.strip().split(" ")
                img_name = parts[0]
                text = " ".join(parts[8:]).lower()
                label = [char_to_index[c] for c in text if c in char_to_index]
                folder1, folder2 = img_name.split("-")[:2]
                img_path = os.path.normpath(os.path.join(self.img_root, folder1, f"{folder1}-{folder2}", img_name + ".png"))

                if os.path.exists(img_path) and self._is_valid_image(img_path):
                    self.valid_samples.append((img_path, label))

    def __len__(self):
        return len(self.valid_samples)

    def __getitem__(self, idx):
        img_path, label = self.valid_samples[idx]
        try:
            image = Image.open(img_path).convert("L")
            image = image.resize((128, 32), Image.BILINEAR)
            image = np.array(image, dtype=np.float32) / 255.0
            image = np.expand_dims(image, axis=0)
            return image, jnp.array(label, dtype=jnp.int32)
        except (UnidentifiedImageError, OSError):
            return None

    def _is_valid_image(self, img_path):
        try:
            with Image.open(img_path) as img:
                img.verify()
            return True
        except (UnidentifiedImageError, OSError):
            return False

# ======= Data Loader Function for JAX =======
def jax_dataloader(dataset, batch_size=32, shuffle=True):
    indices = list(range(len(dataset)))
    if shuffle:
        random.shuffle(indices)

    for start in range(0, len(indices), batch_size):
        batch_indices = indices[start:start + batch_size]
        batch_samples = [dataset[i] for i in batch_indices if dataset[i] is not None]

        if len(batch_samples) == 0:
            continue

        images, labels = zip(*batch_samples)
        images = jnp.stack(images)

        max_label_length = max(len(label) for label in labels)
        padded_labels = jnp.array([
            jnp.pad(label, (0, max_label_length - len(label)), constant_values=0)
            for label in labels
        ])

        yield images, padded_labels, jnp.array([len(label) for label in labels])

# ======= CRNN Model =======
class CRNN(nn.Module):
    img_height: int
    num_classes: int
    lstm_hidden_size: int = 256

    def setup(self):
        self.conv1 = nn.Conv(features=64, kernel_size=(3, 3), strides=1, padding='SAME', dtype=jnp.float32)
        self.conv2 = nn.Conv(features=128, kernel_size=(3, 3), strides=1, padding='SAME', dtype=jnp.float32)
        self.fc = nn.Dense(features=self.num_classes)
        self.lstm = nn.scan(
            nn.LSTMCell,
            variable_broadcast="params",
            split_rngs={"params": False},
            in_axes=1,
            out_axes=1,
        )(features=self.lstm_hidden_size)

    def __call__(self, x):
        x = nn.relu(self.conv1(x))
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='SAME')
        x = nn.relu(self.conv2(x))
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='SAME')

        b, c, h, w = x.shape
        x = x.reshape(b, w, c * h)

        carry = (
            jnp.zeros((b, self.lstm_hidden_size)),
            jnp.zeros((b, self.lstm_hidden_size)),
        )
        _, lstm_out = self.lstm(carry, x)
        output = self.fc(lstm_out.reshape(-1, self.lstm_hidden_size))
        output = output.reshape(b, w, self.num_classes)

        return jax.nn.log_softmax(output, axis=-1)

# ======= Compute Loss Function =======
def loss_fn(params, images, labels):
    logits = model.apply(params, images)
    logit_paddings = jnp.zeros((logits.shape[0], logits.shape[1]), dtype=jnp.float32)
    label_paddings = (labels == 0).astype(jnp.float32)
    loss = optax.ctc_loss(logits, logit_paddings, labels, label_paddings)
    return loss.mean()

# ======= Training Function =======
@jax.jit
def train_step(params, opt_state, images, labels):
    loss, grads = jax.value_and_grad(loss_fn)(params, images, labels)
    loss = jnp.nan_to_num(loss, nan=0.0)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

# ======= Validation Function =======
def evaluate_model(dataset):
    total_loss = 0
    batch_count = 0

    for images, labels, label_lengths in jax_dataloader(dataset, batch_size=32, shuffle=False):
        loss = loss_fn(params, images, labels)
        total_loss += loss.item()
        batch_count += 1

    return total_loss / batch_count if batch_count > 0 else float("inf")

# ======= Train Model Function =======
def train_model(train_dataset, val_dataset, epochs=10):
    global params, opt_state
    for epoch in range(epochs):
        epoch_loss = 0
        batch_count = 0

        for images, labels, label_lengths in jax_dataloader(train_dataset, batch_size=32):
            params, opt_state, loss = train_step(params, opt_state, images, labels)
            if loss == 0.0:
                continue

            epoch_loss += loss.item()
            batch_count += 1

        avg_train_loss = epoch_loss / batch_count if batch_count > 0 else float("inf")
        val_loss = evaluate_model(val_dataset)

        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}")

# ======= Training Configuration =======
img_dir = "/content/iam_words/words"
label_file = "/content/iam_words/words.txt"

dataset = IAMDataset(img_dir, label_file)
dataset_size = len(dataset)
split_idx = int(0.8 * dataset_size)

train_dataset = IAMDataset(img_dir, label_file)
val_dataset = IAMDataset(img_dir, label_file)
train_dataset.valid_samples = dataset.valid_samples[:split_idx]
val_dataset.valid_samples = dataset.valid_samples[split_idx:]

num_classes = len(CHARS) + 1
model = CRNN(img_height=32, num_classes=num_classes)
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 1, 32, 128)))
optimizer = optax.adam(learning_rate=0.001)
opt_state = optimizer.init(params)

train_model(train_dataset, val_dataset, epochs=50)

Epoch [1/50], Loss: 16.7275, Val Loss: 13.1648
Epoch [2/50], Loss: 13.7482, Val Loss: 12.4942
Epoch [3/50], Loss: 13.2128, Val Loss: 12.1312
Epoch [4/50], Loss: 12.9496, Val Loss: 12.1484
Epoch [5/50], Loss: 12.7559, Val Loss: 11.8529
Epoch [6/50], Loss: 12.5726, Val Loss: 11.7964
Epoch [7/50], Loss: 12.3963, Val Loss: 11.6960
Epoch [8/50], Loss: 12.2310, Val Loss: 11.2938
Epoch [9/50], Loss: 12.0380, Val Loss: 11.1781
