In [10]:
import numpy as np
import pywt
import librosa 
import os
from data import WaveData, SR
import matplotlib.pyplot as plt
import tqdm
import soundfile as sf

# Import a test dataset

DATA_PATH = 'data/scott_one_song_moises/songs/'
PERM_FOLDERS = os.listdir(DATA_PATH)
print(PERM_FOLDERS)

def get_perms(data_path=DATA_PATH, perm_folders=PERM_FOLDERS, sr=SR):
    perms = []
    data = {}
    print(perm_folders)
    for folder in perm_folders:
        if folder == '.DS_Store':
            continue
        print(folder)
        # called folder/Test -> read the Test part with librosa
        if 'Test' in os.listdir(data_path + folder):
            
            
            for song in os.listdir(data_path + folder + '/Test'):
            
                perms.append(song)

                if song == '.DS_Store':
                    continue
                
                waveform, sr = librosa.load(data_path + folder + '/Test/' + song, mono=False, sr=sr)
                # print(waveform.shape)
                waveform = np.transpose(waveform)
                # print(waveform.shape)
                data[song] = WaveData(song, waveform, None)
            
    return perms, data

perms, TRAIN_DATA = get_perms()


['song_bass1_drums1_piano1', '.DS_Store', 'song_piano1_vocals1', 'song_bass1_guitar1_piano1', 'song_bass1_drums1', 'song_drums1_guitar1_piano1', 'song_bass1_drums1_guitar1_vocals1', 'song_drums1_piano1']
['song_bass1_drums1_piano1', '.DS_Store', 'song_piano1_vocals1', 'song_bass1_guitar1_piano1', 'song_bass1_drums1', 'song_drums1_guitar1_piano1', 'song_bass1_drums1_guitar1_vocals1', 'song_drums1_piano1']
song_bass1_drums1_piano1
song_piano1_vocals1
song_bass1_guitar1_piano1
song_bass1_drums1
song_drums1_guitar1_piano1
song_bass1_drums1_guitar1_vocals1
song_drums1_piano1


In [11]:

# Get song length analytics

def length_analysis(perms=perms, data=TRAIN_DATA):
    lengths = []
    for song in perms:
        try:
            lengths.append(data[song].waveform.shape[0])
        except Exception as e:
            print(e)
            print(song)
    return lengths

def print_length_analytics(lengths):
    print('Max length:', max(lengths))
    print('Min length:', min(lengths))
    print('Mean length:', np.mean(lengths))
    print('Median length:', np.median(lengths))
    # graph of distribution
    plt.hist(lengths, bins=20)
    plt.show()

lengths = length_analysis()
# print_length_analytics(lengths)

In [None]:
# TODO: Remove songs above a certain length


In [22]:

# Get the wavelet transform of the songs 

def get_wavelet_transform(data=TRAIN_DATA):
    for i, song in tqdm.tqdm(enumerate(perms)):
        dwt = pywt.wavedec(data[song].waveform, 'db1')
        print(f"len dwt = {len(dwt)}")
        # convert to numpy array
        npdwt = np.array(dwt)
        
        print(npdwt.shape)
        data[song].dwt = dwt
        data[song].npdwt = npdwt
    return data

TRAIN_DATA = get_wavelet_transform()
# Pad up to max dims -- print dims
def pad_dwt(data=TRAIN_DATA):
    max_shape = None
    for song in perms:
        if max_shape is None:
            max_shape = data[song].npdwt.shape
        else:
            max_shape = np.maximum(max_shape, data[song].npdwt.shape)

    new_dwt = np.zeros_like(max_shape)
    for song in perms:
        # pad the dwt
        # print(new_wavedata.dwt.shape)
        pad_width = [(0, max_shape[0] - data[song].npdwt.shape[0])]
        for i in range(1, len(max_shape)):
            pad_width.append((0, max_shape[i] - data[song].npdwt.shape[i]))
        new_dwt = np.pad(data[song].npdwt, pad_width, mode='constant', constant_values=0)
        data[song].npdwt = new_dwt
        print(new_dwt.shape)
    return data

TRAIN_DATA = pad_dwt()
# save the numpy file to song folder

1it [00:00,  5.16it/s]

len dwt = 2
(2, 9154482, 1)


2it [00:00,  4.81it/s]

len dwt = 2
(2, 9154482, 1)


3it [00:00,  4.75it/s]

len dwt = 2
(2, 9154482, 1)


4it [00:00,  4.75it/s]

len dwt = 2
(2, 9154482, 1)


5it [00:01,  4.78it/s]

len dwt = 2
(2, 9154482, 1)


6it [00:01,  4.77it/s]

len dwt = 2
(2, 9154482, 1)


8it [00:01,  5.11it/s]

len dwt = 2
(2, 9154482, 1)
len dwt = 2
(2, 9154482, 1)


10it [00:01,  5.75it/s]

len dwt = 2
(2, 9154482, 1)
len dwt = 2
(2, 9154482, 1)


12it [00:02,  6.09it/s]

len dwt = 2
(2, 9154482, 1)
len dwt = 2
(2, 9154482, 1)


13it [00:02,  5.40it/s]


len dwt = 2
(2, 9154482, 1)
(2, 9154482, 1)
(2, 9154482, 1)
(2, 9154482, 1)
(2, 9154482, 1)
(2, 9154482, 1)
(2, 9154482, 1)
(2, 9154482, 1)
(2, 9154482, 1)
(2, 9154482, 1)
(2, 9154482, 1)
(2, 9154482, 1)
(2, 9154482, 1)
(2, 9154482, 1)


In [30]:
# try idwt to see if it works
def idwt(data=TRAIN_DATA):
    idwts = []
    for i, song in tqdm.tqdm(enumerate(perms)):
        # convert to  (list, tuple) from numpy array
        idwt = pywt.waverec(data[song].dwt, 'db1')
        idwts.append(idwt)
    return idwts

idwts = idwt()
print(idwts[0].shape)

# save the idwt to a file\
import soundfile as sf
def save_idwt(data=TRAIN_DATA, idwts=idwts):
    song = perms[0]
    sf.write('idwt.wav', idwts[0], SR)

save_idwt()


13it [00:01, 10.95it/s]

(9154482, 2)



