In [2]:
# from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import cv2
import torch
import torchvision.transforms as T
from torchvision.models.detection import fasterrcnn_resnet50_fpn
import time

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor


def get_model_instance_segmentation(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(
        in_features_mask,
        hidden_layer,
        num_classes
    )

    return model


# Load the trained model
num_classes = 2  # Update this based on your dataset (including the background class)
model = get_model_instance_segmentation(num_classes)
model.load_state_dict(torch.load('trained_model.pth'))
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Open the video
video_path = 'vids/pexels-evgenij-mikhailov-9921125 (360p).mp4'  # Replace with your video file path
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
    raise Exception("Error opening video file")

# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'MP4V')
out = cv2.VideoWriter('output\output_video4.mp4', fourcc, 30.0, (int(cap.get(3)), int(cap.get(4))))

count_all_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
average_fps = 0
total_fps = 0

print(f"Total frames: {count_all_frames}")
# Define the transform to convert frame to tensor
transform = T.Compose([T.ToTensor()])

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    time_counter_start = time.time()

    # Convert frame to tensor and move to the device
    frame_tensor = transform(frame).unsqueeze(0).to(device)

    with torch.no_grad():
        predictions = model(frame_tensor)

    # Draw bounding boxes
    for element in range(len(predictions[0]['boxes'])):
        if predictions[0]['labels'][element] == 1:  # Label 1 corresponds to 'person'
            box = predictions[0]['boxes'][element].cpu().numpy()
            score = predictions[0]['scores'][element].cpu().numpy()
            if score > 0.5:  # Consider detections with a confidence score above 0.5
                cv2.rectangle(frame, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 255, 0), 2)

    # Write the processed frame
    out.write(frame)
    time_counter_end = time.time()
    total_time = time_counter_end - time_counter_start
    fps = 1 / total_time
    total_fps += fps
    average_fps = total_fps / cap.get((1))
    
    print(
        f"Processed frame {count_all_frames}/{cap.get(1)}, Time: {total_time:.3f}, FPS: {fps:.3f}, Average FPS: {average_fps:.3f}")

# Release everything when done
cap.release()
out.release()
cv2.destroyAllWindows()


Total frames: 431
Processed frame 431/1.0, Time: 0.348, FPS: 2.877, Average FPS: 2.877
Processed frame 431/2.0, Time: 0.167, FPS: 5.987, Average FPS: 4.432
Processed frame 431/3.0, Time: 0.166, FPS: 6.018, Average FPS: 4.960
Processed frame 431/4.0, Time: 0.168, FPS: 5.936, Average FPS: 5.204
Processed frame 431/5.0, Time: 0.208, FPS: 4.819, Average FPS: 5.127
Processed frame 431/6.0, Time: 0.177, FPS: 5.642, Average FPS: 5.213
Processed frame 431/7.0, Time: 0.159, FPS: 6.286, Average FPS: 5.366
Processed frame 431/8.0, Time: 0.159, FPS: 6.309, Average FPS: 5.484
Processed frame 431/9.0, Time: 0.160, FPS: 6.248, Average FPS: 5.569
Processed frame 431/10.0, Time: 0.160, FPS: 6.231, Average FPS: 5.635
Processed frame 431/11.0, Time: 0.158, FPS: 6.348, Average FPS: 5.700
Processed frame 431/12.0, Time: 0.155, FPS: 6.456, Average FPS: 5.763
Processed frame 431/13.0, Time: 0.164, FPS: 6.102, Average FPS: 5.789
Processed frame 431/14.0, Time: 0.176, FPS: 5.678, Average FPS: 5.781
Processed f