In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchmetrics
from torchvision import models, transforms
from torch.utils.data import DataLoader
import tkinter as tk
from tkinter import filedialog
from PIL import Image, ImageTk
import matplotlib.pyplot as plt
import numpy as np

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Define the ModelVGG16 class
class ModelVGG16(nn.Module):
    def __init__(self):
        super().__init__()
        self.alpha = 0.7
        
        self.base = models.vgg16(pretrained=True)
        
        # Freeze all layers except the last 15
        for param in list(self.base.parameters())[:-15]:
            param.requires_grad = False
                    
        self.base.classifier = nn.Sequential()  # Clear classifier
        self.base.fc = nn.Sequential()  # Remove fc layers
            
        # Custom blocks
        self.block1 = nn.Sequential(
            nn.Linear(512 * 7 * 7, 256),  # Adjust input size based on VGG16 output
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
        )
        
        self.block2 = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 9)  # Assuming 9 classes of fruits
        )
        
        self.block3 = nn.Sequential(
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(32, 2)  # Fresh or Stale (binary classification)
        )

        # Optimizers
        self.optimizer1 = optim.Adam([{'params': self.base.parameters(), 'lr': 1e-5},
                                      {'params': self.block1.parameters(), 'lr': 3e-4}])
        self.optimizer2 = optim.Adam(self.block2.parameters(), lr=3e-4)
        self.optimizer3 = optim.Adam(self.block3.parameters(), lr=3e-4)
        
        # Loss function
        self.loss_fxn = nn.CrossEntropyLoss()

        # Accuracy metrics
        self.fruit_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=9)
        self.fresh_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=2)

    def forward(self, x):
        x = self.base.features(x)  # Use VGG16's convolutional layers
        x = torch.flatten(x, 1)    # Flatten the output
        x = self.block1(x)         # Pass through custom block1
        y1, y2 = self.block2(x), self.block3(x)  # Get predictions from block2 and block3
        return y1, y2


# Instantiate the model
model_vgg16 = ModelVGG16().to(device)

# **Load your pre-trained model**
# Replace 'model.pth' with the correct path to your downloaded model file
model_path = "model.pth"  # Ensure 'model.pth' is in the same directory or provide the full path
try:
    # If you saved only the state_dict
    model_vgg16.load_state_dict(torch.load(model_path, map_location=device))
    print(f"Model loaded successfully from {model_path}")
except Exception as e:
    print(f"Error loading the model: {e}")

model_vgg16.eval()  # Set the model to evaluation mode

# **Update the class names to match your model's output**
# Assuming your model's block2 outputs 9 fruit classes
class_names = [
    'apple', 'banana', 'orange', 'strawberry', 'tomato',
    'grape', 'pineapple', 'mango', 'blueberry'  # Add up to 9 classes
]

# Function to preprocess the image
def preprocess_image(image_path):
    img = Image.open(image_path)

    # Convert RGBA to RGB if necessary
    if img.mode == 'RGBA':
        img = img.convert('RGB')

    # Define the transformations (same as before)
    transform = transforms.Compose([
        transforms.Resize((244, 244)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Apply the transformation
    img_t = transform(img).unsqueeze(0)  # Add batch dimension
    return img_t


# Function to run the prediction
def predict_freshness(image_path):
    image = preprocess_image(image_path)
    image = image.to(device)
    
    with torch.no_grad():
        outputs = model_vgg16(image)
        # Assuming the model has two outputs for fruit type and freshness
        fruit_pred = torch.argmax(outputs[0], axis=1).cpu().numpy()[0]
        fresh_pred = torch.argmax(outputs[1], axis=1).cpu().numpy()[0]
    
    # Map the predictions to labels
    fruit_label = class_names[fruit_pred]  # Ensure class_names has 9 classes
    freshness_label = 'Fresh' if fresh_pred == 0 else 'Stale'
    
    return fruit_label, freshness_label

# Function to open the file dialog and get the image
def open_image():
    file_path = filedialog.askopenfilename()
    if file_path:
        try:
            img = Image.open(file_path)
            img = img.resize((244, 244))  # Adjust to match display size
            img_tk = ImageTk.PhotoImage(img)
            panel.config(image=img_tk)
            panel.image = img_tk

            # Run prediction
            fruit_label, freshness_label = predict_freshness(file_path)
            result_label.config(text=f"Prediction: {fruit_label}, {freshness_label}")
        except Exception as e:
            result_label.config(text=f"Error processing image: {e}")

# Create the main window
root = tk.Tk()
root.title("Fruit Freshness Predictor")

# Add a button to load the image
btn = tk.Button(root, text="Load Image", command=open_image)
btn.pack(pady=10)

# Add a label to display the prediction result
result_label = tk.Label(root, text="Prediction: ", font=("Helvetica", 14))
result_label.pack(pady=10)

# Add a panel to display the selected image
panel = tk.Label(root)
panel.pack(pady=10)

# Run the application
root.mainloop()


  model_vgg16.load_state_dict(torch.load(model_path, map_location=device))


Model loaded successfully from model.pth
