In [None]:
import cv2
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import numpy as np
import argparse

# Constants
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = 224
MODEL_PATH = "resnetv1_fruit_model.pth"  # Path to the trained model
# C:\Users\Kenan\Downloads\CNN_train_test_model\resnet_fruit_model.pth

# Classes
FRUIT_CLASSES = ["1. Green", "1. Ripe", "1. Semi-Ripe", "2. Green Defect", "2. Ripe Defect", "2. Semi-Ripe Defect"]
BRUISED_CLASSES = ["Not Bruised", "Bruised"]

"""
# Construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-i", "--image", required = True,
	help = "Path to the image")
args = vars(ap.parse_args())
"""


usage: ipykernel_launcher.py [-h] -i IMAGE
ipykernel_launcher.py: error: the following arguments are required: -i/--image


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [7]:

# Preprocessing transformations
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


In [4]:

# Model definition (same as training script)
class ResNetClassifier(nn.Module):
    def __init__(self, num_classes, bruised_classes):
        super(ResNetClassifier, self).__init__()
        self.base_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        
        # Save the input feature size of the original fc layer before replacing it
        in_features = self.base_model.fc.in_features
        
        # Replace the fc layer with an Identity layer
        self.base_model.fc = nn.Identity()
        
        # Define new classification layers
        self.classifier = nn.Linear(in_features, num_classes)
        self.bruised_classifier = nn.Linear(in_features, bruised_classes)

    def forward(self, x):
        x = self.base_model(x)  # Feature extraction
        fruit_class = self.classifier(x)  # Fruit and freshness classification
        bruised_class = self.bruised_classifier(x)  # Bruised/Not Bruised classification
        return fruit_class, bruised_class


In [5]:

# Load model
model = ResNetClassifier(len(FRUIT_CLASSES), len(BRUISED_CLASSES)).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH))
model.eval()

# Fruit counter
"""
def count_fruits(frame):
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    blurred = cv2.GaussianBlur(gray, (5, 5), 0)
    _, thresh = cv2.threshold(blurred, 50, 255, cv2.THRESH_BINARY_INV)
    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    return len(contours)
"""


  model.load_state_dict(torch.load(MODEL_PATH))


'\ndef count_fruits(frame):\n    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)\n    blurred = cv2.GaussianBlur(gray, (5, 5), 0)\n    _, thresh = cv2.threshold(blurred, 50, 255, cv2.THRESH_BINARY_INV)\n    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n    return len(contours)\n'

In [None]:

# Real-time inference
def live_inference():
    cap = cv2.VideoCapture(0)
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        
        # Fruit Count
        # num_fruits = count_fruits(frame)

        # Classification
        img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        img_tensor = transform(img).unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            fruit_pred, bruised_pred = model(img_tensor)
            fruit_class = torch.argmax(fruit_pred, dim=1).item()
            bruised_class = torch.argmax(bruised_pred, dim=1).item()

        # Get classifications
        fruit_name = FRUIT_CLASSES[fruit_class]
        bruise_status = BRUISED_CLASSES[bruised_class]

        # Display Results
        # cv2.putText(frame, f"Fruits Detected: {num_fruits}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
        cv2.putText(frame, f"Bruised: {bruise_status}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
        cv2.putText(frame, f"Type: {fruit_name}", 
                    (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
        cv2.imshow("Fruit Detector", frame)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
    cap.release()
    cv2.destroyAllWindows()


In [None]:
# Run live detection
live_inference()