In [None]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import torch.nn as nn
import torch.optim as optim
import os
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Define transformations for data preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

In [None]:
# Load your saved model architecture
model = models.resnet18(pretrained=False)
num_classes = 4  # Define the number of classes
model.fc = nn.Linear(model.fc.in_features, num_classes)

In [None]:
# Load the saved model checkpoint
model.load_state_dict(torch.load('model_checkpoint.pth'))
model.eval()  # Set the model to evaluation mode

In [None]:
# Define the root folder of your test data
test_data_root = ''

In [None]:
# Get the list of class names
class_names = sorted(os.listdir(test_data_root))
print(class_names)

In [None]:
# Define the imshow function to visualize images
def imshow(image):
    image = image.numpy().transpose((1, 2, 0))
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    image = std * image + mean
    image = np.clip(image, 0, 1)
    plt.imshow(image)
    plt.axis('off')

In [None]:
# Process test images
for class_name in class_names:
    class_folder = os.path.join(test_data_root, class_name)
    for test_image_name in os.listdir(class_folder):
        test_image_path = os.path.join(class_folder, test_image_name)
        
        # Load and preprocess the test image
        test_image = Image.open(test_image_path)
        test_image = transform(test_image).unsqueeze(0)  # Add batch dimension
        
        # Perform inference
        with torch.no_grad():
            outputs = model(test_image)
        
        # Process the outputs (e.g., get predicted class)
        _, predicted_class = torch.max(outputs, 1)
        predicted_label = class_names[predicted_class.item()]  # Use the predicted class name
        
        # Display the test image along with labels
        plt.figure()
        plt.title(f'Actual: {class_name}, Predicted: {predicted_label}')
        imshow(test_image.squeeze())  # Remove the batch dimension
        plt.show()