In [7]:
import cv2
import torch
import torchvision.transforms as T
from PIL import Image
from ultralytics import YOLO
import timm

In [8]:
# Load YOLO model
yolo_model = YOLO("best.pt")

In [9]:
vit_model = timm.create_model('vit_tiny_patch16_224', pretrained=False, num_classes=6)
state_dict = torch.load("best_vit_model.pth", map_location='cpu')
vit_model.load_state_dict(state_dict)
vit_model.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=768, out_features=192, bias=True)
        (drop2): Dropout(p=0.0, inplace=Fals

In [10]:
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

In [11]:
labels = ['BIODEGRADABLE', 'CARDBOARD', 'GLASS', 'METAL', 'PAPER', 'PLASTIC']

In [15]:
mode = 'yolo'  # 'yolo' or 'vit'

cap = cv2.VideoCapture(0)
cv2.namedWindow("Garbage Classifier")

while True:
    ret, frame = cap.read()
    if not ret:
        break

    if mode == 'yolo':
        # Run YOLO
        results = yolo_model(frame)[0]
        boxes = results.boxes

        if boxes is not None and len(boxes) > 0:
            # Get confidences
            confidences = boxes.conf.cpu().numpy()

            # Get index of highest confidence
            top_idx = confidences.argmax()

            # Get just the top box
            top_box = boxes[top_idx]

            # Plot manually (optional: use results.plot() if fine)
            annotated = frame.copy()
            x1, y1, x2, y2 = map(int, top_box.xyxy[0])
            label = results.names[int(top_box.cls)]
            conf = top_box.conf.item()

            cv2.rectangle(annotated, (x1, y1), (x2, y2), (255, 0, 255), 2)
            cv2.putText(annotated, f"{label.upper()} {conf:.2f}", (x1, y1 - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 255), 2)
            cv2.putText(annotated, "YOLO Detection", (10, annotated.shape[0] - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
    
    elif mode == 'vit':
        # Run ViT on full frame
        img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_img = Image.fromarray(img_rgb)
        input_tensor = transform(pil_img).unsqueeze(0)

        with torch.no_grad():
            outputs = vit_model(input_tensor)
            _, predicted = torch.max(outputs, 1)
            label = labels[predicted.item()]

        annotated = frame.copy()
        cv2.putText(annotated, f"ViT: {label}", (10, 40),
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
        cv2.putText(annotated, "ViT Classification", (10, annotated.shape[0] - 10),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)

    # Show result
    cv2.imshow("Garbage Classifier", annotated)

    # Exit condition
    if (cv2.waitKey(1) & 0xFF == ord('q')) or cv2.getWindowProperty("Garbage Classifier", cv2.WND_PROP_VISIBLE) < 1:
        break

cap.release()
cv2.destroyAllWindows()


0: 320x416 1 PAPER, 78.5ms
Speed: 2.1ms preprocess, 78.5ms inference, 0.7ms postprocess per image at shape (1, 3, 320, 416)

0: 320x416 1 PAPER, 75.4ms
Speed: 0.9ms preprocess, 75.4ms inference, 0.6ms postprocess per image at shape (1, 3, 320, 416)

0: 320x416 1 PAPER, 71.8ms
Speed: 0.9ms preprocess, 71.8ms inference, 0.8ms postprocess per image at shape (1, 3, 320, 416)

0: 320x416 1 PAPER, 73.8ms
Speed: 1.2ms preprocess, 73.8ms inference, 0.6ms postprocess per image at shape (1, 3, 320, 416)

0: 320x416 1 PAPER, 80.1ms
Speed: 1.0ms preprocess, 80.1ms inference, 0.6ms postprocess per image at shape (1, 3, 320, 416)

0: 320x416 1 PAPER, 93.5ms
Speed: 1.1ms preprocess, 93.5ms inference, 1.0ms postprocess per image at shape (1, 3, 320, 416)

0: 320x416 1 PAPER, 90.2ms
Speed: 1.0ms preprocess, 90.2ms inference, 0.9ms postprocess per image at shape (1, 3, 320, 416)

0: 320x416 1 PAPER, 97.6ms
Speed: 1.1ms preprocess, 97.6ms inference, 0.8ms postprocess per image at shape (1, 3, 320, 416)
