In [12]:
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import os
import sys
from torchvision import transforms
import torch.nn.functional as F
import pandas as pd
from collections import OrderedDict
import cv2

sys.path.append('./')
from modules.model import Model
from modules.utils import CTCLabelConverter, AttnLabelConverter

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cpu


In [13]:
class NepaliTextRecognitionConfig:
    def __init__(self):
        self.number = '०१२३४५६७८९0123456789'
        self.symbol = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{}~।॥—‘’“”… "
        self.lang_char = 'अआइईउऊऋएऐओऔअंअःकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसहक्षत्रज्ञािीुूृेैोौंःँॅॉ'
        
        self.character = self.number + self.symbol + self.lang_char
        print(f"Total characters: {len(self.character)}")
        print(f"Numbers: {self.number}")
        print(f"Language characters sample: {self.lang_char}...")
        
        self.Transformation = 'None'
        self.FeatureExtraction = 'ResNet'
        self.SequenceModeling = 'BiLSTM'
        self.Prediction = 'Attn'
        
        self.imgH = 80  
        self.imgW = 1220
        self.input_channel = 3
        self.output_channel = 256
        self.hidden_size = 256
        self.num_fiducial = 20
        
        self.batch_max_length = 200
        self.sensitive = True
        self.PAD = False
        self.rgb = True
        self.contrast_adjust = False
        
        self.decode = 'greedy' #beam search
        self.num_class = len(self.character)

def print_character_set(config):
    print("Character Set Analysis:")
    print(f"Total unique characters: {len(config.character)}")
    print(f"Numbers: {len(config.number)} characters")
    print(f"Symbols: {len(config.symbol)} characters") 
    print(f"Language characters: {len(config.lang_char)} characters")

config = NepaliTextRecognitionConfig()
print_character_set(config)
print(config.character)

Total characters: 132
Numbers: ०१२३४५६७८९0123456789
Language characters sample: अआइईउऊऋएऐओऔअंअःकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसहक्षत्रज्ञािीुूृेैोौंःँॅॉ...
Character Set Analysis:
Total unique characters: 132
Numbers: 20 characters
Symbols: 40 characters
Language characters: 72 characters
०१२३४५६७८९0123456789!"#$%&'()*+,-./:;<=>?@[\]^_`{}~।॥—‘’“”… अआइईउऊऋएऐओऔअंअःकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसहक्षत्रज्ञािीुूृेैोौंःँॅॉ


In [14]:
def load_crnn_model(model_path, config):
    if config.Prediction == "CTC":
        converter = CTCLabelConverter(config.character)
    elif config.Prediction == "Attn":
        converter = AttnLabelConverter(config.character)
    config.num_class = len(converter.character)

    # print(f"Model Configuration:")
    # print(f"  Input size: {config.imgH}x{config.imgW}")
    # print(f"  Input channels: {config.input_channel}")
    # print(f"  Number of classes: {config.num_class}")
    # print(f"  Transformation: {config.Transformation}")
    # print(f"  Feature Extraction: {config.FeatureExtraction}")
    # print(f"  Sequence Modeling: {config.SequenceModeling}")
    # print(f"  Prediction: {config.Prediction}")
    
    model = Model(config)
    
    if os.path.exists(model_path):
        #print(f"Loading model from: {model_path}")
        state_dict = torch.load(model_path, map_location=device)
        
        if isinstance(state_dict, dict) and 'state_dict' in state_dict:
            state_dict = state_dict['state_dict']
        
        if all(key.startswith('module.') for key in state_dict.keys()):
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:]
                new_state_dict[name] = v
            state_dict = new_state_dict
        try:
            model.load_state_dict(state_dict)
            #print("✓ Model loaded successfully!")
        except Exception as e:
            print(f"Error loading state dict: {e}")
            print("Trying to load with strict=False...")
            model.load_state_dict(state_dict, strict=False)
            print("✓ Model loaded with strict=False")
            
    else:
        raise FileNotFoundError(f"Model file not found: {model_path}")
    
    model = model.to(device)
    model.eval()
    
    return model, converter

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def center_and_resize_image(img, target_size=(1220, 80)):
    """
    Resize the image to fit inside target_size while maintaining aspect ratio.
    If the image is smaller, center it on a black background of target_size.
    """
    if isinstance(img, str):
       img = Image.open(img)
    target_w, target_h = target_size
    img_w, img_h = img.size

    if img_w > target_w or img_h > target_h:
        img.thumbnail((target_w, target_h), Image.LANCZOS)
    new_img = Image.new("RGB", (target_w, target_h), color="black")
    paste_x = (target_w - img.width) // 2
    paste_y = (target_h - img.height) // 2
    new_img.paste(img, (paste_x, paste_y))

    return new_img

def preprocess_crnn_image(image_path, config, return_tensor=True):
    """
    Complete preprocessing pipeline for the text recognition model.
    
    Args:
        image_path: Path to image or PIL Image
        config: Model configuration object
        return_tensor: Whether to return PyTorch tensor or PIL Image
    
    Returns:
        Preprocessed image tensor or PIL Image
    """
    input_channels = getattr(config, 'input_channel', 1)
    
    processed_pil = center_and_resize_image(image_path, (config.imgW, config.imgH))
    
    if not return_tensor:
        return processed_pil
    
    image_np = np.array(processed_pil)
    
    image_np = image_np.astype(np.float32) / 255.0
    
    if getattr(config, 'contrast_adjust', False):
        image_np = (image_np - np.mean(image_np)) / (np.std(image_np) + 1e-8)
    
    if input_channels == 1:
        image_tensor = torch.from_numpy(image_np).unsqueeze(0) 
        image_tensor = (image_tensor - 0.5) / 0.5
    else:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        image_tensor = transform(processed_pil)
    
    # Add batch dimension: (1, C, H, W)
    image_tensor = image_tensor.unsqueeze(0)
    
    return image_tensor

def visualize_crnn_preprocessing(image_path, config):

    """Visualize the original and preprocessed image for text"""
    
    original_image = Image.open(image_path)
    if config.rgb:
        original_image = original_image.convert('RGB')
    else:
        original_image = original_image.convert('L')
    
    processed_tensor = preprocess_crnn_image(image_path, config)
    processed_image = processed_tensor.squeeze(0).permute(1, 2, 0).numpy()
    
    if config.rgb:
        # Denormalize for visualization
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        processed_image = processed_image * std + mean
        processed_image = np.clip(processed_image, 0, 1)
    else:
        processed_image = (processed_image * 0.5) + 0.5
        processed_image = np.clip(processed_image, 0, 1)
    
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    
    # Original image
    axes[0].imshow(original_image, cmap='gray' if not config.rgb else None)
    axes[0].set_title(f'Original Image\nSize: {original_image.size}')
    axes[0].axis('off')
    
    # Processed image
    if config.rgb:
        axes[1].imshow(processed_image)
    else:
        axes[1].imshow(processed_image[:, :, 0], cmap='gray')
    axes[1].set_title(f'Preprocessed Image\nSize: {config.imgW}x{config.imgH}')
    axes[1].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    #print(f"Original image size: {original_image.size}")
    #print(f"Processed tensor shape: {processed_tensor.shape}")

In [None]:
def inference_crnn_single_image(model, converter, image_path, config):
    
    image_tensor = preprocess_crnn_image(image_path, config)
    image_tensor = image_tensor.to(device)
    batch_size = image_tensor.size(0)
    text_for_pred = torch.LongTensor(batch_size, config.batch_max_length).fill_(0).to(device)
    with torch.no_grad():
        #preds = model(image_tensor, text_for_pred)
        preds = model(image_tensor, text_for_pred, is_train=False)
        # Greedy decoding for CTC
        if config.Prediction == "CTC":
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            _, preds_index = preds.max(2)
            preds_index = preds_index.view(-1)
            preds_str = converter.decode_greedy(preds_index.data, preds_size.data)
            return preds_str
        elif config.Prediction == "Attn":
            # remove last timestep to align with loss logic
            preds = preds[:, :config.batch_max_length - 1, :]

            # greedy index selection
            _, preds_index = preds.max(2)

            # decode raw strings
            preds_str = converter.decode(preds_index, torch.IntTensor([config.batch_max_length] * batch_size))

            # softmax + max prob (same as your snippet)
            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)

            final_result = []
            for pred, pred_max_prob in zip(preds_str, preds_max_prob):
                eos_pos = pred.find('[s]')
                if eos_pos != -1:
                    pred = pred[:eos_pos]
                    pred_max_prob = pred_max_prob[:eos_pos]

                final_result.append(pred)

            return final_result

    #preds_str = converter.decode_beamsearch(preds, beamWidth=4)
    #preds_prob = F.softmax(preds, dim=2)
    #preds_max_prob, _ = preds_prob.max(dim=2)
 

    return None
 
all_entries = os.listdir("crops")
files_only = [os.path.join("crops", entry) for entry in all_entries if os.path.isfile(os.path.join("crops", entry))]

print(config.Prediction)
for index, string in enumerate(files_only):
    if index == 1:
        break
    visualize_crnn_preprocessing(string, config)
    model, converter = load_crnn_model("models/iter_60000.pth",config)
    print(inference_crnn_single_image(model,converter, string,config))

In [11]:
import torch.onnx
import onnx
from onnx import checker
#CTC
def export_crnn_to_onnx(model, converter, config, output_path, input_shape=None,batch_size=1, opset_version=11):
    """
    Export CRNN text recognition model to ONNX format.
    
    Args:
        model: The PyTorch CRNN model
        converter: The label converter (CTCLabelConverter or AttnLabelConverter)
        config: Model configuration object
        output_path (str): Path where to save the ONNX model
        input_shape (tuple): Input tensor shape (batch_size, channels, height, width)
                            Default: (1, 3, 80, 1220) based on config
        opset_version (int): ONNX opset version (default: 11)
    
    Returns:
        str: Path to the exported ONNX model
    """
    model.eval()
    
    if input_shape is None:
        input_shape = (batch_size, config.input_channel, config.imgH, config.imgW)
    print(input_shape)
    #dummy_input = torch.randn(input_shape, device=device)
    dummy_input = torch.randn(batch_size, config.input_channel, config.imgH, config.imgW).to(device)
    dummy_text = torch.LongTensor(batch_size, config.batch_max_length).fill_(0).to(device) #CTC
    #dummy_text = torch.LongTensor(1, config.batch_max_length +1).fill_(0).to(device)# Attn

    with torch.no_grad():
        torch.onnx.export(
            model,
            (dummy_input, dummy_text),
            output_path,
            export_params=True,
            opset_version=opset_version,
            #do_constant_folding=True,
            input_names=['input', 'text'],
            output_names=['output'],
            # dynamic_axes={
            #    'input': {0: 'batch_size'},
            #    'output': {0: 'batch_size'}
            # },
            verbose=False
        )
    
    print(f"Model exported to ONNX format: {output_path}")
    
    try:
        onnx_model = onnx.load(output_path)
        checker.check_model(onnx_model)
        print("✓ ONNX model validation passed!")
        
        print(f"ONNX Model Information:")
        print(f"  Input shape: {input_shape}")
        print(f"  Output classes: {len(converter.character)}")
        print(f"  Opset version: {opset_version}")
        
        return output_path
        
    except Exception as e:
        print(f"❌ ONNX model validation failed: {e}")
        return None

model, converter = load_crnn_model("models/iter_60000.pth",config)

export_crnn_to_onnx(model, converter,config,"models/ResNetBiLSTMCTCv1Batch16.onnx", batch_size=16)

No Transformation module specified
(16, 3, 80, 1220)
Model exported to ONNX format: models/ResNetBiLSTMCTCv1Batch16.onnx
✓ ONNX model validation passed!
ONNX Model Information:
  Input shape: (16, 3, 80, 1220)
  Output classes: 133
  Opset version: 11


'models/ResNetBiLSTMCTCv1Batch16.onnx'

In [16]:
import torch.onnx
import onnx
from onnx import checker
#Attn
def export_crnn_to_onnx(model, converter, config, output_path, input_shape=None, opset_version=11,batch_size=1):
    """
    Export CRNN text recognition model to ONNX format.
    
    Args:
        model: The PyTorch CRNN model
        converter: The label converter (CTCLabelConverter or AttnLabelConverter)
        config: Model configuration object
        output_path (str): Path where to save the ONNX model
        input_shape (tuple): Input tensor shape (batch_size, channels, height, width)
                            Default: (1, 3, 80, 1220) based on config
        opset_version (int): ONNX opset version (default: 11)
    
    Returns:
        str: Path to the exported ONNX model
    """
    model.eval()
    
    if input_shape is None:
        input_shape = (batch_size, config.input_channel, config.imgH, config.imgW)
    print(input_shape)
    #dummy_input = torch.randn(input_shape, device=device)
    dummy_input = torch.randn(batch_size, config.input_channel, config.imgH, config.imgW).to(device)
    #dummy_text = torch.LongTensor(1, config.batch_max_length).fill_(0).to(device) #CTC
    dummy_text = torch.LongTensor(batch_size, config.batch_max_length +1).fill_(0).to(device)# Attn

    with torch.no_grad():
        torch.onnx.export(
            model,
            (dummy_input),
            output_path,
            export_params=True,
            opset_version=opset_version,
            #do_constant_folding=True,
            input_names=['input'],
            output_names=['output'],
            # dynamic_axes={
            #    'input': {0: 'batch_size'},
            #    'output': {0: 'batch_size', 1: 'sequence_length'}
            #},
            verbose=False
        )
    
    print(f"Model exported to ONNX format: {output_path}")
    
    try:
        onnx_model = onnx.load(output_path)
        checker.check_model(onnx_model)
        print("✓ ONNX model validation passed!")
        
        print(f"ONNX Model Information:")
        print(f"  Input shape: {input_shape}")
        print(f"  Output classes: {len(converter.character)}")
        print(f"  Opset version: {opset_version}")
        
        return output_path
        
    except Exception as e:
        print(f"❌ ONNX model validation failed: {e}")
        return None

model, converter = load_crnn_model("models/iter_70000.pth",config)

export_crnn_to_onnx(model, converter,config,"models/ResNetBiLSTMAttnv1Batch32.onnx",batch_size=32)

No Transformation module specified
(32, 3, 80, 1220)
False
Model exported to ONNX format: models/ResNetBiLSTMAttnv1Batch32.onnx
✓ ONNX model validation passed!
ONNX Model Information:
  Input shape: (32, 3, 80, 1220)
  Output classes: 134
  Opset version: 11


'models/ResNetBiLSTMAttnv1Batch32.onnx'

In [None]:

import onnxruntime as ort
import torch
import torch.nn.functional as F
import numpy as np
import os

def load_onnx_model(onnx_path):
    """Load ONNX model as a runtime session."""
    session = ort.InferenceSession(onnx_path)
    return session

def inference_crnn_single_image_onnx(session, converter, image_path, config):
    # Preprocess image
    image_tensor = preprocess_crnn_image(image_path, config)  # [1, C, H, W]
    image_np = image_tensor.cpu().numpy().astype(np.float32)

    batch_size = image_np.shape[0]

    # Run ONNX inference
    if config.Prediction == "Attn":
        text_input = np.zeros((batch_size, config.batch_max_length+1), dtype=np.int64)
        outputs = session.run(None, {"input": image_np })
        preds = outputs[0]  # [batch, seq_len, num_classes]
    else:
        outputs = session.run(None, {"input": image_np })
        preds = outputs[0]  # [batch, seq_len, num_classes]

    if config.Prediction == "CTC":
        preds_size = np.array([preds.shape[1]] * batch_size, dtype=np.int32)
        preds_index = preds.argmax(axis=2).flatten()
        preds_str = converter.decode_greedy(preds_index, preds_size)
        return preds_str

    elif config.Prediction == "Attn":
        preds = preds[:, :config.batch_max_length - 1, :]
        preds_index = preds.argmax(axis=2)
        preds_str = converter.decode(torch.from_numpy(preds_index), 
                                     torch.IntTensor([config.batch_max_length] * batch_size))

        preds_prob = F.softmax(torch.from_numpy(preds), dim=2)
        preds_max_prob, _ = preds_prob.max(dim=2)

        final_result = []
        for pred, pred_max_prob in zip(preds_str, preds_max_prob):
            eos_pos = pred.find('[s]')
            if eos_pos != -1:
                pred = pred[:eos_pos]
                pred_max_prob = pred_max_prob[:eos_pos]
            final_result.append(pred)
        return final_result

    return None

# =========================
# Example usage
# =========================
model, converter = load_crnn_model("models/iter_60000.pth",config)
onnx_session = load_onnx_model("models/ResNetBiLSTMCTCv1.onnx")

all_entries = os.listdir("crops")
files_only = [os.path.join("crops", entry) for entry in all_entries if os.path.isfile(os.path.join("crops", entry))]

for index, string in enumerate(files_only):
    if index == 5:
        break
    visualize_crnn_preprocessing(string, config)
    print(inference_crnn_single_image_onnx(onnx_session, converter, string, config))


In [None]:

import onnxruntime as ort
import torch
import torch.nn.functional as F
import numpy as np
import os

def load_onnx_model(onnx_path):
    """Load ONNX model as a runtime session."""
    session = ort.InferenceSession(onnx_path)
    return session

def inference_crnn_single_image_onnx(session, converter, image_path, config):
    # Preprocess image
    image_tensor = preprocess_crnn_image(image_path, config)  # [1, C, H, W]
    image_np = image_tensor.cpu().numpy().astype(np.float32)

    batch_size = image_np.shape[0]

    # Run ONNX inference
    if config.Prediction == "Attn":
        text_input = np.zeros((batch_size, config.batch_max_length+1), dtype=np.int64)
        outputs = session.run(None, {"input": image_np})
        preds = outputs[0]  # [batch, seq_len, num_classes]
    else:
        outputs = session.run(None, {"input": image_np })
        preds = outputs[0]  # [batch, seq_len, num_classes]

    if config.Prediction == "CTC":
        preds_size = np.array([preds.shape[1]] * batch_size, dtype=np.int32)
        preds_index = preds.argmax(axis=2).flatten()
        preds_str = converter.decode_greedy(preds_index, preds_size)
        return preds_str

    elif config.Prediction == "Attn":
        preds = preds[:, :config.batch_max_length - 1, :]
        preds_index = preds.argmax(axis=2)
        preds_str = converter.decode(torch.from_numpy(preds_index), 
                                     torch.IntTensor([config.batch_max_length] * batch_size))

        preds_prob = F.softmax(torch.from_numpy(preds), dim=2)
        preds_max_prob, _ = preds_prob.max(dim=2)

        final_result = []
        for pred, pred_max_prob in zip(preds_str, preds_max_prob):
            eos_pos = pred.find('[s]')
            if eos_pos != -1:
                pred = pred[:eos_pos]
                pred_max_prob = pred_max_prob[:eos_pos]
            final_result.append(pred)
        return final_result

    return None

# =========================
# Example usage
# =========================
model, converter = load_crnn_model("models/iter_70000.pth",config)
onnx_session = load_onnx_model("models/ResNetBiLSTMAttnv1Batch32.onnx")

all_entries = os.listdir("crops")
files_only = [os.path.join("crops", entry) for entry in all_entries if os.path.isfile(os.path.join("crops", entry))]

for index, string in enumerate(files_only):
    if index == 5:
        break
    visualize_crnn_preprocessing(string, config)
    print(inference_crnn_single_image_onnx(onnx_session, converter, string, config))


In [None]:

import onnxruntime as ort
import torch
import torch.nn.functional as F
import numpy as np
import os
#Batch
def load_onnx_model(onnx_path):
    """Load ONNX model as a runtime session."""
    session = ort.InferenceSession(onnx_path)
    return session

def preprocess_crnn_image_batch(image_paths, config):
    matrix = [preprocess_crnn_image(image_path,config)[0] for image_path in image_paths]
    image_tensors = torch.stack(matrix, dim=0)
    return image_tensors

def inference_crnn_batch_image_onnx(session, converter, image_paths, config):
    # Preprocess image
    image_tensors = preprocess_crnn_image_batch(image_paths, config)

    # batch_size = image_tensors.shape[0]

    # if batch_size < 16:
    #     pad = 16 - batch_size
    #     padding = torch.empty(
    #         pad,
    #         config.input_channel,
    #         config.imgH,
    #         config.imgW,
    #         device=image_tensors.device,
    #         dtype=image_tensors.dtype
    #     )
    #     image_tensors = torch.cat([image_tensors, padding], dim=0)

    print(image_tensors.shape)
    image_np = image_tensors.cpu().numpy().astype(np.float32)

    batch_size = image_np.shape[0]

    # Run ONNX inference
    if config.Prediction == "Attn":
        text_input = np.zeros((batch_size, config.batch_max_length+1), dtype=np.int64)
        outputs = session.run(None, {"input": image_np })
        preds = outputs[0]  # [batch, seq_len, num_classes]
    else:
        outputs = session.run(None, {"input": image_np })
        preds = outputs[0]  # [batch, seq_len, num_classes]

    if config.Prediction == "CTC":
        preds_size = np.array([preds.shape[1]] * batch_size, dtype=np.int32)
        preds_index = preds.argmax(axis=2).flatten()
        preds_str = converter.decode_greedy(preds_index, preds_size)
        return preds_str

    elif config.Prediction == "Attn":
        preds = preds[:, :config.batch_max_length - 1, :]
        preds_index = preds.argmax(axis=2)
        preds_str = converter.decode(torch.from_numpy(preds_index), 
                                     torch.IntTensor([config.batch_max_length] * batch_size))

        preds_prob = F.softmax(torch.from_numpy(preds), dim=2)
        preds_max_prob, _ = preds_prob.max(dim=2)

        final_result = []
        for pred, pred_max_prob in zip(preds_str, preds_max_prob):
            eos_pos = pred.find('[s]')
            if eos_pos != -1:
                pred = pred[:eos_pos]
                pred_max_prob = pred_max_prob[:eos_pos]
            final_result.append(pred)
        return final_result

    return None

# =========================
# Example usage
# =========================
model, converter = load_crnn_model("models/iter_70000.pth",config)
onnx_session = load_onnx_model("models/ResNetBiLSTMAttnv2Batch.onnx")

all_entries = os.listdir("crops")
files_only = [os.path.join("crops", entry) for entry in all_entries if os.path.isfile(os.path.join("crops", entry))]

for index, string in enumerate(files_only):
    if index == 1:
        break
    string = [string, files_only[index+1]]
    #visualize_crnn_preprocessing(string, config)
    print(inference_crnn_batch_image_onnx(onnx_session, converter, string, config))
