In [None]:
# Just so that you don't have to restart the notebook with every change.
%load_ext autoreload
%autoreload 2 

In [None]:
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import torch

from PIL import Image
from collections import deque, Counter
from common import utils
from ultralytics import YOLO
from transformers import AutoImageProcessor

# Useful constants
CURRENT_DIR = os.getcwd()
IMAGES_DIR = os.path.join(CURRENT_DIR, "images")
VIDEOS_DIR = os.path.join(CURRENT_DIR, "videos")
CHORD_CLASSIFIER_MODEL_DIR = os.path.join(CURRENT_DIR, "chord-classifier-model")
FRETBOARD_RECOGNIZER_MODEL_DIR = os.path.join(CURRENT_DIR, "fretboard-recognizer-model")

chord_clf_model_path = utils.find_files(CHORD_CLASSIFIER_MODEL_DIR, [".safetensors", ".pt"])
chord_clf_config_path = utils.find_files(CHORD_CLASSIFIER_MODEL_DIR, [".json"])
fretboard_rec_model_path = utils.find_files(FRETBOARD_RECOGNIZER_MODEL_DIR, [".safetensors", ".pt"])
fretboard_rec_config_path = utils.find_files(FRETBOARD_RECOGNIZER_MODEL_DIR, [".json"])

utils.ensure_files_exist(
    chord_clf_model_path,
    fretboard_rec_model_path,
    chord_clf_config_path,
    fretboard_rec_config_path,
    names=[
        "Chord Classifier model",
        "Fretboard Recognizer model",
        "Chord Classifier config",
        "Fretboard Recognizer config",
    ],
)

In [None]:
# Load Chord Classifier model
chord_clf_model = utils.load_model(chord_clf_model_path, config_path=chord_clf_config_path)

# Load Fretboard Recognizer model
fretboard_rec_model = utils.load_model(fretboard_rec_model_path, config_path=fretboard_rec_config_path, custom_class=YOLO)

print("Models loaded successfully.")

In [None]:
def process_video(
        video_path,
        chord_clf_model=None, 
        feature_extractor=None,
        fretboard_rec_model=None
):
    # Open the video file
    cap = cv2.VideoCapture(video_path)

    # Get video properties
    fps = int(cap.get(cv2.CAP_PROP_FPS))

    print(f"Video FPS: {fps}")

    recent_classifications = deque(maxlen=fps)
    
    current_frame = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        
        current_frame += 1

        # Convert BGR to RGB
        rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
        # Convert to PIL Image
        pil_image = Image.fromarray(rgb_frame)

        # Perform inference on the fretboard recognizer model
        results = fretboard_rec_model.predict(pil_image)[0].boxes
        indices = (results.cls == 80).nonzero(as_tuple=True)[0]

        if len(indices) > 0:
            # Get the bounding box with the highest confidence
            max_conf_index = results.conf[indices].argmax()
            result = results.data[indices[max_conf_index]]

            # Increase bounding box size by 90% :(
            x1, y1, x2, y2 = result[:4]
            width = x2 - x1
            height = y2 - y1
            increase_x = width * 0.90 / 2
            increase_y = height * 0.90 / 2

            new_x1 = max(0, x1 - increase_x)
            new_y1 = max(0, y1 - increase_y)
            new_x2 = min(pil_image.width, x2 + increase_x)
            new_y2 = min(pil_image.height, y2 + increase_y)

            # Crop the fretboard with increased bounding box
            pil_image = pil_image.crop(np.array([new_x1, new_y1, new_x2, new_y2]))

        # # Optional: Display the cropped image
        # plt.imshow(pil_image)
        # plt.show()

        # Preprocess the image
        inputs = feature_extractor(images=pil_image, return_tensors="pt")

        # Perform inference
        with torch.no_grad():
            outputs = chord_clf_model(**inputs)

        probabilities = F.softmax(outputs.logits, dim=-1)

        # Get the predicted class
        predicted_class_idx = probabilities.argmax(-1).item()
        predicted_class = chord_clf_model.config.id2label[predicted_class_idx]

        # Add the prediction to recent classifications
        recent_classifications.append(predicted_class)

        # If we have collected enough frames, determine the most common classification
        if len(recent_classifications) == fps:
            print(recent_classifications)
            most_common_class = Counter(recent_classifications).most_common(1)[0][0]
            print(f"Frame {current_frame}: Most common classification in last {fps} frames: {most_common_class}")
            recent_classifications.clear()
        
        # Optional: Print progress
        if current_frame % 100 == 0:
            print(f"Processed {current_frame}/{current_frame} frames")
    
    cap.release()

In [None]:
video_path = "/home/dhimitriosduka/Videos/Screencasts/Screencast from 2024-08-26 11-52-18.mp4"

feature_extractor = AutoImageProcessor.from_pretrained("facebook/dinov2-small")

process_video(
    video_path,
    chord_clf_model=chord_clf_model,
    feature_extractor=feature_extractor,
    fretboard_rec_model=fretboard_rec_model 
)