# Test Cervical Cancer Model

### Import necessary libraries

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from src.models.cervical_cancer_model import CervicalModel

### Load the model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CervicalModel(num_classes=5)  # Adjust num_classes based on your configuration
checkpoint_path = '../results/cervical_checkpoints/best_checkpoint.pth' # Checkpoint path

# Load the model's state dict (weights)
model.load_state_dict(torch.load(checkpoint_path, weights_only=True))
_ = model.to(device)
_ = model.eval()

### Define the image transformations

In [None]:
transform = transforms.Compose([
    transforms.Resize(256),  # Resize to a fixed size
    transforms.CenterCrop(224),  # Crop to 224x224 to match the model input
    transforms.ToTensor(),  # Convert to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize like ImageNet
])


### Load test dataset

In [None]:
test_dataset_path = 'D:/Data/cervical_cancer_data/test'
test_dataset = datasets.ImageFolder(root=test_dataset_path, transform=transform)

# Create a DataLoader for the test dataset
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

### Function to display image

In [None]:
def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

### Make predictions on a test image

In [None]:
# Predict on a test image
dataiter = iter(test_loader)
images, labels = next(dataiter)

# Display the first image
imshow(images[0])

# Move image and labels to the device (GPU or CPU)
images, labels = images.to(device), labels.to(device)

# Get the model's prediction
outputs = model(images)
_, predicted = torch.max(outputs, 1)

# Get the true label
true_label = labels.item()

# Print predicted class label
class_names = test_dataset.classes
predicted_class = class_names[predicted]

# Print whether the prediction is correct
print(f'Predicted: {predicted_class}')
print(f'True Label: {class_names[true_label]}')

if predicted == labels:
    print("Prediction is correct.")
else:
    print("Prediction is incorrect.")
