In [1]:
import os
from PIL import Image, ImageFilter
import numpy as np
import random

# Binarization
def binarize_image(image):
    img = image.convert('L')  # Convert image to grayscale
    # Binarize using a threshold
    threshold = 128
    binary_img = img.point(lambda p: 255 if p > threshold else 0)
    return binary_img

# Deskewing
def deskew_image(image):
    # Convert to NumPy array
    img_array = np.array(image)

    # Find the coordinates of the black (non-background) pixels
    coords = np.column_stack(np.where(img_array == 0))
    if len(coords) == 0:
        return image  # If there are no black pixels, return the original image

    # Calculate the angle of skew
    angle = np.rad2deg(np.arctan2(coords[:, 1].mean() - img_array.shape[0] / 2, coords[:, 0].mean() - img_array.shape[1] / 2))

    if angle < -45:
        angle = -(90 + angle)
    else:
        angle = -angle

    # Rotate the image
    pil_image = image.rotate(angle, expand=True)  # Use PIL to rotate
    return pil_image

# Noise removal
def remove_noise(image):
    # Apply a Gaussian blur to reduce noise
    denoised_image = image.filter(ImageFilter.GaussianBlur(radius=1))
    return denoised_image

# Preprocess an image
def preprocess_image(img):
    if isinstance(img, str):  # If the input is a file path, open the image
        img = Image.open(img)

    # Binarize the image (assuming this function is defined)
    binary_img = binarize_image(img)

    # Deskew the image (assuming this function is defined)
    deskewed_img = deskew_image(binary_img)

    # Resize the image to a fixed size (for example 1024x1024)
    resized_img = deskewed_img.resize((1024, 1024))

    return resized_img


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import os
from PIL import Image, ImageEnhance, ImageFilter
import numpy as np
import random

# Load and preprocess data
def load_washington_data(images_folder, transcription_file):
    images = []
    transcriptions = []

    # Load images as PIL Image objects
    for filename in os.listdir(images_folder):
        if filename.endswith('.png'):
            img_path = os.path.join(images_folder, filename)
            img = Image.open(img_path)  # Open image using PIL
            images.append(img)

    # Load transcriptions
    with open(transcription_file, 'r') as f:
        transcriptions = [line.strip() for line in f.readlines()]

    return images, transcriptions

import random
from PIL import Image, ImageEnhance
from torchvision import transforms

# Define the AdvancedAugmentation class
class AdvancedAugmentation:
    def __init__(self):
        self.transforms = transforms.Compose([
            transforms.RandomRotation(15),  # Rotate images randomly within a 15-degree range
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),  # Randomly change brightness, contrast, saturation, and hue
            transforms.RandomAffine(translate=(0.1, 0.1), degrees=10, shear=10)  # Random affine transformations
        ])

    def __call__(self, img):
        return self.transforms(img)

# Function to augment images
def augment_images_with_labels(images, transcriptions):
    augmented_images = []
    augmented_transcriptions = []

    # Initialize the advanced augmentation
    augmentation = AdvancedAugmentation()

    for img, transcription in zip(images, transcriptions):
        augmented = augmentation(img)  # Apply the advanced augmentation to the image
        augmented_images.append(augmented)
        augmented_transcriptions.append(transcription)  # Keep the original transcription for augmented image

    return augmented_images, augmented_transcriptions



# Split the data into train, validation, and test sets
def split_data(images, transcriptions, train_ratio=0.7, val_ratio=0.15):
    total_size = len(images)
    indices = list(range(total_size))
    random.shuffle(indices)

    train_size = int(total_size * train_ratio)
    val_size = int(total_size * val_ratio)

    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size + val_size]
    test_indices = indices[train_size + val_size:]

    train_images = [images[i] for i in train_indices]
    train_transcriptions = [transcriptions[i] for i in train_indices]

    val_images = [images[i] for i in val_indices]
    val_transcriptions = [transcriptions[i] for i in val_indices]

    test_images = [images[i] for i in test_indices]
    test_transcriptions = [transcriptions[i] for i in test_indices]

    return (train_images, train_transcriptions), (val_images, val_transcriptions), (test_images, test_transcriptions)


In [4]:
!pip install pytesseract

Collecting pytesseract
  Downloading pytesseract-0.3.13-py3-none-any.whl.metadata (11 kB)
Downloading pytesseract-0.3.13-py3-none-any.whl (14 kB)
Installing collected packages: pytesseract
Successfully installed pytesseract-0.3.13


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

def modify_resnet_for_grayscale(resnet_model):
    # Modify the first convolutional layer of ResNet-18 to accept 1 channel
    resnet_model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    return resnet_model

class OCRModel(nn.Module):
    def __init__(self, num_classes, hidden_size=128, num_layers=2, dropout_rate=0.5):
        super(OCRModel, self).__init__()

        # Load ResNet-18 and modify it for grayscale input
        resnet = models.resnet18(pretrained=True)
        resnet = modify_resnet_for_grayscale(resnet)  # Modify ResNet for grayscale

        # Freeze the parameters in the ResNet layers except the final layer
        for name, param in resnet.named_parameters():
            if 'conv1' not in name:
                param.requires_grad = False

        # Remove the fully connected layer from ResNet
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-2])

        self.combining_conv = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.batch_norm = nn.BatchNorm2d(512)  # Add BatchNorm
        self.relu = nn.ReLU()
        self.pooling = nn.AdaptiveAvgPool2d((32, 32))  # Further reduces dimensions

        # Dropout after feature extraction and convolution
        self.feature_dropout = nn.Dropout2d(dropout_rate)  # Dropout for 2D features

        # Sequence modeling using LSTM
        self.rnn = nn.LSTM(input_size=512, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout_rate)

        # Dropout after LSTM
        self.rnn_dropout = nn.Dropout(dropout_rate)

        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # Extract features using ResNet
        features = self.feature_extractor(x)
        features = F.normalize(features, p=2, dim=1)  # Normalize features

        features = F.interpolate(features, size=(128, 128), mode="bilinear", align_corners=False)
        features = self.combining_conv(features)
        features = self.batch_norm(features)  # Batch normalization
        features = self.relu(features)
        features = self.pooling(features)  # Reduce spatial dimensions

        # Apply dropout after the feature extraction
        features = self.feature_dropout(features)

        # Reshape features for LSTM input
        B, C, H, W = features.size()
        features = features.view(B, H * W, C)

        # Sequence modeling with LSTM
        rnn_out, _ = self.rnn(features)

        # Apply dropout after LSTM
        rnn_out = self.rnn_dropout(rnn_out)

        # Classify with a fully connected layer
        logits = self.fc(rnn_out)

        return logits

# Example usage:
num_classes = 97  # Adjust to your specific number of classes
model = OCRModel(num_classes=num_classes, dropout_rate=0.5)  # Set dropout rate as needed

# Example input (batch_size, channels, height, width)
input_tensor = torch.randn(16, 1, 128, 128)  # Batch size of 16, grayscale images of 128x128

# Forward pass
output = model(input_tensor)
print(output.shape)  # Should output (batch_size, sequence_length, num_classes)


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 167MB/s]


torch.Size([16, 1024, 97])


In [6]:
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from sklearn.model_selection import train_test_split

# Main workflow
images_folder = '/content/drive/My Drive/washingtondb-v1.0/data/line_images_normalized'
transcription_file = '/content/drive/My Drive/washingtondb-v1.0/ground_truth/transcription.txt'

# Load data
images, transcriptions = load_washington_data(images_folder, transcription_file)

# Preprocess images
preprocessed_images = [preprocess_image(img) for img in images]

# Augment data
augmented_images, augmented_transcriptions = augment_images_with_labels(preprocessed_images, transcriptions)
all_images = preprocessed_images + augmented_images
all_transcriptions = transcriptions * 2  # Duplicate transcriptions for augmented data

# Split data
(train_images, train_transcriptions), (val_images, val_transcriptions), (test_images, test_transcriptions) = split_data(all_images, all_transcriptions)

In [7]:
import torch
import numpy as np
import cv2
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

class CroppedImagesDataset(Dataset):
    def __init__(self, images, targets, transform=None):
        """
        Args:
            images (list): List of file paths or PIL Image objects.
            targets (list): List of transcription strings (labels).
            transform (callable, optional): Optional transforms to be applied to images.
        """
        self.images = images
        self.targets = targets
        self.transform = transform or transforms.Compose([
            transforms.Resize((128, 128)),  # Resize to 128x128
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # If the image is already a PIL Image, skip opening it
        image = self.images[idx]
        if not isinstance(image, Image.Image):
            image = Image.open(image).convert("L")  # Open as grayscale if not already PIL

        # Apply the transform
        if self.transform:
            image = self.transform(image)

        target = self.targets[idx]
        return image, target


# Step 1: Define a custom collate_fn
def collate_fn(batch):
    images, targets = zip(*batch)

    # Stack the images to create a batch (torch.stack assumes all images are the same shape)
    images_stack = torch.stack(images)

    # Use the encode_target function to encode the target strings
    encoded_targets, target_lengths = encode_target(targets)

    # Stack the targets (now they're tensors of indices)
    targets_stack = torch.nn.utils.rnn.pad_sequence(encoded_targets, batch_first=True, padding_value=char_to_index['<blank>'])

    return images_stack, targets_stack


# Assuming segmented_images and targets are already defined
train_dataset = CroppedImagesDataset(train_images, train_transcriptions)
val_dataset = CroppedImagesDataset(val_images, val_transcriptions)
test_dataset = CroppedImagesDataset(test_images, test_transcriptions)

# Create a DataLoader for batching
from torch.utils.data import DataLoader

batch_size = 16  # Adjust based on your system's capacity


data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Function to display a batch of images and their labels

# Create DataLoaders



In [8]:
import string

# Generating all printable ASCII characters (from 32 to 126 inclusive)
ascii_chars = [chr(i) for i in range(32, 127)]

# Add non-printable but commonly used ASCII characters (e.g., newline '\n')
special_chars = ['\n']

# Combine special and printable characters into the vocabulary
vocab = ['<blank>'] + ascii_chars + special_chars

# Create mapping dictionaries
char_to_index = {char: idx for idx, char in enumerate(vocab)}
index_to_char = {idx: char for idx, char in enumerate(vocab)}

# Print the index of <blank> token and the size of the vocabulary
print(f"Index of <blank>: {char_to_index['<blank>']}")
print(f"Size of vocabulary: {len(vocab)}")






Index of <blank>: 0
Size of vocabulary: 97


In [9]:
def encode_target(target_texts):
    encoded_targets = []  # List to hold encoded target texts (as tensors)
    target_lengths = []   # List to hold lengths of each encoded target text
    unknown_chars = []    # List to hold characters that are not found in the vocabulary

    for text in target_texts:
        text = str(text)  # Ensure text is a string
        encoded_text = []  # List to hold indices for the current text

        for char in text:
            if char in char_to_index:
                # Directly map the character to its index if it exists in vocab
                encoded_text.append(char_to_index[char])
            else:
                # If the character is not found, use <blank> token
                encoded_text.append(char_to_index['<blank>'])
                unknown_chars.append(char)  # Track unknown characters

        # Append the encoded text as a PyTorch tensor of long integers
        encoded_targets.append(torch.tensor(encoded_text, dtype=torch.long))

        # Append the length of the current encoded text
        target_lengths.append(len(encoded_text))

    # Optionally, log unknown characters for debugging purposes
    if unknown_chars:
        print(f"Unknown characters encountered: {set(unknown_chars)}")

    return encoded_targets, target_lengths




In [10]:
# Example target texts (strings)
target_texts = ["Hello", "World", "Test"]

# Encode target texts
encoded_targets, target_lengths = encode_target(target_texts)
print(encoded_targets)

# Print encoded outputs
for idx, encoded_text in enumerate(encoded_targets):
    print(f"Text: {target_texts[idx]}, Encoded: {encoded_text}, Length: {target_lengths[idx]}")

[tensor([41, 70, 77, 77, 80]), tensor([56, 80, 83, 77, 69]), tensor([53, 70, 84, 85])]
Text: Hello, Encoded: tensor([41, 70, 77, 77, 80]), Length: 5
Text: World, Encoded: tensor([56, 80, 83, 77, 69]), Length: 5
Text: Test, Encoded: tensor([53, 70, 84, 85]), Length: 4


In [11]:
import torch

# Example mapping of indices to characters


def decoding_target(encoded_tensors):
    decoded_texts = []

    for encoded_tensor in encoded_tensors:
        # Convert tensor to list of integers
        encoded_list = encoded_tensor.tolist()
        # Decode each index to a character
        decoded_text = ''.join([index_to_char.get(idx, '?') for idx in encoded_list])
        # Append the decoded string to the result
        decoded_texts.append(decoded_text)

    return decoded_texts

# Example usage
# Use a list of tensors for variable-length sequences
encoded_targets = [torch.tensor([40, 69, 76, 76, 79]), torch.tensor([55, 79, 82, 76, 68]), torch.tensor([52, 69, 83, 84])]


decoded_texts = decoding_target(encoded_targets)
print(decoded_texts)


['Gdkkn', 'Vnqkc', 'Sdrs']


In [12]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.6.0-py3-none-any.whl.metadata (20 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.11.9-py3-none-any.whl.metadata (5.2 kB)
Downloading torchmetrics-1.6.0-py3-none-any.whl (926 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m926.4/926.4 kB[0m [31m16.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.11.9-py3-none-any.whl (28 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.11.9 torchmetrics-1.6.0


In [13]:
num_classes = len(vocab)
print(num_classes)

97


In [14]:
def validate(model, validation_loader, criterion):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, labels in validation_loader:
            images = images.to(device)
            targets, target_lengths = encode_target(labels)
            targets = torch.nn.utils.rnn.pad_sequence(
                targets, batch_first=True, padding_value=char_to_index['<blank>']
            ).to(device)

            logits = model(images)
            logits = torch.clamp(logits, min=-10, max=10)
            input_lengths = torch.full(
                (images.size(0),), logits.size(1), dtype=torch.long
            ).to(device)
            target_lengths = torch.tensor(target_lengths, dtype=torch.long).to(device)
            log_probs = F.log_softmax(logits, dim=-1)

            loss_value = criterion(
                log_probs.permute(1, 0, 2),  # [T, N, C] for CTC
                targets,
                input_lengths,
                target_lengths,
            )
            val_loss += loss_value.item()

    return val_loss / len(validation_loader)

In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchmetrics.text import WordErrorRate  # WER metric
import time
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau


# Define the device (CUDA if available, else CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Feature extractor and OCR model
num_classes = len(vocab)

# Initialize the OCR model with both U-Net and ResNet
ocr_model = OCRModel(num_classes).to(device)

# Loss and optimizer
criterion = nn.CTCLoss(blank=char_to_index['<blank>']).to(device)  # CTC loss
optimizer = optim.Adam(ocr_model.parameters(), lr=1e-3)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)


best_val_loss = float('inf')  # Initialize to infinity


# Train the OCR model
def train_ocr_model(model, dataloader, validation_loader, loss, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for batch_idx, (images, labels) in enumerate(dataloader):

            # Input validation

            images = images.to(device)

            # Encode targets
            targets, target_lengths = encode_target(labels)
            targets = torch.nn.utils.rnn.pad_sequence(
                targets, batch_first=True, padding_value=char_to_index['<blank>']
            ).to(device)

            optimizer.zero_grad()

            # Forward pass
            logits = model(images)  # Output shape: [B, W, num_classes]

            # Clamp logits to stabilize training
            logits = torch.clamp(logits, min=-10, max=10)
            # Define lengths
            input_lengths = torch.full(
                (images.size(0),), logits.size(1), dtype=torch.long
            ).to(device)  # Sequence length
            target_lengths = torch.tensor(target_lengths, dtype=torch.long).to(device)

            # Compute log probabilities for CTC
            log_probs = F.log_softmax(logits, dim=-1)

            # Loss computation with exception handling
            try:
                loss_value = loss(
                    log_probs.permute(1, 0, 2),  # [T, N, C] for CTC
                    targets,
                    input_lengths,
                    target_lengths,
                )
                if torch.isnan(loss_value) or torch.isinf(loss_value):
                    print("NaN or Inf detected in loss. Skipping batch.")
                    continue

                # Backpropagation
                loss_value.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

                epoch_loss += loss_value.item()

            except Exception as e:
                print(f"Error during loss computation: {e}. Skipping batch.")
                continue
        model.eval()
        val_loss = validate(model, validation_loader, loss)
        scheduler.step(val_loss)
        print(f"Epoch {epoch + 1}/{num_epochs}, Average Train Loss: {epoch_loss / len(dataloader):.4f}, Total Train Loss: {epoch_loss:.4f}, Validation Loss: {val_loss:.4f}")

        # Save the best model
        global best_val_loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_ocr_model.pth")
            print("Best model saved!")

        model.train()

# Train the model
train_ocr_model(ocr_model, data_loader, val_loader, criterion, optimizer)

# Save model weights after training
torch.save(ocr_model.state_dict(), "final_ocr_model.pth")





Epoch 1/10, Average Train Loss: 2.9478, Total Train Loss: 170.9701, Validation Loss: 14.9772
Best model saved!
Epoch 2/10, Average Train Loss: 2.0626, Total Train Loss: 119.6318, Validation Loss: 14.8991
Best model saved!
Epoch 3/10, Average Train Loss: 2.0455, Total Train Loss: 118.6411, Validation Loss: 17.1279
Epoch 4/10, Average Train Loss: 2.0086, Total Train Loss: 116.4962, Validation Loss: 19.9591
Epoch 5/10, Average Train Loss: 1.9643, Total Train Loss: 113.9300, Validation Loss: 17.4745
Epoch 6/10, Average Train Loss: 1.9128, Total Train Loss: 110.9444, Validation Loss: 17.2562
Epoch 7/10, Average Train Loss: 1.8651, Total Train Loss: 108.1760, Validation Loss: 21.7246
Epoch 8/10, Average Train Loss: 1.8672, Total Train Loss: 108.2996, Validation Loss: 18.3433
Epoch 9/10, Average Train Loss: 1.8320, Total Train Loss: 106.2559, Validation Loss: 17.7067
Epoch 10/10, Average Train Loss: 1.8283, Total Train Loss: 106.0412, Validation Loss: 21.2786


In [16]:

# Decode target indices back to text
def decode_target(target):
    return ''.join([index_to_char[idx.item()] for idx in target if idx.item() != char_to_index['<blank>']])


In [17]:
import editdistance

def calculate_cer(pred_texts, target_texts):
    """
    Calculate the Character Error Rate (CER) between predicted and target texts.
    CER = Levenshtein Distance / Number of characters in the target text.
    """
    total_cer = 0
    total_chars = 0

    for pred, target in zip(pred_texts, target_texts):
        # Levenshtein distance
        edit_dist = editdistance.eval(pred, target)
        total_cer += edit_dist
        total_chars += len(target)

    cer = total_cer / total_chars if total_chars > 0 else 0
    return cer

In [18]:
import torch
import torch.nn.functional as F
from torchmetrics import WordErrorRate
from torch.nn.utils.rnn import pad_sequence

def evaluate(model, loader):
    model.eval()
    wer_metric = WordErrorRate()
    temperature =  1 # You can adjust this value based on experimentation

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)

            # Encode the labels using your encode_target function
            targets, target_lengths = encode_target(labels)

            # Pad the targets and move to device
            targets = pad_sequence(targets, batch_first=True, padding_value=char_to_index['<blank>']).to(device)
            target_lengths = torch.tensor(target_lengths, dtype=torch.long).to(device)  # Convert list to tensor and move to device

            # Forward pass through the model
            logits = model(images)

            # Apply softmax with temperature scaling (logits are of shape [batch_size, seq_len, num_classes])
            probs = torch.softmax(logits / temperature, dim=-1)  # Apply temperature to logits

            # Get the most probable character indices
            pred_indices = torch.argmax(probs, dim=-1)  # shape (batch_size, sequence_length)

            # Convert indices to characters (ensure you have a mapping from indices to characters)
            predicted_texts = []
            for indices in pred_indices:
                predicted_text = ''.join([vocab[i.item()] for i in indices])
                predicted_texts.append(predicted_text)

            # Decode the target labels to text
            target_strings = [decode_target(target) for target in targets]  # Process each target sequence

            # Print out the predictions and targets
            for pred, target in zip(predicted_texts, target_strings):
                print(f"Predicted: {pred}")
                print(f"Ground Truth: {target}")
                print("-" * 50)

            cer = calculate_cer(decoded_texts, target_strings)
            print(f"CER: {cer:.2f}")

            # Use WER metric (works with strings now)
            wer = wer_metric(predicted_texts, target_strings)
            print(f"WER: {wer.item():.2f}")

# Assuming you have a model and a test loader
evaluate(ocr_model, test_loader)  # This would call the evaluate function with your OCR model and test loader




Predicted: tennsoor(([[2<blank>,,,,,,,,,,,,,,,,,,,,,,,,,,      <blank><blank>   <blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><blank><bl

In [19]:
output_file = "/content/drive/My Drive/justresnetall.txt"

def process_transcription_and_save_separately(transcription_output, output_file):
    """
    Save all transcription outputs into a single text file.

    Args:
        transcription_output (list): A list of transcription strings.
        output_file (str): The file where all transcriptions will be saved.
    """
    with open(output_file, 'w') as f:
        for idx, line in enumerate(transcription_output):
            f.write(line)
    print(f"All transcriptions saved to {output_file}")

In [20]:
import torch
import torch.nn.functional as F
from torchmetrics import WordErrorRate
from torch.nn.utils.rnn import pad_sequence

def decode_ctc(logits, char_to_index, index_to_char):
    batch_size = logits.size(0)
    seq_len = logits.size(1)

    if logits.dim() == 2:  # Shape (batch_size, seq_len)
        pred_indices = logits  # Assuming logits are already the predicted class indices
    elif logits.dim() == 3:
        _, pred_indices = logits.max(dim=-1)  # Get indices with highest probability (batch_size, seq_len)
    else:
        raise ValueError("Unexpected logits shape")

    decoded_texts = []
    for batch_idx in range(batch_size):
        text = []
        prev_idx = -1

        for seq_idx in range(seq_len):
            idx = pred_indices[batch_idx, seq_idx].item()  # Convert to Python integer

            # Skip <blank> and consecutive duplicates
            if idx != char_to_index['<blank>'] and idx != prev_idx:
                text.append(index_to_char[idx])  # Convert index to character
            prev_idx = idx

        decoded_texts.append(''.join(text))

    return decoded_texts


def evaluate(model, loader, vocab, char_to_index, index_to_char):
    model.eval()
    wer_metric = WordErrorRate()
    temperature =  1  # You can adjust this value based on experimentation

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)

            # Encode the labels using your encode_target function
            targets, target_lengths = encode_target(labels)

            # Pad the targets and move to device
            targets = pad_sequence(targets, batch_first=True, padding_value=char_to_index['<blank>']).to(device)
            target_lengths = torch.tensor(target_lengths, dtype=torch.long).to(device)  # Convert list to tensor and move to device

            # Forward pass through the model
            logits = model(images)

            # Apply softmax with temperature scaling (logits are of shape [batch_size, seq_len, num_classes])
            probs = torch.softmax(logits / temperature, dim=-1)  # Apply temperature to logits


            # Decode the logits using CTC decoding
            decoded_texts = decode_ctc(probs, char_to_index, index_to_char)  # Using decode_ctc function

            # Decode the target labels to text
            target_strings = [decode_target(target) for target in targets]  # Process each target sequence

            # Print out the predictions and targets
            for pred, target in zip(decoded_texts, target_strings):
                print(f"Predicted: {pred}")
                print(f"Ground Truth: {target}")
                print("-" * 50)
                process_transcription_and_save_separately(pred, output_file)

            # Use WER metric (works with strings now)
            wer = wer_metric(decoded_texts, target_strings)
            print(f"WER: {wer.item():.2f}")

# Assuming you have a model, test loader, vocab, char_to_index, and index_to_char
evaluate(ocr_model, test_loader, vocab, char_to_index, index_to_char)  # This would call the evaluate function with your OCR model and test loader


Predicted: tensor([2,  
Ground Truth: 277-21 Y-o-u|a-r-e|h-e-r-e-b-y|o-r-d-e-r-e-d-s_cm|a-s|s-o-o-n|a-s
--------------------------------------------------
All transcriptions saved to /content/drive/My Drive/justresnetall.txt
Predicted: tensor([2, , 
Ground Truth: 277-32 o-r|a-n-y|o-t-h-e-r-s-s_cm|s_mi|o-r|P-l-a-n-t-a-t-i-o-n-s-s_pt|s_mi|L-i-e-u-t-e-n-a-n-t|F-r-a-s_mi
--------------------------------------------------
All transcriptions saved to /content/drive/My Drive/justresnetall.txt
Predicted: tensor([2, , 
Ground Truth: 276-07 h-a-v-e|C-a-t-t-l-e|d-e-l-i-v-e-r-e-d|h-e-r-e|s_et-c-s_pt|b-y|t-h-e|s_1st-s_pt|o-f|n-e-x-t
--------------------------------------------------
All transcriptions saved to /content/drive/My Drive/justresnetall.txt
Predicted: tensor([2, , 
Ground Truth: 273-25 t-h-e|m-e-n|c-o-m-p-l-e-t-e-l-y|f-u-r-n-i-s-h-e-d|w-i-t-h|b-o-t-h-s_sq|a-n-d
--------------------------------------------------
All transcriptions saved to /content/drive/My Drive/justresnetall.txt
Predict