In [12]:
import cv2
import torch
import torchvision.transforms as T
from PIL import Image
from ultralytics import YOLO
import timm

In [13]:
# Load YOLO model
yolo_model = YOLO("best_yolo.pt")

In [14]:
vit_model = timm.create_model('vit_tiny_patch16_224', pretrained=False, num_classes=6)
state_dict = torch.load("best_vit_model.pth", map_location='cpu')
vit_model.load_state_dict(state_dict)
vit_model.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=768, out_features=192, bias=True)
        (drop2): Dropout(p=0.0, inplace=Fals

In [15]:
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

In [16]:
labels = ['BIODEGRADABLE', 'CARDBOARD', 'GLASS', 'METAL', 'PAPER', 'PLASTIC']

In [11]:
import cv2
import torch
from PIL import Image

mode = 'vit'  # 'yolo' or 'vit'

cap = cv2.VideoCapture(0)
cv2.namedWindow("Garbage Classifier")

# Get frame size and initialize video writer
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'XVID')  # Use 'MP4V' for .mp4
out = cv2.VideoWriter('output.avi', fourcc, 20.0, (frame_width, frame_height))

while True:
    ret, frame = cap.read()
    if not ret:
        break

    if mode == 'yolo':
        # Run YOLO
        results = yolo_model(frame)[0]
        boxes = results.boxes

        if boxes is not None and len(boxes) > 0:
            # Get confidences
            confidences = boxes.conf.cpu().numpy()
            top_idx = confidences.argmax()
            top_box = boxes[top_idx]

            annotated = frame.copy()
            x1, y1, x2, y2 = map(int, top_box.xyxy[0])
            label = results.names[int(top_box.cls)]
            conf = top_box.conf.item()

            cv2.rectangle(annotated, (x1, y1), (x2, y2), (255, 0, 255), 2)
            cv2.putText(annotated, f"{label.upper()} {conf:.2f}", (x1, y1 - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 255), 2)
            cv2.putText(annotated, "YOLO Detection", (10, annotated.shape[0] - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
        else:
            annotated = frame.copy()

    elif mode == 'vit':
        # Run ViT on full frame
        img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_img = Image.fromarray(img_rgb)
        input_tensor = transform(pil_img).unsqueeze(0)

        with torch.no_grad():
            outputs = vit_model(input_tensor)
            _, predicted = torch.max(outputs, 1)
            label = labels[predicted.item()]

        annotated = frame.copy()
        cv2.putText(annotated, f"ViT: {label}", (10, 40),
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
        cv2.putText(annotated, "ViT Classification", (10, annotated.shape[0] - 10),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
    else:
        annotated = frame.copy()

    # Show result
    cv2.imshow("Garbage Classifier", annotated)

    # Save to video file
    out.write(annotated)

    # Exit condition
    if (cv2.waitKey(1) & 0xFF == ord('q')) or cv2.getWindowProperty("Garbage Classifier", cv2.WND_PROP_VISIBLE) < 1:
        break

# Release resources
cap.release()
out.release()
cv2.destroyAllWindows()


In [17]:
import gradio as gr
import numpy as np
from PIL import Image

def predict(image, mode):
    frame = np.array(image)
    annotated = frame.copy()

    if mode == 'yolo':
        results = yolo_model(frame)[0]
        boxes = results.boxes

        if boxes is not None and len(boxes) > 0:
            confidences = boxes.conf.cpu().numpy()
            top_idx = confidences.argmax()
            top_box = boxes[top_idx]

            x1, y1, x2, y2 = map(int, top_box.xyxy[0])
            label = results.names[int(top_box.cls)]
            conf = top_box.conf.item()

            cv2.rectangle(annotated, (x1, y1), (x2, y2), (255, 0, 255), 2)
            cv2.putText(annotated, f"{label.upper()} {conf:.2f}", (x1, y1 - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 255), 2)
            cv2.putText(annotated, "YOLO Detection", (10, annotated.shape[0] - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)

    elif mode == 'vit':
        img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_img = Image.fromarray(img_rgb)
        input_tensor = transform(pil_img).unsqueeze(0)

        with torch.no_grad():
            outputs = vit_model(input_tensor)
            _, predicted = torch.max(outputs, 1)
            label = labels[predicted.item()]

        cv2.putText(annotated, f"ViT: {label}", (10, 40),
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
        cv2.putText(annotated, "ViT Classification", (10, annotated.shape[0] - 10),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)

    return annotated

In [19]:
gr.Interface(
    fn=predict,
    inputs=[
        # gr.Image(source="webcam", tool="editor", label="Input Image"),  # Webcam input
        gr.Image(label="Upload or Take Picture"), #Upload or Take Picture
        gr.Radio(["yolo", "vit"], label="Choose Model", value="yolo")   # Model selector
    ],
    outputs=gr.Image(label="Annotated Output"),
    live=True,
    title="Garbage Classifier (YOLO + ViT)"
).launch(inbrowser=True)



Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




IMPORTANT: You are using gradio version 3.41.2, however version 4.44.1 is available, please upgrade.
--------

0: 288x416 1 PAPER, 109.0ms
Speed: 2.9ms preprocess, 109.0ms inference, 1.3ms postprocess per image at shape (1, 3, 288, 416)

0: 288x416 1 PAPER, 96.2ms
Speed: 1.6ms preprocess, 96.2ms inference, 0.8ms postprocess per image at shape (1, 3, 288, 416)


Traceback (most recent call last):
  File "c:\Users\Gauriel\anaconda3\envs\DL\lib\site-packages\gradio\routes.py", line 488, in run_predict
    output = await app.get_blocks().process_api(
  File "c:\Users\Gauriel\anaconda3\envs\DL\lib\site-packages\gradio\blocks.py", line 1434, in process_api
    data = self.postprocess_data(fn_index, result["prediction"], state)
  File "c:\Users\Gauriel\anaconda3\envs\DL\lib\site-packages\gradio\blocks.py", line 1335, in postprocess_data
    prediction_value = block.postprocess(prediction_value)
  File "c:\Users\Gauriel\anaconda3\envs\DL\lib\site-packages\gradio\components\image.py", line 314, in postprocess
    return processing_utils.encode_array_to_base64(y)
  File "c:\Users\Gauriel\anaconda3\envs\DL\lib\site-packages\gradio\processing_utils.py", line 104, in encode_array_to_base64
    pil_image = Image.fromarray(_convert(image_array, np.uint8, force_copy=False))
  File "c:\Users\Gauriel\anaconda3\envs\DL\lib\site-packages\gradio\processing_utils.