# check stft and fft

In [1]:
import os
import numpy as np

In [2]:
%matplotlib inline
import matplotlib.pyplot as plt

In [3]:
import torch
import torch.nn.functional as F

In [4]:
from scipy.signal import stft
from scipy.fftpack import fft, fftshift

In [5]:
import h5py
from musicnet import MusicNetHDF5

In [6]:
h5_in = h5py.File("./data/musicnet/musicnet_11khz_valid.h5", "r")
dataset = MusicNetHDF5(h5_in, window=4096, stride=512, at=None)

From around [dataset.py#L179](https://github.com/ChihebTrabelsi/deep_complex_networks/blob/master/musicnet/musicnet/dataset.py#L179)
```python
def aggregate_raw_batch(features, output, kind=None, cplx=True):
    n_samples, n_window, n_one = features.shape
    assert n_one == 1

    n_channels = 2 if cplx else 1
    features_out = np.zeros([n_samples, n_window, n_channels])

    if kind == "fourier":
        if cplx:
            data = fft(features, axis=1)
            features_out[:, :, 0] = np.real(data[:, :, 0])
            features_out[:, :, 1] = np.imag(data[:, :, 0])

        else:
            data = np.abs(fft(features, axis=1))
            features_out = data

    elif kind == "stft":
        #  scipy.signal.stft: `... the last axis ... corresponds to the segment times`.
        _, _, data = stft(features, nperseg=120, noverlap=60, axis=1)
        length = data.shape[1]
        n_feats = data.shape[3]

        if cplx:
            features_out = np.zeros([n_samples, length, n_feats * 2])
            features_out[:, :, :n_feats] = np.real(data)  # <- data is 4d
            features_out[:, :, n_feats:] = np.imag(data)

        else:
            features_out = np.abs(data[:, :, 0, :])

    else:
        features_out = features  # <- references a new object, does not overwrite contents 

    return features_out, output


def train_iterator(self, **kwargs):
    features = np.zeros([len(self.keys), self.window])
    if True:
        output = np.zeros([len(self.keys), self.n_outputs])
        for j, key in enumerate(self.keys):
            features[j], output[j] = self[np.random.randint(*self.keys[key])]

        yield aggregate_raw_batch(features[:, :, None], output, **kwargs)
```

Rewritten

In [7]:
def aggregate_raw_batch(features, output, kind=None, cplx=True, dim_fix=False):
    if kind == "fourier":
        data = fft(features, axis=1)

    elif kind == "stft":
        _, _, data = stft(features[:, :, 0], nperseg=120, noverlap=60, axis=-1)
        if dim_fix:
            # shuffle freq and time dims, so that re-im freqs are concatenated
            data = data.transpose(0, 2, 1)

    else:
        data = features

    if cplx:
        features_out = np.concatenate([data.real, data.imag], axis=-1)

    elif kind == "fourier" or kind == "stft":
        features_out = abs(data)

    else:
        features_out = data

    return features_out, output


def train_iterator(self, **kwargs):
    if True:
        # rig the RNG
        features, output = map(np.stack, zip(*(
            self[np.random.RandomState(111111).randint(*self.keys[key])]
            for key in self.keys
        )))

        yield aggregate_raw_batch(features[..., np.newaxis], output, **kwargs)

In [8]:
from sklearn.model_selection import ParameterGrid

grid = ParameterGrid({
    "kind": [None, "fourier", "stft"],
    "cplx": [False, True],
    "dim_fix": [True]
})

In [9]:
shapes = []
for par in grid:
    bx, by = next(train_iterator(dataset, **par))
    shapes.append((par, bx.shape, bx[0, :5, :2]))

In [10]:
shapes

[({'cplx': False, 'dim_fix': True, 'kind': None},
  (6, 4096, 1),
  array([[0.15863739],
         [0.1207215 ],
         [0.07565732],
         [0.03595928],
         [0.00253459]], dtype=float32)),
 ({'cplx': False, 'dim_fix': True, 'kind': 'fourier'},
  (6, 4096, 1),
  array([[0.5083604 ],
         [0.23709638],
         [0.20849799],
         [0.39537755],
         [0.39036623]], dtype=float32)),
 ({'cplx': False, 'dim_fix': True, 'kind': 'stft'},
  (6, 70, 61),
  array([[0.00354679, 0.00526401],
         [0.00076194, 0.00482726],
         [0.00132237, 0.00793825],
         [0.00181491, 0.00279176],
         [0.00238974, 0.00560864]], dtype=float32)),
 ({'cplx': True, 'dim_fix': True, 'kind': None},
  (6, 4096, 2),
  array([[0.15863739, 0.        ],
         [0.1207215 , 0.        ],
         [0.07565732, 0.        ],
         [0.03595928, 0.        ],
         [0.00253459, 0.        ]], dtype=float32)),
 ({'cplx': True, 'dim_fix': True, 'kind': 'fourier'},
  (6, 4096, 2),
  array([

In [11]:
n_len, n_wnd, n_ovr = 40, 33, 32
padded, boundary = True, "zeros"
padded, boundary = False, "zeros"
padded, boundary = True, None
padded, boundary = False, None

for n_len in range(40, 121):
    _, _, z = stft(np.r_[:n_len], nperseg=n_wnd, noverlap=n_ovr, boundary=boundary, padded=padded)
    ll = z.shape[-1]

    if padded:
        n_len = n_len + (-(n_len - n_wnd) % (n_wnd - n_ovr)) % n_wnd
    
    if boundary is not None:
        n_len += 2 * (n_wnd // 2)
    
    assert ll == ((n_len - n_ovr) // (n_wnd - n_ovr))

<br>