In [2]:
import torch, torchvision, PIL, numpy as np
import pathlib
import PIL
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from tqdm.auto import tqdm

In [None]:
# You can get more datasets here https://repository.cloudlab.zhaw.ch/artifactory/deepscores/archives/2017/
!wget https://tuggeluk.github.io/class_names/class_names.csv
!wget https://repository.cloudlab.zhaw.ch/artifactory/deepscores/classification/DeepScores2017_classification.zip


In [None]:
!unzip DeepScores2017_classification.zip -d music_dataset/

In [None]:
class ObjectDetectionDataset:
    def __init__(self, root_dir, transform=None, transform_label=None):
        root_dir = pathlib.Path(root_dir).resolve()
        self.root_dir = root_dir
        self.transform = transform
        self.transform_label = transform_label
        with open(root_dir / 'classes.txt') as f:
          # for every line in the data set and strip removes any spaces and leading character
            self.classes = [w.strip() for w in f] 
        # Now split all the dimensions and get their floating point numbers from the file. 
        def parse_box(line):
            kls, cx, cy, sx, sy = line.split()
            return int(kls), float(cx), float(cy), float(sx), float(sy)
        # opens the file and for each line in the file, it calls parse_box. 
        def parse_boxes(fn):
            with open(fn) as f:
                return [parse_box(l) for l in f]
        # parse the image file name and the bounding box 
        self.fns_labels = [(imgfn, parse_boxes(imgfn.with_suffix('.txt')))
                           for imgfn in sorted(root_dir.glob('*.jpg'))]

    def __len__(self):
        # return the length of the dataset
        return len(self.fns_labels)

    def __getitem__(self, i):
        if torch.is_tensor(i):
            i = i.item()
        imgfn, label = self.fns_labels[i]
        # read the image from PILLOW library
        img = PIL.Image.open(imgfn)
        if self.transform:
            # apply the transform to the image if it exists!
            img = self.transform(img)
        if self.transform_label:
            # same with the label, transform it! 
            label = self.transform_label(label)
        return (img, label)