In [None]:
pip install jax jaxlib flax optax pillow numpy



In [None]:
!unzip -l /content/handwriting_dataset_1.zip


Archive:  /content/handwriting_dataset_1.zip
  Length      Date    Time    Name
---------  ---------- -----   ----
  2097139  2025-02-09 01:08   handwriting_dataset_1.zip/words_new.txt
        0  2025-03-07 01:46   handwriting_dataset_1.zip/iam_words/words/k07/k07-185/
        0  2025-03-07 01:46   handwriting_dataset_1.zip/iam_words/words/k07/k07-125/
        0  2025-03-07 01:46   handwriting_dataset_1.zip/iam_words/words/k03/k03-144/
        0  2025-03-07 01:46   handwriting_dataset_1.zip/iam_words/words/k03/k03-152/
        0  2025-03-07 01:46   handwriting_dataset_1.zip/iam_words/words/k07/k07-152/
        0  2025-03-07 01:46   handwriting_dataset_1.zip/iam_words/words/k03/k03-164/
        0  2025-03-07 01:46   handwriting_dataset_1.zip/iam_words/words/k03/k03-157/
        0  2025-03-07 01:46   handwriting_dataset_1.zip/iam_words/words/k03/k03-138/
        0  2025-03-07 01:46   handwriting_dataset_1.zip/iam_words/words/k07/k07-146/
        0  2025-03-07 01:46   handwriting_dataset_

In [None]:
!unzip /content/handwriting_dataset_1.zip -d /content/handwriting_dataset_1


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
   creating: /content/handwriting_dataset_1/handwriting_dataset_1.zip/iam_words/words/d06/d06-076/
   creating: /content/handwriting_dataset_1/handwriting_dataset_1.zip/iam_words/words/c04/c04-160/
   creating: /content/handwriting_dataset_1/handwriting_dataset_1.zip/iam_words/words/d06/d06-082/
   creating: /content/handwriting_dataset_1/handwriting_dataset_1.zip/iam_words/words/c04/c04-075/
   creating: /content/handwriting_dataset_1/handwriting_dataset_1.zip/iam_words/words/d06/d06-072/
   creating: /content/handwriting_dataset_1/handwriting_dataset_1.zip/iam_words/words/e02/e02-018/
   creating: /content/handwriting_dataset_1/handwriting_dataset_1.zip/iam_words/words/g04/g04-007/
   creating: /content/handwriting_dataset_1/handwriting_dataset_1.zip/iam_words/words/d06/d06-111/
   creating: /content/handwriting_dataset_1/handwriting_dataset_1.zip/iam_words/words/c04/c04-150/
   creating: /content/handwriting_dataset_1/

In [None]:
!pip install --upgrade pip
!pip install --upgrade jax jaxlib flax optax


Collecting pip
  Downloading pip-25.0.1-py3-none-any.whl.metadata (3.7 kB)
Downloading pip-25.0.1-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m70.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.0.1


In [None]:
!pip install --upgrade pip
!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting nvidia-cuda-nvcc-cu12>=12.6.85 (from jax-cuda12-plugin[with_cuda]<=0.5.2,>=0.5.1; extra == "cuda"->jax[cuda])
  Downloading nvidia_cuda_nvcc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.7 kB)
Downloading nvidia_cuda_nvcc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (40.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.1/40.1 MB[0m [31m121.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: nvidia-cuda-nvcc-cu12
  Attempting uninstall: nvidia-cuda-nvcc-cu12
    Found existing installation: nvidia-cuda-nvcc-cu12 12.5.82
    Uninstalling nvidia-cuda-nvcc-cu12-12.5.82:
      Successfully uninstalled nvidia-cuda-nvcc-cu12-12.5.82
Successfully installed nvidia-cuda-nvcc-cu12-12.8.93


In [None]:
import os
import jax
import jax.numpy as jnp
import numpy as np
import optax
import string
import random
import editdistance
from flax import linen as nn
from PIL import Image, UnidentifiedImageError

# ======= Character Set for IAM Dataset =======
CHARS = string.ascii_lowercase + string.digits + ".,;:'\"!?()&- "  # Ensure space is included at the end
char_to_index = {char: i + 1 for i, char in enumerate(CHARS)}  # Start from 1 (0 is blank)
index_to_char = {i: char for char, i in char_to_index.items()}


class IAMDataset:
    def __init__(self, img_root, label_file, img_size=(32, 128)):
        self.img_root = img_root
        self.img_size = img_size
        self.valid_samples = []

        with open(label_file, "r", encoding="utf-8") as f:
            for line in f:
                if line.startswith("#"):
                    continue
                parts = line.strip().split(" ")
                img_name = parts[0]
                text = " ".join(parts[8:]).lower()
                label = [char_to_index[c] for c in text if c in char_to_index]
                folder1, folder2 = img_name.split("-")[:2]
                img_path = os.path.normpath(os.path.join(self.img_root, folder1, f"{folder1}-{folder2}", img_name + ".png"))

                if os.path.exists(img_path) and self._is_valid_image(img_path):
                    self.valid_samples.append((img_path, label))

    def __len__(self):
        return len(self.valid_samples)

    def __getitem__(self, idx):
        img_path, label = self.valid_samples[idx]
        try:
            image = Image.open(img_path).convert("L")
            image = image.resize((128, 32), Image.BILINEAR)
            image = np.array(image, dtype=np.float32) / 255.0
            image = np.expand_dims(image, axis=-1)

            return image, jnp.array(label, dtype=jnp.int32)
        except (UnidentifiedImageError, OSError):
            return None

    def _is_valid_image(self, img_path):
        try:
            with Image.open(img_path) as img:
                img.verify()
            return True
        except (UnidentifiedImageError, OSError):
            return False

def jax_dataloader(dataset, batch_size=32, shuffle=True):
    indices = list(range(len(dataset)))
    if shuffle:
        random.shuffle(indices)

    for start in range(0, len(indices), batch_size):
        batch_indices = indices[start:start + batch_size]
        batch_samples = [dataset[i] for i in batch_indices if dataset[i] is not None]

        if len(batch_samples) == 0:
            continue

        images, labels = zip(*batch_samples)
        images = jnp.stack(images)

        max_label_length = max(len(label) for label in labels)
        padded_labels = jnp.array([
            jnp.pad(label, (0, max_label_length - len(label)), constant_values=0)  # Ensure blank token at the end
            for label in labels
        ])

        yield images, padded_labels, jnp.array([len(label) for label in labels])


# ======= CRNN Model =======
class CRNN(nn.Module):
    img_height: int
    num_classes: int
    lstm_hidden_size: int = 512  # Fix 2: Increased LSTM size
    num_lstm_layers: int = 2  # Fix 2: Added second LSTM layer

    def setup(self):
      self.conv1 = nn.Conv(features=64, kernel_size=(3, 3), strides=1, padding='SAME')
      self.conv2 = nn.Conv(features=128, kernel_size=(3, 3), strides=1, padding='SAME')
      self.conv3 = nn.Conv(features=256, kernel_size=(3, 3), strides=1, padding='SAME')  # Added third conv layer
      self.conv4 = nn.Conv(features=256, kernel_size=(3, 3), strides=1, padding='SAME')  # Second 256 layer
      self.conv5 = nn.Conv(features=512, kernel_size=(3, 3), strides=1, padding='SAME')  # Final deeper layer
      self.fc = nn.Dense(features=self.num_classes)

      # Bidirectional LSTMs (for Modification 3)
      self.lstm_fw = [nn.LSTMCell(features=self.lstm_hidden_size) for _ in range(self.num_lstm_layers)]
      self.lstm_bw = [nn.LSTMCell(features=self.lstm_hidden_size) for _ in range(self.num_lstm_layers)]

    @nn.compact
    def __call__(self, x, train=True):  # Add train argument
      x = nn.relu(self.conv1(x))
      x = nn.BatchNorm(use_running_average=not train)(x)  # Pass use_running_average
      x = nn.Dropout(0.3, deterministic=not train)(x)  # Dropout must also use deterministic
      x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='SAME')

      x = nn.relu(self.conv2(x))
      x = nn.BatchNorm(use_running_average=not train)(x)
      x = nn.Dropout(0.3, deterministic=not train)(x)
      x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='SAME')

      x = nn.relu(self.conv3(x))
      x = nn.BatchNorm(use_running_average=not train)(x)
      x = nn.Dropout(0.3, deterministic=not train)(x)
      x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='SAME')

      x = nn.relu(self.conv4(x))
      x = nn.BatchNorm(use_running_average=not train)(x)
      x = nn.Dropout(0.3, deterministic=not train)(x)

      x = nn.relu(self.conv5(x))
      x = nn.BatchNorm(use_running_average=not train)(x)
      x = nn.Dropout(0.3, deterministic=not train)(x)

      b, h, w, c = x.shape
      x = x.reshape(b, w, h * c)

      # Bidirectional LSTMs
      carries_fw = [lstm.initialize_carry(jax.random.PRNGKey(0), (b, self.lstm_hidden_size)) for lstm in self.lstm_fw]
      carries_bw = [lstm.initialize_carry(jax.random.PRNGKey(0), (b, self.lstm_hidden_size)) for lstm in self.lstm_bw]

      outputs = []
      for t in range(x.shape[1]):
          for i, lstm in enumerate(self.lstm_fw):
              carries_fw[i], x_t_fw = lstm(carries_fw[i], x[:, t, :])
          for i, lstm in enumerate(self.lstm_bw):
              carries_bw[i], x_t_bw = lstm(carries_bw[i], x[:, x.shape[1] - t - 1, :])  # Reverse order for BW LSTM
          outputs.append(jnp.concatenate([x_t_fw, x_t_bw], axis=-1))  # Concatenating both directions

      lstm_out = jnp.stack(outputs, axis=1)
      output = self.fc(lstm_out)
      return jax.nn.log_softmax(output, axis=-1)



# ======= Compute Loss Function =======
def loss_fn_with_batch_stats(params, images, labels):
    logits, new_model_state = model.apply(
        {'params': params, 'batch_stats': batch_stats},
        images,
        train=True,
        mutable=['batch_stats'],
        rngs={'dropout': rng}  # Ensure dropout gets a PRNG key
    )
    logit_paddings = jnp.zeros((logits.shape[0], logits.shape[1]), dtype=jnp.int32)
    label_paddings = (labels == 0).astype(jnp.int32)
    loss = optax.ctc_loss(logits, logit_paddings, labels, label_paddings).mean()
    return loss, new_model_state['batch_stats']


# ======= Training Function =======
@jax.jit
def train_step(params, batch_stats, opt_state, images, labels):
    def loss_fn_with_batch_stats(params, images, labels):
        logits, new_model_state = model.apply(
            {'params': params, 'batch_stats': batch_stats},
            images, train=True, mutable=['batch_stats'],
            rngs={'dropout': rng}  # Pass Dropout PRNG
        )
        logit_paddings = jnp.zeros((logits.shape[0], logits.shape[1]), dtype=jnp.int32)
        label_paddings = (labels == 0).astype(jnp.int32)
        loss = optax.ctc_loss(logits, logit_paddings, labels, label_paddings).mean()
        return loss, new_model_state['batch_stats']  # Return updated batch_stats

    (loss, new_batch_stats), grads = jax.value_and_grad(loss_fn_with_batch_stats, has_aux=True)(params, images, labels)
    grads = jax.tree_util.tree_map(lambda g: jnp.clip(g, -1.0, 1.0) if g is not None else 0, grads)

    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    return params, new_batch_stats, opt_state, loss


def compute_cer(preds, targets):
    total_chars, total_errors = 0, 0

    for pred, target in zip(preds, targets):
        pred_text = ''.join([index_to_char[i] for i in pred if i in index_to_char])
        target_text = ''.join([index_to_char[i] for i in target if i in index_to_char])

        total_chars += len(target_text)
        total_errors += editdistance.eval(pred_text, target_text)  # Using editdistance

    return total_errors / total_chars if total_chars > 0 else float("inf")

def greedy_decode(preds):
    decoded = []
    prev_char = None
    for i in preds:
        if i == 0:  # Skip blank characters (CTC blank)
            prev_char = None
            continue
        if i != prev_char:  # Avoid repeated characters (CTC rule)
            decoded.append(index_to_char.get(i, "?"))
        prev_char = i
    return "".join(decoded).strip()  # Remove leading/trailing spaces


def evaluate_model(dataset):
    total_loss, batch_count = 0, 0
    all_preds, all_targets = [], []

    for images, labels, label_lengths in jax_dataloader(dataset, batch_size=32, shuffle=False):
        loss, _ = loss_fn_with_batch_stats(params, images, labels)  # Correct loss function call
        total_loss += float(loss)  # Convert to float for accumulation
        batch_count += 1

        logits = model.apply({'params': params, 'batch_stats': batch_stats}, images, train=False, mutable=False)
        preds = jnp.argmax(logits, axis=-1).tolist()

        all_preds.extend(preds)
        all_targets.extend(labels.tolist())

    # Print a few decoded results
    for i in range(min(10, len(all_preds))):  # Avoid index errors if dataset is small
        pred_text = greedy_decode(all_preds[i])  # Use improved CTC decoding
        target_text = "".join([index_to_char.get(c, "?") for c in all_targets[i] if c > 0])  # Ignore padding

        print(f"Decoded Target: {target_text}")
        print(f"Decoded Prediction: {pred_text}")
        print("=" * 40)

    # Compute CER
    cer = compute_cer(all_preds, all_targets)
    avg_loss = total_loss / batch_count if batch_count > 0 else float("inf")

    return avg_loss, cer

# ----------------- In Progress ------------------
# Want to make it get input/print output

def predict_single_image(model, params, batch_stats, image_path):
    # Load and preprocess the image
    try:
        image = Image.open(image_path).convert("L")  # Convert to grayscale
        image = image.resize((128, 32), Image.BILINEAR)  # Resize to match model input
        image = np.array(image, dtype=np.float32) / 255.0  # Normalize pixel values
        image = np.expand_dims(image, axis=-1)  # Add channel dimension
        image = np.expand_dims(image, axis=0)  # Add batch dimension (1, 32, 128, 1)
        
        image_jax = jnp.array(image)  # Convert to JAX array
        
        # Run inference through the model
        logits = model.apply({'params': params, 'batch_stats': batch_stats}, image_jax, train=False, mutable=False)
        preds = jnp.argmax(logits, axis=-1).squeeze().tolist()  # Get predicted indices

        # Decode the prediction
        predicted_text = greedy_decode(preds)

        print(f"Predicted Text: {predicted_text}")
        return predicted_text
    except (UnidentifiedImageError, OSError):
        print("Error loading image.")
        return None



# ======= Training Configuration =======
img_dir = "/content/handwriting_dataset_1/handwriting_dataset_1.zip/iam_words/words/"
label_file = "/content/handwriting_dataset_1/handwriting_dataset_1.zip/iam_words/words.txt"

dataset = IAMDataset(img_dir, label_file)
split_idx = int(0.8 * len(dataset))

train_dataset = IAMDataset(img_dir, label_file)
val_dataset = IAMDataset(img_dir, label_file)
train_dataset.valid_samples = dataset.valid_samples[:split_idx]
val_dataset.valid_samples = dataset.valid_samples[split_idx:]

num_classes = len(CHARS) + 1
model = CRNN(img_height=32, num_classes=num_classes, lstm_hidden_size=512)
rng = jax.random.PRNGKey(0)
dummy_input = jnp.ones((1, 32, 128, 1))
print("Dummy input shape:", dummy_input.shape)

variables = model.init({'params': rng, 'dropout': jax.random.PRNGKey(1)}, dummy_input, train=True)
params = variables['params']  # Extract model parameters
batch_stats = variables['batch_stats']  # Extract BatchNorm statistics

schedule = optax.exponential_decay(
    init_value=1e-3,  # Lowered learning rate
    transition_steps=500,
    decay_rate=0.95
)

optimizer = optax.adam(schedule)
opt_state = optimizer.init(params)  # Now params is defined


# ======= Train Model Function =======
def train_model(train_dataset, val_dataset, epochs=10):
    global params, batch_stats, opt_state
    for epoch in range(epochs):
        epoch_loss, batch_count = 0, 0

        for images, labels, label_lengths in jax_dataloader(train_dataset, batch_size=32):
            params, batch_stats, opt_state, loss = train_step(params, batch_stats, opt_state, images, labels)
            epoch_loss += float(loss)
            batch_count += 1

        avg_train_loss = epoch_loss / batch_count if batch_count > 0 else float("inf")
        val_loss, cer = evaluate_model(val_dataset)

        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}, CER: {cer:.4f}")


# ======= Train Model =======
train_model(train_dataset, val_dataset, epochs=40)

Dummy input shape: (1, 32, 128, 1)
Decoded Target: "
Decoded Prediction: .
Decoded Target: my
Decoded Prediction: .,
Decoded Target: background
Decoded Prediction: .
Decoded Target: is
Decoded Prediction: .,.
Decoded Target: a
Decoded Prediction: .
Decoded Target: doctor
Decoded Prediction: .,
Decoded Target: of
Decoded Prediction: .,.
Decoded Target: 68
Decoded Prediction: .,.,
Decoded Target: ,
Decoded Prediction: .
Decoded Target: who
Decoded Prediction: .
Decoded Target: has
Decoded Prediction: .
Decoded Target: practised
Decoded Prediction: .,.
Decoded Target: medicine
Decoded Prediction: .,.
Decoded Target: for
Decoded Prediction: .,.
Decoded Target: 43
Decoded Prediction: .,
Decoded Target: years
Decoded Prediction: .,.
Decoded Target: ,
Decoded Prediction: ..
Decoded Target: chiefly
Decoded Prediction: .,.
Decoded Target: as
Decoded Prediction: ...
Decoded Target: a
Decoded Prediction: .
Epoch [1/50], Loss: 72.3488, Val Loss: 12.3025, CER: 3.4824
Decoded Target: "
Decoded Predi

In [None]:
import os
import jax
import jax.numpy as jnp
import numpy as np
import optax
import string
import random
import editdistance
from flax import linen as nn
from PIL import Image, UnidentifiedImageError

# ======= Character Set for IAM Dataset =======
CHARS = string.ascii_lowercase + string.digits + ".,;:'\"!?()&- "  # Ensure space is included at the end
char_to_index = {char: i + 1 for i, char in enumerate(CHARS)}  # Start from 1 (0 is blank)
index_to_char = {i: char for char, i in char_to_index.items()}


class IAMDataset:
    def __init__(self, img_root, label_file, img_size=(32, 128)):
        self.img_root = img_root
        self.img_size = img_size
        self.valid_samples = []

        with open(label_file, "r", encoding="utf-8") as f:
            for line in f:
                if line.startswith("#"):
                    continue
                parts = line.strip().split(" ")
                img_name = parts[0]
                text = " ".join(parts[8:]).lower()
                label = [char_to_index[c] for c in text if c in char_to_index]
                folder1, folder2 = img_name.split("-")[:2]
                img_path = os.path.normpath(os.path.join(self.img_root, folder1, f"{folder1}-{folder2}", img_name + ".png"))

                if os.path.exists(img_path) and self._is_valid_image(img_path):
                    self.valid_samples.append((img_path, label))

    def __len__(self):
        return len(self.valid_samples)

    def __getitem__(self, idx):
        img_path, label = self.valid_samples[idx]
        try:
            image = Image.open(img_path).convert("L")
            image = image.resize((128, 32), Image.BILINEAR)
            image = np.array(image, dtype=np.float32) / 255.0
            image = np.expand_dims(image, axis=-1)

            return image, jnp.array(label, dtype=jnp.int32)
        except (UnidentifiedImageError, OSError):
            return None

    def _is_valid_image(self, img_path):
        try:
            with Image.open(img_path) as img:
                img.verify()
            return True
        except (UnidentifiedImageError, OSError):
            return False

def jax_dataloader(dataset, batch_size=32, shuffle=True):
    indices = list(range(len(dataset)))
    if shuffle:
        random.shuffle(indices)

    for start in range(0, len(indices), batch_size):
        batch_indices = indices[start:start + batch_size]
        batch_samples = [dataset[i] for i in batch_indices if dataset[i] is not None]

        if len(batch_samples) == 0:
            continue

        images, labels = zip(*batch_samples)
        images = jnp.stack(images)

        max_label_length = max(len(label) for label in labels)
        padded_labels = jnp.array([
            jnp.pad(label, (0, max_label_length - len(label)), constant_values=0)  # Ensure blank token at the end
            for label in labels
        ])

        yield images, padded_labels, jnp.array([len(label) for label in labels])


class CRNN(nn.Module):
    img_height: int
    num_classes: int
    lstm_hidden_size: int = 512
    num_lstm_layers: int = 2

    def setup(self):
        self.conv1 = nn.Conv(features=64, kernel_size=(3, 3), strides=1, padding='SAME')
        self.conv2 = nn.Conv(features=128, kernel_size=(3, 3), strides=1, padding='SAME')
        self.conv3 = nn.Conv(features=256, kernel_size=(3, 3), strides=1, padding='SAME')
        self.conv_proj = nn.Conv(features=256, kernel_size=(1, 1), strides=1, padding='SAME')
        self.fc = nn.Dense(features=self.num_classes)

    @nn.compact
    def __call__(self, x, train=True):
        def residual_block(x, conv, conv_proj, train):
            # main path
            y = nn.relu(conv(x))
            y = nn.BatchNorm(use_running_average=not train)(y)
            y = nn.Dropout(0.2, deterministic=not train)(y)
            # skip path (projection to match 256 channels)
            skip = conv_proj(x)
            return skip + y

        # --- Block 1 ---
        x = nn.relu(self.conv1(x))
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.Dropout(0.2, deterministic=not train)(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='SAME')

        # --- Block 2 ---
        x = nn.relu(self.conv2(x))
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.Dropout(0.2, deterministic=not train)(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='SAME')

        # --- Block 3: Single residual block ---
        x = residual_block(x, self.conv3, self.conv_proj, train)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='SAME')

        # Flatten to (batch_size, width, height*channels)
        b, h, w, c = x.shape
        x_seq = x.reshape(b, w, h * c)

        # --- One Transformer Block (instead of two) ---
        def transformer_block(x):
            y = nn.LayerNorm()(x)
            y = nn.SelfAttention(num_heads=4, qkv_features=x.shape[-1])(y)  # fewer heads
            y = nn.Dropout(0.2, deterministic=not train)(y)
            return x + y

        x_seq = transformer_block(x_seq)

        # Final dense + log_softmax
        logits = self.fc(x_seq)
        return jax.nn.log_softmax(logits, axis=-1)



# ======= Compute Loss Function =======
def loss_fn_with_batch_stats(params, images, labels):
    logits, new_model_state = model.apply(
        {'params': params, 'batch_stats': batch_stats},
        images,
        train=True,
        mutable=['batch_stats'],
        rngs={'dropout': rng}  # Ensure dropout gets a PRNG key
    )
    logit_paddings = jnp.zeros((logits.shape[0], logits.shape[1]), dtype=jnp.int32)
    label_paddings = (labels == 0).astype(jnp.int32)
    loss = optax.ctc_loss(logits, logit_paddings, labels, label_paddings).mean()
    return loss, new_model_state['batch_stats']


# ======= Updated Train Step Function =======
@jax.jit
def train_step(params, batch_stats, opt_state, images, labels):
    def loss_fn_with_batch_stats(params, images, labels):
        logits, new_model_state = model.apply(
            {'params': params, 'batch_stats': batch_stats},
            images,
            train=True,
            mutable=['batch_stats'],
            rngs={'dropout': rng}  # Pass Dropout PRNG
        )
        logit_paddings = jnp.zeros((logits.shape[0], logits.shape[1]), dtype=jnp.int32)
        label_paddings = (labels == 0).astype(jnp.int32)
        loss = optax.ctc_loss(logits, logit_paddings, labels, label_paddings).mean()
        return loss, new_model_state['batch_stats']

    (loss, new_batch_stats), grads = jax.value_and_grad(loss_fn_with_batch_stats, has_aux=True)(params, images, labels)
    # Updated gradient clipping threshold (from ±1.0 to ±0.5)
    grads = jax.tree_util.tree_map(lambda g: jnp.clip(g, -0.5, 0.5) if g is not None else 0, grads)

    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    return params, new_batch_stats, opt_state, loss


def compute_cer(preds, targets):
    total_chars, total_errors = 0, 0

    for pred, target in zip(preds, targets):
        pred_text = ''.join([index_to_char[i] for i in pred if i in index_to_char])
        target_text = ''.join([index_to_char[i] for i in target if i in index_to_char])

        total_chars += len(target_text)
        total_errors += editdistance.eval(pred_text, target_text)  # Using editdistance

    return total_errors / total_chars if total_chars > 0 else float("inf")

def greedy_decode(preds):
    decoded = []
    prev_char = None
    for i in preds:
        if i == 0:  # Skip blank characters (CTC blank)
            prev_char = None
            continue
        if i != prev_char:  # Avoid repeated characters (CTC rule)
            decoded.append(index_to_char.get(i, "?"))
        prev_char = i
    return "".join(decoded).strip()  # Remove leading/trailing spaces


def evaluate_model(dataset):
    total_loss, batch_count = 0, 0
    all_preds, all_targets = [], []

    for images, labels, label_lengths in jax_dataloader(dataset, batch_size=32, shuffle=False):
        loss, _ = loss_fn_with_batch_stats(params, images, labels)  # Correct loss function call
        total_loss += float(loss)  # Convert to float for accumulation
        batch_count += 1

        logits = model.apply({'params': params, 'batch_stats': batch_stats}, images, train=False, mutable=False)
        preds = jnp.argmax(logits, axis=-1).tolist()

        all_preds.extend(preds)
        all_targets.extend(labels.tolist())

    # Print a few decoded results
    for i in range(min(10, len(all_preds))):  # Avoid index errors if dataset is small
        pred_text = greedy_decode(all_preds[i])  # Use improved CTC decoding
        target_text = "".join([index_to_char.get(c, "?") for c in all_targets[i] if c > 0])  # Ignore padding

        print(f"Decoded Target: {target_text}")
        print(f"Decoded Prediction: {pred_text}")
        print("=" * 40)

    # Compute CER
    cer = compute_cer(all_preds, all_targets)
    avg_loss = total_loss / batch_count if batch_count > 0 else float("inf")

    return avg_loss, cer


# ======= Training Configuration =======
img_dir = "/content/handwriting_dataset_1/handwriting_dataset_1.zip/iam_words/words/"
label_file = "/content/handwriting_dataset_1/handwriting_dataset_1.zip/iam_words/words.txt"

dataset = IAMDataset(img_dir, label_file)
split_idx = int(0.8 * len(dataset))

train_dataset = IAMDataset(img_dir, label_file)
val_dataset = IAMDataset(img_dir, label_file)
train_dataset.valid_samples = dataset.valid_samples[:split_idx]
val_dataset.valid_samples = dataset.valid_samples[split_idx:]

num_classes = len(CHARS) + 1
model = CRNN(img_height=32, num_classes=num_classes, lstm_hidden_size=512)
rng = jax.random.PRNGKey(0)
dummy_input = jnp.ones((1, 32, 128, 1))
print("Dummy input shape:", dummy_input.shape)

variables = model.init({'params': rng, 'dropout': jax.random.PRNGKey(1)}, dummy_input, train=True)
params = variables['params']  # Extract model parameters
batch_stats = variables['batch_stats']  # Extract BatchNorm statistics

schedule = optax.exponential_decay(
    init_value=5e-4,  # Start lower
    transition_steps=500,
    decay_rate=0.95
)

optimizer = optax.adam(schedule)
opt_state = optimizer.init(params)  # Now params is defined


# ======= Train Model Function =======
def train_model(train_dataset, val_dataset, epochs=10):
    global params, batch_stats, opt_state
    for epoch in range(epochs):
        epoch_loss, batch_count = 0, 0

        for images, labels, label_lengths in jax_dataloader(train_dataset, batch_size=32):
            params, batch_stats, opt_state, loss = train_step(params, batch_stats, opt_state, images, labels)
            epoch_loss += float(loss)
            batch_count += 1

        avg_train_loss = epoch_loss / batch_count if batch_count > 0 else float("inf")
        val_loss, cer = evaluate_model(val_dataset)

        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}, CER: {cer:.4f}")


# ======= Train Model =======
train_model(train_dataset, val_dataset, epochs=40)

Dummy input shape: (1, 32, 128, 1)
Decoded Target: "
Decoded Prediction: 
Decoded Target: my
Decoded Prediction: 
Decoded Target: background
Decoded Prediction: 
Decoded Target: is
Decoded Prediction: 
Decoded Target: a
Decoded Prediction: 
Decoded Target: doctor
Decoded Prediction: 
Decoded Target: of
Decoded Prediction: 
Decoded Target: 68
Decoded Prediction: 
Decoded Target: ,
Decoded Prediction: 
Decoded Target: who
Decoded Prediction: 
Epoch [1/40], Loss: 75.2767, Val Loss: 13.9033, CER: 1.0022
Decoded Target: "
Decoded Prediction: "
Decoded Target: my
Decoded Prediction: 
Decoded Target: background
Decoded Prediction: ,
Decoded Target: is
Decoded Prediction: 
Decoded Target: a
Decoded Prediction: 
Decoded Target: doctor
Decoded Prediction: 
Decoded Target: of
Decoded Prediction: 
Decoded Target: 68
Decoded Prediction: 
Decoded Target: ,
Decoded Prediction: 
Decoded Target: who
Decoded Prediction: 
Epoch [2/40], Loss: 71.7740, Val Loss: 13.3718, CER: 0.9975
Decoded Target: "
Decod