# Prep

In [17]:
# build training set
# imports
import os
import glob
import pickle
import pandas as pd
import numpy as np

from sklearn.preprocessing import LabelEncoder
from torchvision import transforms as vt

import torch
import torch.nn as nn


## User Defined Functions

In [6]:
def check_dict(py_dict,dict_key):
    if dict_key in py_dict:
        val = py_dict[dict_key]
    else:
        val = ''
    return val

In [7]:
def find_matching_strings(string_list, target_string):
    # make case insensitive
    lower_target_string = target_string.lower()
    matching_strings = [s.lower() for s in string_list if s.lower() in lower_target_string]
    return matching_strings

# Preprocess Data

In [8]:
instrument_list = ['Clarinet','Sax Alto','Flute','Violin','Trumpet','Cello','Sax Tenor','Piccolo','Sax Soprano','Sax Baritone','Oboe','Double Bass']
def create_music_record(dir_path, instrument_list):
    # find wav file
    wav_files = glob.glob(os.path.join(dir_path, '**', '*.wav'), recursive=True)
    # skip if more/less than 1 file found
    if len(wav_files) == 1:
        wav_file = wav_files[0]
        # print("wav file:", wav_file)
    else:
        print(f"incorrect number of wav files found in {dir_path}")
        return
    
    # get sound_metadata 
    metadata_file = os.path.join(dir_path,'sound_metadata.pkl')
    # print(metadata_file)
    # check if file exists:
    metadata_exists = os.path.isfile(metadata_file)
    # if exists, load
    if metadata_exists:
        with open(metadata_file, 'rb') as file:
            sound_metadata = pickle.load(file)
        
        # format metadata
        ## parse name to get instrument and note (target variables)
        try:
            sound_name = sound_metadata['name']
            # sound_name_split = sound_name.split('-')
            # sound_instr = sound_name_split[0].strip()
            # sound_note = sound_name_split[1].strip()
            found_instruments = find_matching_strings(instrument_list,sound_name)
            if len(found_instruments) == 1:
                sound_instr = found_instruments[0]
            else:
                # default to blank and investigate later
                sound_instr = ''
        except IndexError as e:
            sound_instr = ''
            # sound_note = ''
        
        # if "channels" in sound_metadata:
        #     channels = sound_metadata['channels']
        # else:
        #     channels = ''
        
        
        music_record = {"relative_path":wav_file,
                        "channels":check_dict(sound_metadata,"channels"),
                        "filesize":check_dict(sound_metadata,"filesize"),
                        "bitrate":check_dict(sound_metadata,"bitrate"),
                        "bitdepth":check_dict(sound_metadata,"bitdepth"),
                        "duration":check_dict(sound_metadata,"duration"),
                        "samplerate":check_dict(sound_metadata,"samplerate"),
                        "instrument_name":sound_instr}
        return music_record
    
    
    

# create_music_record(dir_list[1])



In [9]:
def create_music_set(dir_list):
    music_list = [create_music_record(dir_,instrument_list = instrument_list) for dir_ in dir_list] 
    # filter
    music_fltrd = [i for i in music_list if i is not None]
    music_df = pd.DataFrame(music_fltrd)
    return music_df

dir_list = []
for root, dirs, files in os.walk("./freesound"):
        for directory in dirs:
            dir_list.append(os.path.join(root,directory))
print(len(dir_list))
music_df = create_music_set(dir_list)

7554
incorrect number of wav files found in ./freesound\247146


## Clean Up

In [10]:
# clean up
n_missing = music_df.loc[music_df['instrument_name'] == ''].shape[0]
pct_missing = round(10*n_missing/music_df.shape[0],2)
print(f"Records missing target variable: {n_missing}. Removing  {pct_missing}% of records from our data")


music_df = music_df.loc[music_df['instrument_name'] != '']

Records missing target variable: 1. Removing  0.0% of records from our data


## Numerically Encode Target Variable

In [11]:
# numerical encode target instrument
encoder = LabelEncoder()
music_df['target_instrument'] = encoder.fit_transform(music_df['instrument_name']).astype('int64')
music_df['target_instrument'].unique()

instrument_map = music_df[['instrument_name','target_instrument']].value_counts().reset_index()
instrument_map

Unnamed: 0,instrument_name,target_instrument,count
0,clarinet,1,1670
1,violin,11,1052
2,trumpet,10,829
3,cello,0,726
4,sax alto,6,717
5,flute,3,694
6,sax tenor,9,449
7,piccolo,5,388
8,sax soprano,8,334
9,sax baritone,7,288


# Preprocess for transformer architecture

In [12]:
def patchify(images, n_patches):
    n, c, h, w = images.shape

    assert h == w, "Patchify method is implemented for square images only"

    patches = torch.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2)
    patch_size = h // n_patches

    for idx, image in enumerate(images):
        for i in range(n_patches):
            for j in range(n_patches):
                patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
                patches[idx, i * n_patches + j] = patch.flatten()
    return patches

In [15]:
class MyViT(nn.Module):
  def __init__(self, chw=(1, 28, 28), n_patches=7):
    # Super constructor
    super(MyViT, self).__init__()

    # Attributes
    self.chw = chw # (C, H, W)
    self.n_patches = n_patches

    assert chw[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
    assert chw[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"

  def forward(self, images):
    patches = patchify(images, self.n_patches)
    return patches

In [96]:
torch.manual_seed(17)

  # Current model
model = MyViT(
  chw=(1, 28, 28),
  n_patches=7
)

x = torch.randn(1, 2, 28, 28) # Dummy images
x2 = model(x).numpy()
print(x2.shape) # torch.Size([7, 49, 16])
print(x2)

(1, 49, 32)
[[[-1.0461915   1.2305212   1.8662122  ...  1.1032946  -1.3344064
   -2.581236  ]
  [ 0.06754854  0.44813395 -0.5985899  ... -0.76082724 -0.72297764
    0.5215839 ]
  [ 0.21132568  0.13313377  0.25223497 ... -0.06073891 -0.8029356
    0.06508627]
  ...
  [-0.88790995  0.79799265 -0.36270887 ...  1.0173298  -0.5579972
    0.11708623]
  [-0.59282196 -1.5264444  -0.07835507 ...  0.31156117 -0.1824684
    1.9655087 ]
  [-1.2054605  -0.81623    -0.9593591  ...  0.44824982  0.52507913
   -1.6396205 ]]]


In [95]:
from skimage.util.shape import view_as_blocks
torch.manual_seed(17)
x = torch.randn(1,2,28, 28).numpy()

x3 = view_as_blocks(x, block_shape=(1,1,7,7))

x3.squeeze().shape

(2, 4, 4, 7, 7)

In [102]:
torch.manual_seed(17)
x = torch.randn(1,2,64,324).numpy()

x3 = view_as_blocks(x, block_shape=(1,1,8,18))

x3.reshape(36,1152)

array([[-1.0461915e+00,  1.2305212e+00,  1.8662122e+00, ...,
         3.8504535e-01, -7.6659721e-01,  1.5816286e+00],
       [ 2.3505518e-01,  1.2705117e+00, -4.9797949e-01, ...,
        -9.2189276e-01,  1.3849835e+00, -1.0581790e+00],
       [ 1.5265250e+00,  5.8004040e-01,  1.8565146e+00, ...,
        -3.8039985e-01,  1.2446707e+00,  1.4971067e+00],
       ...,
       [-6.0182762e-01, -5.2705139e-01,  7.2333544e-01, ...,
        -1.1251336e+00, -4.3829238e-01, -7.4847996e-01],
       [ 8.6367428e-01,  1.0425885e+00, -1.1460442e-03, ...,
        -1.1520212e+00,  1.2362085e-01,  5.3020716e-01],
       [ 3.1318608e-01, -1.1552500e+00, -1.2039675e+00, ...,
         1.8449354e+00,  7.2837526e-01, -1.8400691e+00]], dtype=float32)

In [101]:
41472/36

1152.0

In [87]:
324/18

18.0