In [1]:
import tkinter as tk
from tkinter import Label, Button
import cv2
from PIL import Image, ImageTk
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
import numpy as np

In [2]:
processor = AutoImageProcessor.from_pretrained("yangy50/garbage-classification")
model = AutoModelForImageClassification.from_pretrained("yangy50/garbage-classification")

preprocessor_config.json:   0%|          | 0.00/228 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


config.json:   0%|          | 0.00/883 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/343M [00:00<?, ?B/s]

In [18]:
def classify_frame(frame):
    # Convert the frame from BGR (OpenCV format) to RGB
    rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    # Convert the RGB frame to a PIL image
    pil_image = Image.fromarray(rgb_frame)
    # Preprocess the image
    inputs = processor(images=pil_image, return_tensors="pt")
    # Perform inference
    with torch.no_grad():
        outputs = model(**inputs)
    # Get predicted label
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    # Map the predicted index to the corresponding label using model.config.id2label
    predicted_label = model.config.id2label[predicted_class_idx]
    return predicted_label

In [28]:
class GarbageClassifierApp:
    def __init__(self, window, window_title):
        self.window = window
        self.window.title(window_title)

        # Default video source
        self.video_source = 0

        # Dropdown menu to select camera
        self.camera_label = tk.Label(window, text="Select Camera:")
        self.camera_label.pack(anchor=tk.W)
        self.camera_var = tk.IntVar(value=self.video_source)  # Variable to hold the selected camera index
        self.camera_menu = tk.OptionMenu(window, self.camera_var, *[i for i in range(5)], command=self.change_camera)
        self.camera_menu.pack(anchor=tk.W)

        # Open the video source
        self.vid = cv2.VideoCapture(self.video_source)
        if not self.vid.isOpened():
            raise ValueError("Unable to open video source", self.video_source)

        # Create a canvas to display the video frames
        self.canvas = tk.Canvas(window, width=self.vid.get(cv2.CAP_PROP_FRAME_WIDTH),
                                height=self.vid.get(cv2.CAP_PROP_FRAME_HEIGHT))
        self.canvas.pack()

        # Button to capture and classify the current frame
        self.btn_classify = Button(window, text="Classify", width=50, command=self.classify_current_frame)
        self.btn_classify.pack(anchor=tk.CENTER, expand=True)

        # Label to display the classification result
        self.label_result = Label(window, text="Classification Result: ")
        self.label_result.pack(anchor=tk.CENTER, expand=True)

        # Start the video loop
        self.delay = 15
        self.update()

        self.window.mainloop()

    def change_camera(self, camera_index):
        # Release the current video source
        if self.vid.isOpened():
            self.vid.release()

        # Set the new video source
        self.video_source = int(camera_index)
        self.vid = cv2.VideoCapture(self.video_source)
        if not self.vid.isOpened():
            self.label_result.config(text=f"Unable to open camera {self.video_source}")
        else:
            self.label_result.config(text=f"Camera {self.video_source} selected")

    def update(self):
        # Get a frame from the video source
        ret, frame = self.vid.read()
        if ret:
            # Define the square's size and position
            height, width, _ = frame.shape
            square_size = 200  # Size of the square (length of a side)
            top_left = (width // 2 - square_size // 2, height // 2 - square_size // 2)
            bottom_right = (width // 2 + square_size // 2, height // 2 + square_size // 2)

            # Draw a green square on the frame
            color = (0, 255, 0)  # Green color in BGR
            thickness = 2  # Thickness of the rectangle border
            cv2.rectangle(frame, top_left, bottom_right, color, thickness)

            # Convert the frame to RGB format
            rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            # Convert the image to PIL format
            pil_image = Image.fromarray(rgb_frame)
            # Convert the PIL image to ImageTk format
            imgtk = ImageTk.PhotoImage(image=pil_image)
            # Display the image on the canvas
            self.canvas.create_image(0, 0, anchor=tk.NW, image=imgtk)
            self.photo = imgtk
        # Repeat after a delay
        self.window.after(self.delay, self.update)

    def classify_current_frame(self):
        # Get the current frame
        ret, frame = self.vid.read()
        if ret:
            # Get the region of interest (ROI) inside the square
            height, width, _ = frame.shape
            square_size = 1000
            top_left = (width // 2 - square_size // 2, height // 2 - square_size // 2)
            bottom_right = (width // 2 + square_size // 2, height // 2 + square_size // 2)

            # Crop the region inside the square
            roi = frame[top_left[1]:bottom_right[1], top_left[0]:bottom_right[0]]
            # Classify the frame
            label = classify_frame(roi)
            # Display the result
            self.label_result.config(text=f"Classification Result: {label}")
            if label in ["cardboard","paper","trash"]:
                self.label_result.config(text="Biodegradable")
            else:
                self.label_result.config(text="Non-Biodegradable")

    def __del__(self):
        if self.vid.isOpened():
            self.vid.release()


# Create a window and pass it to the Application object
GarbageClassifierApp(tk.Tk(), "Garbage Classifier")

<__main__.GarbageClassifierApp at 0x22878f49510>

<__main__.GarbageClassifierApp at 0x228775b2610>

<__main__.GarbageClassifierApp at 0x22878f576d0>