In [1]:
# TODO: later export script to py file

In [2]:
import torch
import torch.utils.data as data
import pandas as pd
import os
import re
from utils import str2midi
from warnings import warn

In [3]:
%%bash
# open output in text editor
cat /homes/fy105/pytorch-mums/src/directories.csv

bash: /import/linux/miniconda/3/4.7.12/lib/libtinfo.so.6: no version information available (required by bash)


subpath,instrument_name,instrument_family,instrument_source,type,notes
DVD 2/KEYBOARDS/CELESTA,celesta,keyboard,acoustic,note,
DVD 2/KEYBOARDS/ORGAN/Pedals_BarPlen_Reeds,organ,keyboard,acoustic,note,
DVD 2/KEYBOARDS/ORGAN/SymPlenum_56,organ,keyboard,acoustic,note,
DVD 2/KEYBOARDS/ORGAN/Posaune pedals,organ,keyboard,acoustic,note,
DVD 2/KEYBOARDS/ORGAN/Prinzipal 2,organ,keyboard,acoustic,note,
DVD 2/KEYBOARDS/ORGAN/Nasat,organ,keyboard,acoustic,note,
DVD 2/KEYBOARDS/ORGAN/Crumhorn,organ,keyboard,acoustic,note,
DVD 2/KEYBOARDS/ORGAN/Hauptwerk All_27,organ,keyboard,acoustic,note,
DVD 2/KEYBOARDS/ORGAN/Dulzian,organ,keyboard,acoustic,note,
DVD 2/KEYBOARDS/ORGAN/SoloCornet_29,organ,keyboard,acoustic,note,
DVD 2/KEYBOARDS/ORGAN/Scharf4_28,organ,keyboard,acoustic,note,
DVD 2/KEYBOARDS/ORGAN/Prinzipal4_28,organ,keyboard,acoustic,note,
DVD 2/KEYBOARDS/ORGAN/Koppelflote_Positiv,organ,keyboard,acoustic,note,
"DVD 2/KEYBOARDS/ORGAN/Brustwerk,all stops",organ,keyboard,acoustic,note,
DVD 2/KEYBOARDS

In [4]:
# from directory_tree import DisplayTree

import yaml

with open('../../cosi/config/mums.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

In [5]:
class MUMS(data.Dataset):
    """ PyTorch dataset for MUMS.
        Adapted from pytorch-nsynth: https://github.com/kwon-young/pytorch-nsynth
    
    Args:
        root (string): Root directory of dataset.
        transform (callable, optional): A function/transform that takes in
                a sample and returns a transformed version.
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        include_dirs (list): List of bottom-level directories to include in the dataset.
            If empty, all bottom-level directories are included. Refer to directories.csv.
        blacklist_pattern (list): List of strings used to blacklist dataset elements.
            If one of the strings is present in the audio filename, this sample
            together with its metadata is removed from the dataset. Case-insensitive.
    """

    def __init__(self, root, transform=None, target_transform=None,
                 include_dirs=[], blacklist_pattern=[]):
        
        assert(isinstance(root, str))
        assert(isinstance(include_dirs, list))
        assert(isinstance(blacklist_pattern, list))

        self.root = root
        self.include_dirs = include_dirs

        PATH_DIRECTORIES = './directories.csv'  # csv file listing bottom-level directories
        df_directories = pd.read_csv(PATH_DIRECTORIES)
        
        if self.include_dirs:    # otherwise include all directories by default
            df_directories = df_directories[df_directories['subpath'].isin(self.include_dirs)]
    
        self.filenames = []
        self.json_data = {} # metadata

        blacklist = lambda x: any(re.search(pattern, x, re.IGNORECASE) for pattern in blacklist_pattern)

        for index, row in df_directories.iterrows():
            path_dir = os.path.join(self.root, row['subpath'])

            if blacklist(path_dir):
                continue

            instrument_name_str = row['instrument_name']
            instrument_family_str = row['instrument_family']
            instrument_source_str = row['instrument_source']
            type_str = row['type']

            # TODO: LabelEncoder for instrument_name, instrument_family, instrument_source, type?

            for f in os.listdir(path_dir):
                if f.endswith('.wav'):
                    print(f)    # DEBUG
                    if blacklist(f):
                        continue

                    path_f = os.path.join(path_dir, f)
                    self.filenames.append(path_f)

                    targets_f = {'instrument_name_str': instrument_name_str,
                                  'instrument_family_str': instrument_family_str,
                                  'instrument_source_str': instrument_source_str,
                                  'type_str': type_str}

                    if type_str == 'note':
                        pitch_height_str = re.search('[A-Ga-g]#?\d', f)

                        if pitch_height_str:
                            pitch_height_str = pitch_height_str.group(0)
                            pitch_class_str = re.search('[A-Ga-g]#?', pitch_height_str).group(0)
                            pitch = str2midi(pitch_height_str)

                        else:
                            warn(f"Pitch height not found in {f}")
                            warn(f"Pitch class not found in {f}")
                            warn(f"Pitch not found in {f}")
                            pitch_height_str = None
                            pitch_class_str = None
                            pitch = None
                            
                        targets_f['pitch_height_str'] = pitch_height_str
                        targets_f['pitch_class_str'] = pitch_class_str
                        targets_f['pitch'] = pitch

                    elif type_str == 'chord':
                        if 'ELECTRIC GUITAR' in path_f:
                            pitch_class_str = re.search('_[A-Ga-g]#?', f).group(0)
                            targets_f['root_pitch_class_str'] = pitch_class_str[1:]  # remove underscore

                            chord_quality_str = re.search('GUITAR( |_)[A-Z](( ?(([A-Z0-9]+)?))+)?', path_dir).group(0)[7:]   # remove leading 'GUITAR'
                            if re.match(r'[A-Z](( ?(([A-Z0-9]+)?))+)?S$', chord_quality_str):
                                chord_quality_str = chord_quality_str[:-1]  # remove trailing 'S'
                            elif re.match(r'[A-Z](( ?(([A-Z0-9]+)?))+)? STOPPED', chord_quality_str):
                                chord_quality_str = chord_quality_str[:-8]  # remove trailing ' STOPPED'

                            targets_f['chord_quality_str'] = chord_quality_str  # default case

                        elif 'ACCORDION' in path_f:
                            pitch_class_str = re.search(' [A-Ga-g]#? [A-Z]+', f).group(0)
                            if 'FLAT' in pitch_class_str:
                                targets_f['root_pitch_class_str'] = f'{pitch_class_str[1:2]}b'  # use b symbol
                                chord_quality_str = re.search(' [A-Ga-g]#? [A-Z]+ ?([A-Z]+)? ?([A-Z0-9]+)?', f).group(0)[8:]   # remove leading pitch char and 'FLAT'

                            elif '#' in pitch_class_str:
                                targets_f['root_pitch_class_str'] = pitch_class_str[1:3]  # remove trailing word
                                chord_quality_str = re.search(' [A-Ga-g]#? [A-Z]+ ?([A-Z]+)? ?([A-Z0-9]+)?', f).group(0)[4:]   # remove leading pitch char and '#'
                            else:
                                targets_f['root_pitch_class_str'] = pitch_class_str[1:2]  # remove trailing word
                                chord_quality_str = re.search(' [A-Ga-g]#? [A-Z]+ ?([A-Z]+)? ?([A-Z0-9]+)?', f).group(0)[3:]   # remove leading pitch char

                            targets_f['chord_quality_str'] = chord_quality_str

                        elif 'ORGAN' in path_f:
                            warn(f"No pitch class for {f}")
                            warn(f"No chord quality for {f}")

                    self.json_data[path_f] = targets_f
            
            # if index > 4:  # DEBUG
            #     break

        # TODO: categorical field lists for LabelEncoder on metadata?

        # self.categorical_field_list = categorical_field_list
        # self.le = []
        # for i, field in enumerate(self.categorical_field_list):
        #     self.le.append(LabelEncoder())
        #     field_values = [value[field] for value in self.json_data.values()]
        #     self.le[i].fit(field_values)

        self.transform = transform
        self.target_transform = target_transform

        return
    
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx) -> tuple[torch.Tensor, list, dict]:
        pass
        # TODO
        # name = self.filenames[index]
        # _, sample = scipy.io.wavfile.read(name)
        # target = self.json_data[os.path.splitext(os.path.basename(name))[0]]
        # categorical_target = [
        #     le.transform([target[field]])[0]
        #     for field, le in zip(self.categorical_field_list, self.le)]
        # if self.transform is not None:
        #     sample = self.transform(sample)
        # if self.target_transform is not None:
        #     target = self.target_transform(target)
        # return [sample, *categorical_target, target]

In [6]:
if __name__ == "__main__":
    pass
    # TODO
    # # audio samples are loaded as an int16 numpy array
    # # rescale intensity range as float [-1, 1]
    # toFloat = transforms.Lambda(lambda x: x / np.iinfo(np.int16).max)
    # # use instrument_family and instrument_source as classification targets
    # dataset = NSynth(
    #     "../nsynth-test",
    #     transform=toFloat,
    #     blacklist_pattern=["string"],  # blacklist string instrument
    #     categorical_field_list=["instrument_family", "instrument_source"])
    # loader = data.DataLoader(dataset, batch_size=32, shuffle=True)
    # for samples, instrument_family_target, instrument_source_target, targets \
    #         in loader:
    #     print(samples.shape, instrument_family_target.shape,
    #           instrument_source_target.shape)
    #     print(torch.min(samples), torch.max(samples))

In [7]:
MUMS(root=config['path']['mums'],
     include_dirs=config['include_dirs'],
     blacklist_pattern=config['blacklist_pattern'])

CELESTA_F5.wav
CELESTA_D4.wav
CELESTA_G#6.wav
CELESTA_F7.wav
CELESTA_C#7.wav
CELESTA_C#4.wav
CELESTA_C5.wav
CELESTA_A#7.wav
CELESTA_F#4.wav
CELESTA_C7.wav
CELESTA_E6.wav
CELESTA_E4.wav
CELESTA_G7.wav
CELESTA_B4.wav
CELESTA_G#4.wav
CELESTA_C6.wav
CELESTA_D6.wav
CELESTA_C8.wav
CELESTA_B6.wav
CELESTA_A5.wav
CELESTA_B5.wav
CELESTA_F#7.wav
CELESTA_C4.wav
CELESTA_A#6.wav
CELESTA_D#5.wav
CELESTA_B7.wav
CELESTA_G#5.wav
CELESTA_D#7.wav
CELESTA_G4.wav
CELESTA_F#6.wav
CELESTA_E7.wav
CELESTA_C#6.wav
CELESTA_C#5.wav
CELESTA_A7.wav
CELESTA_G5.wav
CELESTA_G#7.wav
CELESTA_F4.wav
CELESTA_A#5.wav
CELESTA_D7.wav
CELESTA_A6.wav
CELESTA_F#5.wav
CELESTA_D5.wav
CELESTA_G6.wav
CELESTA_E5.wav
CELESTA_A#4.wav
CELESTA_F6.wav
CELESTA_D#6.wav
CELESTA_A4.wav
CELESTA_D#4.wav
Ped_BarPlen_Reeds_c2.wav
Ped_BarPlen_Reeds_e2.wav
Ped_BarPlen_Reeds_d3.wav
Ped_BarPlen_Reeds_f#3.wav
Ped_BarPlen_Reeds_d1.wav
Ped_BarPlen_Reeds_c1.wav
Ped_BarPlen_Reeds_g#1.wav
Ped_BarPlen_Reeds_e1.wav
Ped_BarPlen_Reeds_a#2.wav
Ped_BarPlen_Reeds

Koppelflote f#3.wav
Koppelflote e2.wav
Koppelflote e6.wav
Koppelflote f#4.wav
Koppelflote g#4.wav
Koppelflote f#5.wav
Koppelflote a#5.wav
Koppelflote c2.wav
Koppelflote c5.wav
Koppelflote d4.wav
Koppelflote f#6.wav
Koppelflote g#3.wav
Koppelflote e4.wav
Koppelflote c6.wav
Koppelflote a#3.wav
Koppelflote d6.wav
Koppelflote d2.wav
Koppelflote g#2.wav
Koppelflote g#5.wav
Koppelflote a#4.wav
Koppelflote c4.wav
Koppelflote a#2.wav
Koppelflote d5.wav
Koppelflote c3.wav
Koppelflote e5.wav
Koppelflote f#2.wav
Koppelflote d3.wav
Koppelflote e3.wav
Brustwerk,all stops e6.wav
Brustwerk,all stops c4.wav
Brustwerk,all stops d4.wav
Brustwerk,all stops a#3.wav
Brustwerk,all stops g#3.wav
Brustwerk,all stops e2.wav
Brustwerk,all stops f#5.wav
Brustwerk,all stops a#5.wav
Brustwerk,all stops c6.wav
Brustwerk,all stops d5.wav
Brustwerk,all stops e3.wav
Brustwerk,all stops f#2.wav
Brustwerk,all stops e4.wav
Brustwerk,all stops g#5.wav
Brustwerk,all stops g#4.wav
Brustwerk,all stops a#2.wav
Brustwerk,all s

  warn(f"No pitch class for {f}")
  warn(f"No chord quality for {f}")
  warn(f"No pitch class for {f}")
  warn(f"No chord quality for {f}")
  warn(f"Pitch height not found in {f}")
  warn(f"Pitch class not found in {f}")
  warn(f"Pitch not found in {f}")
  warn(f"Pitch height not found in {f}")
  warn(f"Pitch class not found in {f}")
  warn(f"Pitch not found in {f}")
  warn(f"Pitch height not found in {f}")
  warn(f"Pitch class not found in {f}")
  warn(f"Pitch not found in {f}")
  warn(f"Pitch height not found in {f}")
  warn(f"Pitch class not found in {f}")
  warn(f"Pitch not found in {f}")
  warn(f"Pitch height not found in {f}")
  warn(f"Pitch class not found in {f}")
  warn(f"Pitch not found in {f}")
  warn(f"Pitch height not found in {f}")
  warn(f"Pitch class not found in {f}")
  warn(f"Pitch not found in {f}")
  warn(f"Pitch height not found in {f}")
  warn(f"Pitch class not found in {f}")
  warn(f"Pitch not found in {f}")
  warn(f"Pitch height not found in {f}")
  warn(f"Pitch

PERCUSSION_CHA CHA MIX  #1.wav
PERCUSSION_CHA CHA MIX  #2.wav
PERCUSSION_WAWONKO MIX  #5.wav
PERCUSSION_WAWONKO MIX  #7.wav
PERCUSSION_WAWONKO MIX  #4.wav
PERCUSSION_WAWONKO MIX  #10.wav
PERCUSSION_WAWONKO MIX  #8.wav
PERCUSSION_WAWONKO MIX  #2.wav
PERCUSSION_WAWONKO MIX  #1.wav
PERCUSSION_WAWONKO MIX  #3.wav
PERCUSSION_WAWONKO MIX  #9.wav
PERCUSSION_WAWONKO MIX  #6.wav
TRIANGLE PATTERN #3.wav
TRIANGLE _WAWONKO PATTERN .wav
TRIANGLE PATTERN #2.wav
TRIANGLE PATTERN #1.wav
TRIANGLE _SAMBA PATTERN #1.wav
TRIANGLE _SAMBA PATTERN #2.wav
TRIANGLE PATTERN #4.wav
PERCUSSION_EGYPTIAN MIX  #1.wav
PERCUSSION_EGYPTIAN MIX  #3.wav
PERCUSSION_EGYPTIAN MIX  #2.wav
WHISTLE PATTERNS.wav
DARBUKA PATTERN.wav
BURMA TEMPLE BELLSPATTERN#4.wav
BURMA TEMPLE BELLSPATTERN#3.wav
BURMA TEMPLE BELLSPATTERN#5.wav
BURMA TEMPLE BELLSPATTERN#2.wav
BURMA TEMPLE BELLSPATTERN#1.wav
PANDERO PATTERN #4.wav
PANDERO PATTERN #1.wav
PANDERO PATTERN #3.wav
PANDERO PATTERN #2.wav
PANDERO PATTERN #5.wav
FINGER CYMBAL OPEN_CLOSE.w

  warn(f"Pitch height not found in {f}")
  warn(f"Pitch class not found in {f}")
  warn(f"Pitch not found in {f}")
  warn(f"Pitch height not found in {f}")
  warn(f"Pitch class not found in {f}")
  warn(f"Pitch not found in {f}")
  warn(f"Pitch height not found in {f}")
  warn(f"Pitch class not found in {f}")
  warn(f"Pitch not found in {f}")
  warn(f"Pitch height not found in {f}")
  warn(f"Pitch class not found in {f}")
  warn(f"Pitch not found in {f}")
  warn(f"Pitch height not found in {f}")
  warn(f"Pitch class not found in {f}")
  warn(f"Pitch not found in {f}")
  warn(f"Pitch height not found in {f}")
  warn(f"Pitch class not found in {f}")
  warn(f"Pitch not found in {f}")
  warn(f"Pitch height not found in {f}")
  warn(f"Pitch class not found in {f}")
  warn(f"Pitch not found in {f}")
  warn(f"Pitch height not found in {f}")
  warn(f"Pitch class not found in {f}")
  warn(f"Pitch not found in {f}")
  warn(f"Pitch height not found in {f}")
  warn(f"Pitch class not found in {f}")

<__main__.MUMS at 0x74a39d68fe50>

In [8]:
#TODO: listen to VlnAstp_4.20sec - confirm pitch class
#       if cannot resolve through regex, reconfigure code to assign
#       pitch class as None - can resolve through custom collate_fn

In [9]:
        # def split(line):    # split line by comma, ignore commas inside (double) quotes
        #     lexer = shlex.shlex(line)
        #     lexer.quotes = '"'
        #     lexer.whitespace = ','
        #     lexer.whitespace_split = True
        #     return list(lexer)

        # self.root = root
        # if not include_dirs:    # include all directories by default
        #     with open(path_folders, 'r') as f:
        #         next(f)   # skip header
        #         self.include_dirs = [split(row)[0] for row in f]
        # else:
        #     self.include_dirs = include_dirs

        # print(self.include_dirs)  # DEBUG
    
        # self.filenames = []