# Deep Dive: CRNN OCR Pipeline

This notebook provides a detailed walkthrough of the CRNN OCR pipeline, from data preparation to inference. We will visualize intermediate outputs, including the final feature map before CTC Loss, and explain the CTC Loss calculation in detail.

**Core Steps:**
1.  **Data Preparation**: Image processing and Label encoding.
2.  **Model Forward**: CNN Features -> Reshape -> RNN -> Logits.
3.  **Visualization**: Heatmap of the output Logits (Time vs Class).
4.  **CTC Loss**: Understanding the inputs and mechanism.
5.  **Inference**: Greedy Decoding.

## 1. Setup and Imports

In [None]:
import os
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import sys

# Add src to path to import modules
sys.path.append('src')

from dataset import OCRDataset, get_vocab
from model import CRNN
from utils import decode_greedy

# Configuration
DATA_ROOT = 'data'
CHECKPOINT_PATH = 'checkpoints/best_model.pth'
IMG_HEIGHT = 32
IMG_WIDTH = 100 # Consistent with training
DEVICE = torch.device('cpu') # Use CPU for easier debugging/visualization

## 2. Data Preparation (Train/Val Split & Encoding)

**Goal**: Prepare the raw image and label for the network.

-   **Image**: Resize to (32, 100), Convert to Grayscale (1 channel), Normalize (-1, 1).
-   **Label**: Convert characters to integers using a Vocabulary Mapping.

The `OCRDataset` class handles this via `__getitem__`.

In [None]:
# 1. Get Vocabulary
vocab = get_vocab(os.path.join(DATA_ROOT, 'trainset'), os.path.join(DATA_ROOT, 'testset'))
print(f"Vocabulary ({len(vocab)}): {vocab}")

# 2. Initialize Dataset
train_ds = OCRDataset(os.path.join(DATA_ROOT, 'trainset'), vocab, height=IMG_HEIGHT, width=IMG_WIDTH)
idx2char = train_ds.idx2char

# 3. Get a sample
idx = 0  # Change this to see different samples
image_tensor, label_tensor, label_len = train_ds[idx]

print(f"\nSample Index: {idx}")
print(f"Image Shape: {image_tensor.shape} (Channel, Height, Width)")
print(f"Label Tensor: {label_tensor}")
print(f"Label Length: {label_len.item()}")
print(f"Encoded Label string: {''.join([idx2char[i.item()] for i in label_tensor])}")

# Visualize
plt.imshow(image_tensor.squeeze(), cmap='gray')
plt.title(f"Preprocessed Image Input (Label: {''.join([idx2char[i.item()] for i in label_tensor])})")
plt.axis('on')
plt.show()

## 3. Model Forward Pass & Visualization

We will load the trained model and pass this single image through it. However, instead of just getting the final output, we will inspect the shapes at each major block.

### Pipeline Stages:
1.  **CNN Backbone**: Extracts visual features. Input: `(B, 1, 32, 100)` -> Output: `(B, 512, 1, 26)`.
    *   *Note*: The Height became 1 (vertical Information compressed), Width became 26 (Horizontal/Time information compressed but preserved).
    *   The width 26 corresponds to the number of "Timesteps" (T).
2.  **Reshape**: Prepare for RNN sequence. `(B, 512, 1, 26)` -> `(26, B, 512)` (Seq, Batch, Feature).
3.  **RNN Head**: Contextual sequence modeling. Output: `(26, B, Hidden*2)`.
4.  **Linear Projection**: Map to classes. Output: `(26, B, NumClasses)`.

In [None]:
# Load Model
n_class = len(vocab) + 1 # +1 for blank
model = CRNN(IMG_HEIGHT, 1, n_class, 256).to(DEVICE)
model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=DEVICE))
model.eval()

# Prepare Input Batch (Batch Size = 1)
input_batch = image_tensor.unsqueeze(0).to(DEVICE) # [1, 1, 32, 100]

print("---- Forward Pass Tracking ----")
print(f"1. Input: {input_batch.shape}")

# A. CNN Pass
with torch.no_grad():
    cnn_features = model.cnn(input_batch)
print(f"2. CNN Output features: {cnn_features.shape} (Batch, Channel, Height, Width)")

# B. Pre-RNN Reshape
b, c, h, w = cnn_features.size()
assert h == 1, "Height must be 1"
reshaped_features = cnn_features.squeeze(2) # [B, C, W]
reshaped_features = reshaped_features.permute(2, 0, 1) # [W, B, C] -> [Time, Batch, Feature]
print(f"3. Input to RNN (Permuted): {reshaped_features.shape} (Time, Batch, Feature)")

# C. RNN Pass
with torch.no_grad():
    rnn_output = model.rnn(reshaped_features)
print(f"4. Model Output (Logits): {rnn_output.shape} (Time, Batch, NumClasses)")

## 4. Visualizing the Output (Features/Logits)

This is the requested plot: **The feature map / output of the final layer before CTC Loss**.

We have a matrix of shape `[Time (26), Class (39)]`. Each column represents a timestep, and each row represents the probability of a specific character at that timestep.

**How to read the plot:**
-   **Y-axis**: The Characters (Classes). Index 0 is usually `Blank`.
-   **X-axis**: Time steps (0 to 25).
-   **Color**: Probability (Softmax) of that character at that time.
-   **Bright Spots**: The model is confident that character is present at that rough horizontal location in the image.

In [None]:
# Apply Softmax to convert Logits to Probabilities for visualizing
probs = rnn_output.softmax(dim=2).squeeze(1).numpy() # [Time, Classes]
probs = probs.T # Transpose for plotting: [Classes, Time]

# Setup Plot
plt.figure(figsize=(15, 8))
plt.imshow(probs, aspect='auto', cmap='viridis', interpolation='nearest')

# Annotate Y-axis with characters
tick_locs = np.arange(n_class)
tick_labels = ['<BLANK>'] + [idx2char[i] for i in range(1, n_class)]
plt.yticks(tick_locs, tick_labels, fontsize=8)

plt.xlabel("Timesteps (Output Width)")
plt.ylabel("Classes (Characters)")
plt.title(f"Probability Heatmap (Model Output)\nTrue Label: {''.join([idx2char[i.item()] for i in label_tensor])}")
plt.colorbar(label="Probability")
plt.show()

## 5. CTC Loss Calculation (Deep Dive)

**The Problem**: Our output has 26 timesteps, but our label "cat" only has 3 characters. We don't know *exactly* which timestep aligns to 'c', 'a', or 't'.

**CTC Solution**: Sum the probabilities of **ALL valid paths** that decode to "cat".

**Valid Paths**: To transform a path to a label:
1.  Collapse repeated characters: `cc` -> `c`
2.  Remove blanks: `_` -> (nothing)
3.  Example paths for "cat":
    -   `c_a_t`
    -   `cc_at`
    -   `_caat`

**The Formula**:
Loss = -log(P(Label | Input))
Where P(Label | Input) = Sum(Product of probabilities for each step in a valid path).

**Inputs to PyTorch `ctc_loss`**:
1.  `log_probs`: `(Time, Batch, Classes)` [Log Softmax values]
2.  `targets`: Flat 1D tensor of all labels concatenated.
3.  `input_lengths`: Vector size `(Batch,)` containing actual length of each Time sequence (usually all 26 here).
4.  `target_lengths`: Vector size `(Batch,)` containing length of each label (e.g., 3 for "cat").

In [None]:
criterion = nn.CTCLoss(blank=0, zero_infinity=True)

# 1. Prepare Log Probs (Required by PyTorch CTC)
log_probs = rnn_output.log_softmax(2) # [Time, Batch, Class]
print(f"Log Probs Shape: {log_probs.shape}")

# 2. Prepare Targets
targets = label_tensor # [Label Len]
print(f"Targets: {targets}")

# 3. Input Lengths (The 'T' dimension)
T = log_probs.size(0)
input_lengths = torch.LongTensor([T]) # [26]
print(f"Input Lengths: {input_lengths}")

# 4. Target Lengths
target_lengths = label_len # [Len]
print(f"Target Lengths: {target_lengths}")

# Calculate Loss
loss = criterion(log_probs, targets, input_lengths, target_lengths)
print(f"\nCalculated CTC Loss: {loss.item():.4f}")

## 6. Inference & Greedy Decoding

How do we turn the heatmap back into text?

**Greedy Decoding**:
1.  Take the index with the maximum probability at each timestep (`argmax`).
2.  This gives a "raw path".
3.  **CTC Collapse Rule**: Remove repeated adjacent duplicates, THEN remove blanks.

*Example*:
-   Raw Path Indices: `[0, 3, 3, 0, 1, 1, 0, 20] `  (Assuming 0=Blank, 3=c, 1=a, 20=t)
-   Raw String: `_cc_aa_t`
-   Collapse Repeats (Group): `_`, `c`, `_`, `a`, `_`, `t`
-   Remove Blanks: `c`, `a`, `t` -> "cat"

In [None]:
# 1. Argmax path
preds_indices = rnn_output.argmax(2).squeeze(1).tolist() # [Time]
print(f"Raw Path Indices (Best per timestep): {preds_indices}")

# 2. Convert to Chars (for visualization)
# Note: We handle 0 as <BLANK>
raw_chars = []
for idx in preds_indices:
    if idx == 0:
        raw_chars.append('-') # Using '-' for blank visualization
    else:
        raw_chars.append(idx2char[idx])
print(f"Raw Path String: {''.join(raw_chars)}")

# 3. Decoding Logic (Included in utils.decode_greedy)
decoded_text = decode_greedy(rnn_output, idx2char)[0]
print(f"\nFinal Decoded Text: {decoded_text}")
print(f"Ground Truth: {''.join([idx2char[i.item()] for i in label_tensor])}")

if decoded_text == ''.join([idx2char[i.item()] for i in label_tensor]):
    print("\nSUCCESS: Prediction matches Ground Truth!")
else:
    print("\nMISMATCH: Prediction incorrect.")