# Testing Notebook

Load a model and test it against images of your choice.

ipywidgets reference: https://medium.com/data-science/interactive-controls-for-jupyter-notebooks-f5c94829aee6

In [None]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Loading Model

In [None]:
# Loading Model
import os
from setup import SAVE_PATH, FINE_TUNE
from torchvision.models import resnet18
from ipywidgets import interact

data = {'model': None}

@interact
def choose_model(selected_model=os.listdir(SAVE_PATH)):
    saved_state_dict = torch.load(os.path.join(SAVE_PATH, selected_model), weights_only=False)['model']
    if FINE_TUNE:
        model = resnet18(pretrained=True)
        model.fc = torch.nn.Linear(model.fc.in_features, 3)
        model.load_state_dict(saved_state_dict)
        model.eval()
        data['model'] = model
        print('Model loaded successfully')

## Testing On Specific Image

In [None]:
from IPython.display import Image
from setup import DATA_PATH, IMAGE_HEIGHT, IMAGE_WIDTH
import PIL
from torchvision import transforms

CLASS_NAME = 'Covid'
SUBSET = 'train'
PATH = os.path.join(DATA_PATH, SUBSET, CLASS_NAME)

@interact
def change(img_name=os.listdir(os.path.join(PATH))):

    # Img
    img_path = os.path.join(PATH, img_name)
    print(img_path)

    # Loading Image
    pil_image = PIL.Image.open(img_path)
    pil_image = pil_image.convert('RGB')

    transform = transforms.Compose([transforms.Resize(size=(IMAGE_WIDTH, IMAGE_HEIGHT)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
    model_input = transform(pil_image).float().unsqueeze(0)
    
    # Testing Against Model
    result = data['model'](model_input)
    probabilities = torch.softmax(result, dim=1)
    predicted_class = torch.argmax(probabilities).item()


    # Display Results
    display(f'Predicted Class: {predicted_class}')
    display(f'Class: {CLASS_NAME}')
    display(Image(os.path.join(PATH, img_name)))



## Evaluating Image By Image

In [None]:
from data.evaluation import evaluate_single

evaluate_single(model=data['model'], classes_list=['Covid', 'Normal', 'Viral Pneumonia'], device=device)