# Use Model and Evaluate with own Dataset
Use and test the model on a custom input image and let it to an prediction.
Or Select a Dataset Folder and do Validation (eg: for testing the model on a custom dataset)

In [2]:
import torch
from torchvision import models, transforms
import ipywidgets as widgets
from IPython.display import display
from PIL import Image
import os
import io
import sys

class_names = ['clean', 'avgDirty', 'dirty']
global img, modelPath

def use_model(model, device, img):
    model = model.to(device)
    model.eval()   # set model to evaluation mode

    with torch.no_grad():   # no need to track gradients for validation
        imgs = img.to(device)   # move data to device

        # -- forward pass --
        outputs = model(imgs)
        probs = torch.nn.functional.softmax(outputs, dim=1)  # Apply softmax to get probabilities
        _, preds = torch.max(outputs, 1)
        pred = preds.item()

        prob = probs.tolist()
        prob = probs[0][preds].item() * 100

        return pred, prob

def predImage():
    global img, modelPath

    model, device = load_model()

    # --- Prepare Image ---
    transform = transforms.Compose([
        transforms.Resize(512),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    img = transform(img)
    img = img.unsqueeze(0)  # add batch dimension

    pred, prob = use_model(model, device, img)

    print(f"Prediction: {class_names[pred]} ({prob:.2f}%)")
    with output:
        print(f"Prediction: {class_names[pred]} ({prob:.2f}%)")

# Function to handle file selection
def on_file_select(change):
    global img

    file_info = change['new'][0]  # Extract the file info dictionary from the tuple
    file_content = file_info['content']  # Access the file content

    # Convert the file content to an image
    image = Image.open(io.BytesIO(file_content))
    img = image
    print("Image loaded")

    output.clear_output()
    displayImage = widgets.Image(value=file_content, format='jpg', width=512, height=512)
    with output:
        display(displayImage)

def get_datasets_from_folder(folder):
    return [name for name in os.listdir(folder) if os.path.isdir(os.path.join(folder, name))]

def load_model():
    global img, modelPath
    
    # Get widget values
    trainingFolder = trainingFolderWidget.value
    modelPath = os.path.join('./trainingOutput', trainingFolder, 'best_model.pth')

    # --- Prepare Model ---
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model_ft = models.inception_v3(pretrained=True)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = torch.nn.Linear(num_ftrs, 3)
    model_ft.to(device)
    model = model_ft
    model.load_state_dict(torch.load(modelPath))

    print(f"Loaded model: {modelPath}")

    return model, device


def on_predImage_btn_clicked(b):
    global modelPath
    output.clear_output()
    
    predImage()

def on_predFolder_btn_clicked(b):
    global modelPath
    print(f"Running validation on folder: {valSet_path.value}")

    output.clear_output()
    
    trainingFolder = trainingFolderWidget.value
    modelPath = os.path.join('./trainingOutput', trainingFolder, 'best_model.pth')
    val_set_path = os.path.abspath(valSet_path.value)
    output_folder = os.path.abspath(output_path.value)
    print(f"val_path: {val_set_path}")
    command = f"""start cmd /k python ./src/trainCNN/validateCNN.py --model_path {modelPath} --val_set_path "{val_set_path}" --output_folder "{output_folder}" """

    # run script
    with output:
        print(f"\nExecuting command: {command}")
        os.system(command)

def on_eval_btn_clicked(b):
    output.clear_output()
    
    # Get widget values
    trainingFolder = trainingFolderWidget.value
    modelPath = os.path.join('./trainingOutput', trainingFolder, 'best_model.pth')

    # run script
    with output:
        command = f"start cmd /k python ./src/trainCNN/validateCNN.py --model_path {modelPath}"
        print(f"\nExecuting command: {command}")
        os.system(command)

# --- Widgets ---
# pick model
title_modelSelect = widgets.HTML(value="<h2>Select Model</h2>")
trainingFolderWidget = widgets.Dropdown(description='Model', options=get_datasets_from_folder('./trainingOutput'))
vboxFolderSelect = widgets.VBox([title_modelSelect, trainingFolderWidget])

# --- eval with val set ---
title_evalModel = widgets.HTML(value="<h2>Evaluate with <br> Validation Dataset</h2>")
btn_evalModel = widgets.Button(description="Evaluate", button_style='success')
btn_evalModel.on_click(on_eval_btn_clicked)
vbox_evalModel = widgets.VBox([title_evalModel, btn_evalModel])

# --- predict Image ---
title_imagePred = widgets.HTML(value="<h2>Predict Image</h2>")
file_picker = widgets.FileUpload(
    accept='',  # Accept all file types
    multiple=False,  # Only single file upload
    description='Select Image',
)
btn_predImage = widgets.Button(description="Predict Image", button_style='success')
btn_predImage.on_click(on_predImage_btn_clicked)

vboxpredImage = widgets.VBox([title_imagePred, file_picker, btn_predImage])

# --- predict Folder ---
title_folderPred = widgets.HTML(value="<h2>Evaluate custom folder</h2>")
valSet_path = widgets.Text(value='./trainingOutput', description='Dataset Path:', disabled=False)
output_path = widgets.Text(value='./trainingOutput', description='Output Path:', disabled=False)
btn_predFolder = widgets.Button(description="Evaluate Folder", button_style='success')
btn_predFolder.on_click(on_predFolder_btn_clicked)

vboxpredFolder = widgets.VBox([title_folderPred, valSet_path, output_path, btn_predFolder])

spacer = widgets.HTML(value="<div style='width:50px'></div>")
hbox = widgets.HBox([vbox_evalModel, spacer, vboxpredImage, spacer, vboxpredFolder])

outputTitle = widgets.HTML(value="<h2>Output</h2>")
output = widgets.Output()
output.layout = {
    'border': '1px solid black',
    'overflow_y': 'auto',  # Add a vertical scrollbar in case of overflow
    'overflow_x': 'auto',  # Add a horizontal scrollbar in case of overflow
}

# Display widgets
display(vboxFolderSelect, hbox, outputTitle, output)

# Observe changes in the file picker
file_picker.observe(on_file_select, names='value')






VBox(children=(HTML(value='<h2>Select Model</h2>'), Dropdown(description='Model', options=('2024-06-07_13-34-1…

HBox(children=(VBox(children=(HTML(value='<h2>Evaluate with <br> Validation Dataset</h2>'), Button(button_styl…

HTML(value='<h2>Output</h2>')

Output(layout=Layout(border_bottom='1px solid black', border_left='1px solid black', border_right='1px solid b…