# check stft and fft

In [1]:
import os
import numpy as np

In [2]:
%matplotlib inline
import matplotlib.pyplot as plt

In [3]:
import torch
import torch.nn.functional as F

In [4]:
from scipy.signal import stft
from scipy.fftpack import fft, fftshift

In [5]:
import h5py
import tqdm
from musicnet import MusicNetHDF5

In [6]:
from ipywidgets import widgets

<br>

## Visualizer

In [7]:
h5_in = h5py.File("./data/musicnet/musicnet_11khz_valid.h5", "r")
dataset = MusicNetHDF5(h5_in, window=4096, stride=512, at=None)

In [8]:
def float_slider(min, max, step, value=None, continuous_update=False):
    layout = widgets.Layout(min_width='500px', display='flex')
    value = min if value is None else min
    return widgets.FloatSlider(min=min, max=max, step=step,value=value,
                               continuous_update=continuous_update, layout=layout)

def int_slider(min, max, step, value=None, continuous_update=True):
    layout = widgets.Layout(min_width='500px', display='flex')
    value = min if value is None else min
    return widgets.IntSlider(min=min, max=max, step=step,value=value,
                             continuous_update=continuous_update, layout=layout)


def update_w_sample(*args):
    beg, end = dataset.limits[w_keys.value]
    w_sample.max = end - beg
    w_sample.value = min(end - beg, max(0, w_sample.value))

w_keys = widgets.Dropdown(options=dir(dataset))
w_sample = int_slider(0, 1, 1)

w_keys.observe(update_w_sample, 'value')
update_w_sample()

In [9]:
def draw(key, ix):
    beg, end = dataset.limits[key]
    data, labels = dataset[ix + beg]

    f, t, z = stft(data, fs=11000, nperseg=120, noverlap=60, axis=-1)

    fig, ax = plt.subplots(1, 1, figsize=(12, 3))

    ax.pcolormesh(t, f, abs(z), vmin=0, cmap=plt.cm.jet)
    ax.set_title(f'STFT Magnitude {key}')
    ax.set_ylabel('Frequency [Hz]')
    ax.set_xlabel('Time [sec]')

    plt.show()

In [None]:
widgets.interact(draw, key=w_keys, ix=w_sample);

<hr>

In [11]:
feed = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)

In [12]:
for bx, by in tqdm.tqdm(feed):
    pass

100%|██████████| 1416/1416 [00:08<00:00, 169.53it/s]


## Test random access to HDF5 file.

Compare random access to objects and labels (NLCS)

In [13]:
cache = dataset
objects, labels, window = cache.objects, cache.labels, cache.window

Prepare indices and pointers

In [14]:
lengths = np.r_[list(map(len, objects))]
indptr = np.r_[0, lengths.cumsum()]

indices = np.random.randint(indptr[-1], size=1000)

lookup = indptr.searchsorted(indices, side="right") - 1
assert np.all(0 <= lookup) and np.all(lookup < len(objects))

ix_access = np.minimum(indices - indptr[lookup], lengths[lookup] - window)

### Data access

Measure copy and indexing

In [15]:
out = np.empty(window)
zeros = np.zeros(window)

In [16]:
%%timeit -n 100
for key, ix in zip(lookup, ix_access):
    obj = objects[key]
    out[:] = zeros[:]

2.55 ms ± 168 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [17]:
%%timeit -n 10
for key, ix in zip(lookup, ix_access):
    obj = objects[key]
    out[:] = obj[ix:ix+window]

163 ms ± 1.95 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Label access

In [18]:
out = np.empty(84)

In [19]:
%%timeit -n 1000
for key, ix in zip(lookup, ix_access):
    lab = labels[key]
    zeros = np.zeros(84)
    pass
    out[:] = zeros[:]

1.76 ms ± 31.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [20]:
%%timeit -n 100
for key, ix in zip(lookup, ix_access):
    lab = labels[key]
    zeros = np.zeros(84)
    for a, b, i in lab.find_overlap(ix, ix+1):
        zeros[i] = 1
    out[:] = zeros[:]

5.82 ms ± 475 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [21]:
%%timeit -n 100
for key, ix in zip(lookup, ix_access):
    lab = labels[key]
    ind = [i for _, _, i in lab.find_overlap(ix, ix+1)]
    out[:] = np.bincount(ind, minlength=84) > 0

9.31 ms ± 284 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


<hr>

In [22]:
dir(cache)

['id_1792', 'id_1876', 'id_2131', 'id_2384', 'id_2514', 'id_2567']

In [23]:
for key in dir(cache):
    result = cache[slice(*cache.limits[key])]
    assert np.allclose(cache[key], result)

In [24]:
%%timeit
chunk = cache['id_1876']

11.1 ms ± 8.43 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [25]:
%%timeit
result = cache[slice(*cache.limits['id_1876'])]

1.13 s ± 1.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


<br>