In [32]:
import torch
from PIL import Image
from torchvision import transforms
import torch.nn as nn

# Define the model structure again (since we need the model class to load it)
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3),
            nn.ReLU()
        )
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 22 * 22, 10)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x

# Load the model and its weights
model = SimpleModel()
model.load_state_dict(torch.load('model.pth'))
model.eval()  # Set the model to evaluation mode

# Define image preprocessing steps (resize, convert to grayscale, and normalize)
transform = transforms.Compose([
    transforms.Resize((28, 28)),        # Resize image to 28x28
    transforms.Grayscale(1),            # Convert to grayscale
    transforms.ToTensor()             # Convert to tensor
])

# Load your image (make sure it's the right path to the image)
img_path = 'image_test.png'  # Replace with your image path
img = Image.open(img_path)

# Apply the transformation to the image
img = transform(img)

# Add a batch dimension (because PyTorch expects a batch of images, not a single one)
img = img.unsqueeze(0)  # Shape becomes [1, 1, 28, 28] (1 image, 1 color channel, 28x28 pixels)

# Make the prediction (no need to compute gradients during inference)

with torch.no_grad():
    output = model(img)

# Get the predicted label (index of the max value in the output)
_, predicted_label = torch.max(output, 1) #Finds the most likely class by selecting the index of the highest scor

# Print the predicted label
print(f'The model predicts the number is: {predicted_label.item()}')


The model predicts the number is: 3
