In [1]:
import cv2
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image

# Constants
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = 224
RESNET_MODEL_PATH = "resnet_fruit_model.pth"  # Path to the trained ResNet model
EFFICIENTNET_MODEL_PATH = "efficientnet_fruit_model.pth"  # Path to the trained EfficientNet model

# Classes
FRUIT_CLASSES = ["1. Green", "1. Ripe", "1. Semi-Ripe", "2. Green Defect", "2. Ripe Defect", "2. Semi-Ripe Defect"]
BRUISED_CLASSES = ["Not Bruised", "Bruised"]

# Preprocessing transformations
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),   
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:

# Model definition (same for both models)
class MultiTaskClassifier(nn.Module):
    def __init__(self, base_model, num_classes, bruised_classes, model_type="resnet"):
        super(MultiTaskClassifier, self).__init__()
        self.base_model = base_model
        self.model_type = model_type

        # Determine the number of input features based on the model type
        if model_type == "resnet":
            in_features = self.base_model.fc.in_features
            # Replace the fc layer with an Identity layer
            self.base_model.fc = nn.Identity()
        elif model_type == "efficientnet":
            in_features = self.base_model.classifier[1].in_features
            # Replace the classifier layer with an Identity layer
            self.base_model.classifier = nn.Identity()
        else:
            raise ValueError(f"Unsupported model type: {model_type}")

        # Define new classification layers
        self.classifier = nn.Linear(in_features, num_classes)
        self.bruised_classifier = nn.Linear(in_features, bruised_classes)

    def forward(self, x):
        x = self.base_model(x)  # Feature extraction
        fruit_class = self.classifier(x)  # Fruit and freshness classification
        bruised_class = self.bruised_classifier(x)  # Bruised/Not Bruised classification
        return fruit_class, bruised_class



In [None]:

# Load ResNet model
resnet_model = MultiTaskClassifier(models.resnet50(weights=models.ResNet50_Weights.DEFAULT), 
                                   len(FRUIT_CLASSES), len(BRUISED_CLASSES), model_type="resnet").to(DEVICE)
resnet_model.load_state_dict(torch.load(RESNET_MODEL_PATH))
resnet_model.eval()

# Load EfficientNet model
efficientnet_model = MultiTaskClassifier(models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT), 
                                         len(FRUIT_CLASSES), len(BRUISED_CLASSES), model_type="efficientnet").to(DEVICE)
efficientnet_model.load_state_dict(torch.load(EFFICIENTNET_MODEL_PATH))
efficientnet_model.eval()


In [None]:

# Fruit counter
"""
def count_fruits(frame):
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    blurred = cv2.GaussianBlur(gray, (5, 5), 0)
    _, thresh = cv2.threshold(blurred, 50, 255, cv2.THRESH_BINARY_INV)
    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    return len(contours)
"""

'\ndef count_fruits(frame):\n    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)\n    blurred = cv2.GaussianBlur(gray, (5, 5), 0)\n    _, thresh = cv2.threshold(blurred, 50, 255, cv2.THRESH_BINARY_INV)\n    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n    return len(contours)\n'

In [None]:

# Inference function
def inference_window(model, window_name):
    cap = cv2.VideoCapture(0)
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        
        # Fruit Count
        # num_fruits = count_fruits(frame)

        # Classification
        img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        img_tensor = transform(img).unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            fruit_pred, bruised_pred = model(img_tensor)
            fruit_class = torch.argmax(fruit_pred, dim=1).item()
            bruised_class = torch.argmax(bruised_pred, dim=1).item()

        # Get classifications
        fruit_name = FRUIT_CLASSES[fruit_class]
        bruise_status = BRUISED_CLASSES[bruised_class]

        # Display Results
        # cv2.putText(frame, f"Fruits Detected: {num_fruits}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
        cv2.putText(frame, f"Bruised: {bruise_status}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
        cv2.putText(frame, f"Type: {fruit_name}", 
                    (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
        cv2.imshow(window_name, frame)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
    cap.release()
    cv2.destroyAllWindows()


In [None]:

# Run inference for ResNet
print("Starting ResNet inference...")
inference_window(resnet_model, "ResNet Fruit Detector")

# Run inference for EfficientNet
print("Starting EfficientNet inference...")
inference_window(efficientnet_model, "EfficientNet Fruit Detector")
