In [1]:
import torch
import torch.nn as nn
from torchvision import transforms
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import string

In [2]:
characters = string.ascii_letters + string.digits
char_to_idx = {char: idx + 1 for idx, char in enumerate(characters)}
char_to_idx['<BLANK>'] = 0
idx_to_char = {idx: char for char, idx in char_to_idx.items()}

In [5]:
class BidirectionalLSTM(nn.Module):
    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM, self).__init__()
        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
        self.linear = nn.Linear(nHidden * 2, nOut)

    def forward(self, x):
        recurrent, _ = self.rnn(x)
        T, b, h = recurrent.size()
        output = self.linear(recurrent.view(T * b, h))
        output = output.view(T, b, -1)
        return output

class CRNN(nn.Module):
    def __init__(self, imgH, nc, nclass, nh):
        super(CRNN, self).__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(nc, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.MaxPool2d((2, 1), (2, 1)),
            nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(),
            nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(),
            nn.MaxPool2d((2, 1), (2, 1)),
            nn.Conv2d(512, 512, 2, 1, 0), nn.BatchNorm2d(512), nn.ReLU(),
        )

        self.rnn = nn.Sequential(
            BidirectionalLSTM(512, nh, nh),
            BidirectionalLSTM(nh, nh, nclass)
        )

    def forward(self, x):
        conv = self.cnn(x)
        b, c, h, w = conv.size()
        assert h == 1, 'Expected height of conv features to be 1'
        conv = conv.squeeze(2) 
        conv = conv.permute(2, 0, 1)  
        output = self.rnn(conv)
        return output

In [8]:
img_h = 32 
nc = 1
nclass = len(char_to_idx)
nh = 256

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CRNN(img_h, nc, nclass, nh)
model.load_state_dict(torch.load('./model_weights/receipts_ocr_model.pth', map_location=device))
model = model.to(device)
model.eval()

  model.load_state_dict(torch.load('./model_weights/receipts_ocr_model.pth', map_location=device))


CRNN(
  (cnn): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU()
    (14): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation

In [9]:
transform = transforms.Compose([
    transforms.Resize((32, 128)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [10]:
def decode_predictions(outputs):
    outputs = outputs.permute(1, 0, 2)
    batch_size = outputs.size(0)
    _, preds = outputs.max(2)
    preds = preds.cpu().numpy()

    decoded_texts = []
    for i in range(batch_size):
        pred = preds[i]
        pred_text = ''
        prev_idx = -1
        for idx in pred:
            if idx != prev_idx and idx != 0:
                char = idx_to_char.get(idx, '')
                pred_text += char
            prev_idx = idx
        decoded_texts.append(pred_text)
    return decoded_texts

In [11]:
image_path = input("Enter the path to the image file: ")

In [24]:
def recognize_text_from_image(image):
    image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    image_pil = transform(image_pil)

    image_tensor = image_pil.unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(image_tensor)
        outputs = outputs.log_softmax(2)

        decoded_texts = decode_predictions(outputs)
        recognized_text = decoded_texts[0]

    return recognized_text

In [25]:
print(recognize_text_from_image(cv2.imread(image_path)))

3000
