In [6]:
import torch
import cv2
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import tkinter as tk
from tkinter import filedialog, Label, Button, PhotoImage, Frame
from tkinter.font import Font
import os

In [7]:
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.fc1 = nn.Linear(128 * 16 * 16, 128)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, 1)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 16 * 16)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.sigmoid(self.fc2(x))
        return x

<h3> Loading the Model
Create a function to load the model:</h3>

In [3]:
def load_model(model_path):
    model = CNNModel()
    model.load_state_dict(torch.load(model_path))
    model.eval()
    return model



<h3> Preprocessing the Input Image
Create a function to preprocess the input image:</h3>

In [4]:
def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    
    image = cv2.imread(image_path)
    if image is None:
        raise FileNotFoundError(f"Image not found at {image_path}")
    
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = transform(image)
    image = image.unsqueeze(0)  # Add batch dimension
    return image

<h3>Making a Prediction
Create a function to make a prediction and interpret the result:</h3>

In [5]:
def predict(image_path, model):
    image = preprocess_image(image_path)
    with torch.no_grad():
        output = model(image)
        prediction = (output > 0.5).float()
        return 'accepted' if prediction.item() == 1.0 else 'rejected'