In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
%matplotlib inline

In [4]:
# Define the image preprocessing transform
preprocess = transforms.Compose([
    transforms.Resize((112, 112)),  # ResNet standard input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet stats
])

def load_model(model_path, num_classes):
    model = models.resnet34(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.load_state_dict(torch.load(model_path, map_location="cpu"))
    model.eval()  # Set model to evaluation mode
    return model

# Inference function
def predict_char(image_path, model, class_names):
    """
    Perform inference on a single image.

    Args:
    - image_path (str): Path to the input image.
    - model (torch.nn.Module): Trained PyTorch model.
    - class_names (list of str): List of class names.

    Returns:
    - str: Predicted class name.
    """
    # Load and preprocess the image
    image = Image.open(image_path).convert("RGB")
    input_tensor = preprocess(image)
    # print(type(input_tensor))
    plt.imshow(input_tensor.view(-1, 224).cpu().numpy())
    plt.show()
    input_batch = input_tensor.unsqueeze(0)  # Create a mini-batch as expected by the model

    # Move the input to the same device as the model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    input_batch = input_batch.to(device)
    model = model.to(device)

    # Perform inference
    with torch.no_grad():
        output = model(input_batch)
        _, predicted_idx = torch.max(output, 1)

    # Get the class label
    predicted_class = class_names[predicted_idx.item()]
    return predicted_class

In [6]:
# Load the model and class names
model_path = "./char_classifier.pth"
class_names = ['-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 
               'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 
               'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']',
               '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
               'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
model = load_model(model_path, num_classes=len(class_names))

