In [6]:
# from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import cv2
import torch
import torchvision.transforms as T
from torchvision.models.detection import fasterrcnn_resnet50_fpn
import time

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor


def get_model_instance_segmentation(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(
        in_features_mask,
        hidden_layer,
        num_classes
    )

    return model


# Load the trained model
num_classes = 2  # Update this based on your dataset (including the background class)
model = get_model_instance_segmentation(num_classes)
model.load_state_dict(torch.load('trained_model.pth'))
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Open the video
video_path = 'vids/production_id_4204454 (540p).mp4'  # Replace with your video file path
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
    raise Exception("Error opening video file")

# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'MP4V')
out = cv2.VideoWriter('output\output_video3.mp4', fourcc, 30.0, (int(cap.get(3)), int(cap.get(4))))

count_all_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
average_fps = int(cap.get(cv2.CAP_PROP_FPS))

print(f"Total frames: {count_all_frames}")
# Define the transform to convert frame to tensor
transform = T.Compose([T.ToTensor()])

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    time_counter_start = time.time()

    # Convert frame to tensor and move to the device
    frame_tensor = transform(frame).unsqueeze(0).to(device)

    with torch.no_grad():
        predictions = model(frame_tensor)

    # Draw bounding boxes
    for element in range(len(predictions[0]['boxes'])):
        if predictions[0]['labels'][element] == 1:  # Label 1 corresponds to 'person'
            box = predictions[0]['boxes'][element].cpu().numpy()
            score = predictions[0]['scores'][element].cpu().numpy()
            if score > 0.5:  # Consider detections with a confidence score above 0.5
                cv2.rectangle(frame, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 255, 0), 2)

    # Write the processed frame
    out.write(frame)
    time_counter_end = time.time()
    total_time = time_counter_end - time_counter_start
    fps = 1 / total_time
    average_fps = (average_fps + fps) / cap.get(1)
    print(f"Processed frame {count_all_frames}/{cap.get(1)}, Time: {total_time:.3f}, FPS: {fps:.3f}, Average FPS: {average_fps:.3f}")

# Release everything when done
cap.release()
out.release()
cv2.destroyAllWindows()


Total frames: 442
Processed frame 442/1.0, Time: 0.485, FPS: 2.062, Average FPS: 31.062
Processed frame 442/2.0, Time: 0.208, FPS: 4.808, Average FPS: 17.935
Processed frame 442/3.0, Time: 0.206, FPS: 4.863, Average FPS: 7.599
Processed frame 442/4.0, Time: 0.218, FPS: 4.593, Average FPS: 3.048
Processed frame 442/5.0, Time: 0.219, FPS: 4.569, Average FPS: 1.523
Processed frame 442/6.0, Time: 0.205, FPS: 4.879, Average FPS: 1.067
Processed frame 442/7.0, Time: 0.200, FPS: 4.988, Average FPS: 0.865
Processed frame 442/8.0, Time: 0.224, FPS: 4.459, Average FPS: 0.666
Processed frame 442/9.0, Time: 0.205, FPS: 4.871, Average FPS: 0.615
Processed frame 442/10.0, Time: 0.193, FPS: 5.191, Average FPS: 0.581
Processed frame 442/11.0, Time: 0.204, FPS: 4.891, Average FPS: 0.497
Processed frame 442/12.0, Time: 0.198, FPS: 5.054, Average FPS: 0.463
Processed frame 442/13.0, Time: 0.188, FPS: 5.316, Average FPS: 0.444
Processed frame 442/14.0, Time: 0.191, FPS: 5.242, Average FPS: 0.406
Processed