In [2]:
import torch
from torchvision import transforms
from PIL import Image
import yaml
from model import CheatDetectionModel 
import os
from utils import readConfig

device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model state dictionary
model_path = 'models/best_model.pth'
state_dict = torch.load(model_path, map_location=torch.device(device))

# Instantiate your model based on the configuration used during training
config_file = 'config/model_config.yaml'  
config = readConfig(config_file)
model = CheatDetectionModel(config['model']['num_classes'])  

# Load state dict into the model
model.load_state_dict(state_dict)
model.eval()  # Set the model to evaluation mode

# Define transforms
data_transforms = transforms.Compose([
    transforms.Resize((640, 640)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Function to predict class for a single image
def predict_image_class(image_path):
    image = Image.open(image_path)
    image = data_transforms(image)
    image = torch.unsqueeze(image, 0)
    
    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs, 1)
        predicted_class = predicted.item()
    
    return predicted_class

# Example usage
if __name__ == "__main__":
    image_path = 'cat.99.jpg'  # Replace with your actual image path
    predicted_class = predict_image_class(image_path)
    print(f"Predicted class: {predicted_class}")


True