In [1]:
import os
import json
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.datasets import ImageFolder
from PIL import Image
from pathlib import Path

In [2]:
class CNNClassifier(nn.Module):
    def __init__(self, num_classes):
        super(CNNClassifier, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(16, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128*8*8, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [3]:
def infer(data_dir, model_path):
        
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    train_dataset = ImageFolder("characters_train")
    class_names = train_dataset.classes
    num_classes = len(class_names)
    
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                           [0.229, 0.224, 0.225])
    ])
    
    model = CNNClassifier(num_classes)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    model.eval()
    
    data_path = Path(data_dir)
    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif'}
    image_files = [f for f in data_path.iterdir() 
                   if f.is_file() and f.suffix.lower() in image_extensions]
    
    predictions = {}
    
    with torch.no_grad():
        for img_path in image_files:
            img = Image.open(img_path).convert("RGB")
            img_tensor = transform(img).unsqueeze(0).to(device)
            
            output = model(img_tensor)
            _, pred_idx = torch.max(output, 1)
            pred_class = class_names[pred_idx.item()]
            
            predictions[img_path.name] = pred_class
    
    with open("results.json", 'w') as f:
        json.dump(predictions, f, indent=4)
    
    print(f"Predictions saved to results.json ({len(predictions)} images processed)")

In [4]:
infer("test_images","best_model.pth")

  model.load_state_dict(torch.load(model_path, map_location=device))


Predictions saved to results.json (27 images processed)
