In [None]:
import cv2
import torch
import torchvision.transforms as transforms
from PIL import Image
from torchvision import models
import torchvision.transforms as transforms
from torchvision import transforms
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# Load the trained model checkpoint
model = models.resnet18(pretrained=False)
num_classes = 4  # Change this to the number of your classes
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load('model_checkpoint.pth'))
model.eval()

# Define the class names
classes = ['blank', 'paper', 'rock', 'scissors']

# Define the cropping and resizing transformations
crop_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def crop_frame(frame):
    # Convert the frame to grayscale
    gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    
    # Threshold the frame to create a binary mask
    _, thresholded_frame = cv2.threshold(gray_frame, 1, 255, cv2.THRESH_BINARY)
    
    # Find contours in the binary mask
    contours, _ = cv2.findContours(thresholded_frame, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    if contours:
        # Find the bounding box of the largest contour
        largest_contour = max(contours, key=cv2.contourArea)
        x, y, w, h = cv2.boundingRect(largest_contour)
        
        # Crop the frame to the bounding box
        cropped_frame = frame[y:y+h, x:x+w]
        
        return cropped_frame
    
    return None

def process_frame(frame):
    # Crop the frame to remove black edges
    cropped_frame = crop_frame(frame)
    
    if cropped_frame is not None:
        # Convert cropped frame to PIL Image
        cropped_pil = Image.fromarray(cropped_frame)

        # Apply image transformations
        frame_tensor = crop_transform(cropped_pil).unsqueeze(0)
        return frame_tensor
    else:
        return None


if __name__ == "__main__":
    video_path = "" # input video path here
    output_video_path = "" # place your video output path here
    cap = cv2.VideoCapture(video_path)
    
    frame_width = int(cap.get(3))
    frame_height = int(cap.get(4))
    fps = int(cap.get(5))
    out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
    
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 2  # Increase font size
    font_color = (0, 255, 0)
    text_y = 200  # Lower position on the y-axis

    frame_count = 0
    frames_to_display = 10  # Number of frames to randomly display
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frames_between_display = max(1, total_frames // frames_to_display)

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        # Crop the frame to remove black edges
        cropped_frame = crop_frame(frame)
        
        if cropped_frame is not None:
            # Apply cropping and resizing transformations
            transformed_frame = crop_transform(cropped_frame)
            transformed_frame = transformed_frame.unsqueeze(0)  # Add batch dimension

            # Perform inference
            with torch.no_grad():
                outputs = model(transformed_frame)
                _, predicted_class = torch.max(outputs, 1)
                predicted_label = classes[predicted_class.item()]  # Use the predicted class index

            # Calculate text size for centered text in the video
            text_size = cv2.getTextSize(predicted_label, font, font_scale, 2)[0]
            text_x_centered = (frame.shape[1] - text_size[0]) // 2

            # Write the frame to the output video with centered text
            cv2.putText(frame, predicted_label, (text_x_centered, text_y), font, font_scale, font_color, 2, cv2.LINE_AA)
            
            # Display the frame with prediction
            if frame_count % frames_between_display == 0:
                # Crop the frame for display using the same crop dimensions
                cropped_for_display = crop_frame(frame)
                if cropped_for_display is not None:
                    plt.imshow(cropped_for_display[..., ::-1])  # Convert BGR to RGB
                    plt.title(f"Predicted: {predicted_label}")
                    plt.show()


            out.write(frame)  # Write frame with overlay to the output video

        frame_count += 1

    # Release the video capture and writer
    cap.release()
    out.release()