In [1]:
import numpy as np
import torch
from torchvision.io import read_video
from torch.utils.data import Dataset
import os
from note_detector.python.video_note_detector import generate_labels

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [2]:
# quick debug commands:

vid_path, vid_name = "./training_data/", "dumb_scale_youtube.mp4"
# cur_video_labels, num_frames = generate_labels(vid_path, vid_name)
# video_frames, _, _ = read_video(vid_path + vid_name)
# video_frames[0].shape


In [3]:
# have > 1 video and need dataset to be a LINEARLY sampled set of images from each video, labeled with the note being played in that frame
# we do know the length of the dataset given the number of videos bc it is (# videos) * (# samples taken per video [SPV])
# in __init__: 
#   - Run the API conversion of all the videos to their labeled notes + which frame the note corresponds to - store these in an array; spot idx is the note for sampled image number idx
#   - Extract all the frames from the videos and store them in an array; spot idx of the array is frame [idx % SPV] of video [floor of idx / SPV]
# in __getitem__(idx), return the values (image, label) at spot idx of the 2 arrays created in __init__

class NoteDataset(Dataset):
    def __init__(self):
        self.frame_labels = []
        self.frames = []
        
        # loop over each training video to assign a label to each frame and aggregate them all in one training array(s)
        for file in os.listdir("./training_data"):
            
            # get frames of video and store them in self.frames
            v_frames, _, _ = read_video("./training_data/%s" % file)
            for frame_num in range(v_frames.shape[0]):
                self.frames.append(v_frames[frame_num])

            
            # use library to get labels for each frame
            cur_video_labels, num_frames = generate_labels("./training_data/", file)
            tmp_label_aggregator = [[] for i in range(num_frames)]
            for frame, note in cur_video_labels:
                tmp_label_aggregator[int(frame)].append(note)
            self.frame_labels.extend(tmp_label_aggregator)
    
    
    def __len__(self):
        return len(self.frame_labels)
    
    def __getitem__(self, idx):
        return self.frames[idx], self.frame_labels[idx]