In [None]:
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from pathlib import Path
import os
import cv2
from torchvision.io import read_image
import matplotlib.pyplot as plt
from torchvision import models, transforms
import numpy as np

# Loading the model structure that we will upload our trained model on
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

num_classes = 3
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(device)

In [None]:
model_path = 'test_model_01_RCNN.pth' # uploaded trained model dict state path, change accordingly.

!ls -lh {model_path}  # Checking file size and permissions.
#!unzip -t {model_path} # Uncomment this line if file unzipping needed.

model.load_state_dict(torch.load(model_path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))) # Loading the state dictionary.
model.eval() # Setting evaluation mode.

In [None]:
# Transform input frame to tensor
transform = transforms.Compose([
    transforms.ToTensor()
])

# Open webcam
cap = cv2.VideoCapture(0)  # 0 for default camera
if not (cap.isOpened()):
    print('Camera could not be opened, check device settings/permissions')

while True:
    ret, frame = cap.read()
    if not ret:
        break

    image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    image_tensor = transform(image)

    with torch.no_grad():
        predictions = model([image_tensor])[0]

    # Draw predictions
    for box, score, label in zip(predictions['boxes'], predictions['scores'], predictions['labels']):
        if score > 0.5:
            x1, y1, x2, y2 = box.int().numpy()
            cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
            cv2.putText(frame, f'{label.item()} {score:.2f}', (x1, y1 - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

    # Show result
    cv2.imshow('Live Fire Detection', frame)

    # Press 'q' to quit
    if cv2.waitKey(1) & 0xFF == ord('q'):
        print('testing')
        break

cap.release()
cv2.destroyAllWindows()