In [None]:
import os

import torch
import torchvision.transforms as tt
from torchvision.datasets import ImageFolder

import matplotlib.pyplot as plt

from resnet_9_model import *

## Initialise Data

In [None]:
resize_shape = (48,48)

data_dir = '/home/selimon/Desktop/AI/wdwyl_ros1/src/perception/brand_classification/data/testing'
print(os.listdir(data_dir))

stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

valid_tfms = tt.Compose([
    tt.Resize(resize_shape),
    tt.ToTensor(), 
    tt.Normalize(*stats)
])

# Load dataset & apply transformation
test_ds = ImageFolder(data_dir, valid_tfms)

In [None]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        print('cuda')
        return torch.device('cuda')
    else:
        print('cpu')
        return torch.device('cpu')
    
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)


def predict_image(img, model):
    # Convert to a batch of 1
    xb = to_device(img.unsqueeze(0), device)
    # Get predictions from model
    yb = model(xb)
    # Apply softmax to get probabilities
    probs = F.softmax(yb, dim=1)
    # Retrieve the class labels
    classes = test_ds.classes
    # Print probabilities for each class
    for i, prob in enumerate(probs[0]):
        print(f'{classes[i]}: {prob}')
    # Pick index with highest probability as predicted class
    _, preds  = torch.max(yb, dim=1)
    # Retrieve the class label
    return test_ds.classes[preds[0].item()]

## Initialise Model

In [None]:
device = get_default_device()

# Load the model from a file and move it to the device
model = ResNet9Pretrained(4).to(device)
model.load_state_dict(torch.load('resnet9_model.pth', map_location=device))
model.eval()  # Set the model to evaluation mode

## Run Model

## Testing with Individual Images

In [None]:
correct_predictions = 0
total_images = len(test_ds)

for i in range(total_images):
    # Get the image and its label from the dataset
    img, label = test_ds[i]
    
    # Make a prediction for the image using the model
    predicted_class = predict_image(img, model)
    
    # Check if the prediction is correct
    if predicted_class == test_ds.classes[label]:
        correct_predictions += 1
    
    # Visualize the image
    plt.imshow(img.permute(1, 2, 0).clamp(0, 1))
    plt.title(f"Predicted: {predicted_class}, Actual: {test_ds.classes[label]}")
    plt.show()


## Accuracy

In [None]:
accuracy = correct_predictions / total_images
print(f"Accuracy on the test dataset: {accuracy:.2%}")