In [None]:
# default_exp utils

In [None]:
import nbdev.showdoc as literacy

In [None]:
#export
from speechsep.imports import *
from speechsep.base import *

# Utils

This contains helper functions used for spectrogram creation and audio processing.

## Fourier Transforms

In [None]:
#export
def time_bins(X, window_size, overlap):
    """
    Create an overlapped version of X
    Parameters
    ----------
    X : ndarray, shape=(n_samples,)
        Input signal to window and overlap
    window_size : int
        Size of windows to take
    window_step : int
        Step size between windows
    Returns
    -------
    X_strided : shape=(nun_windows, window_size)
        2D array of overlapped X
    """
    if window_size % 2 != 0:
        raise ValueError(f"Window size must be even! Recieved {window_size}")
    padding = np.zeros(int(window_size - len(X) % window_size))
    X = np.concatenate((X, padding))
    slide_length = int(window_size*(1-overlap))
    num_windows = (len(X) - window_size) // slide_length
    out = np.ndarray((num_windows,window_size),dtype = X.dtype)

    for i in range(num_windows):
        start = i * slide_length
        out[i] = X[start : start+window_size]
    return out

In [None]:
#export
def stft(X, fftsize=512, win_mult=2, overlap=0.5, normalize=False):
    """
    Compute STFT for 1D real valued input X
    """
    win_size = fftsize*win_mult
    X = time_bins(X, win_size, overlap)
    hanning = .54 - .46 * np.cos(2 * np.pi * np.arange(win_size) / (win_size - 1))
    X = X * hanning.reshape((1, win_size))
    X = np.fft.fft(X).T
    #X = np.fft.fft(X)[:win_size//2].T
    if normalize: X*=256/X.max()
    return X

In [None]:
#export
def istft(X, fftsize=512, win_mult=2, overlap=0.5, normalize=False):
    #X = np.concatenate((X, X[::-1]), axis=0)
    X = np.fft.ifft(X.T).real
    win_size = len(X[0])
    slider_length = int(win_size*(1-overlap))

    hanning = .54 - .46 * np.cos(2 * np.pi * np.arange(win_size) / (win_size - 1))
    X = X/hanning.reshape((1, win_size))

    inv_audio = X[0][0:win_size-slider_length]
    for i in range(len(X)):
        inv_audio = np.concatenate((inv_audio, X[i][-slider_length:]))
    return(inv_audio)

# Utils functions

In [None]:
#export
def fill(sig, shape):
    diff = abs(len(sig) - shape)
    return np.pad(sig, (0,diff), 'constant', constant_values=(0, 0))

In [None]:
#export
def randomComplex(shape):
    randcmplx = np.random.multivariate_normal([0,0], [[1,0],[0,1]], shape)
    return randcmplx[:,:,0]+randcmplx[:,:,1]*1j

In [None]:
#export
def complex2real(data):
    if np.iscomplexobj(data):
        new_data = np.concatenate((data.real[..., np.newaxis], data.imag[..., np.newaxis]), axis=-1)
        return new_data.T
    return data

def real2complex(data):
    data = data.numpy().swapaxes(-1,-3)
    return data[..., 0] + data[..., 1]*1j

In [None]:
#export
def complex_mult(x,y):
    x_real, x_imag = x[...,:1,:,:], x[...,1:,:,:]
    y_real, y_imag = y[...,:1,:,:], y[...,1:,:,:]
    res_real = x_real*y_real - x_imag*y_imag
    res_imag = x_real*y_imag + x_imag*y_real
    return torch.cat((res_real,res_imag), -3)

In [None]:
#export
def get_shape(olist):
    "Get the shape of items in iterable. If there are different lengths an error will occur"
    shapes = set([o.shape for o in olist])
    if len(shapes) > 1: ValueError("To generate Masks make sure that the length of files are equal.")
    return shapes.pop()

In [None]:
#export
def join_audios(audioList):
    np_list = audioList if isinstance(audioList[0], ndarray) else [audio.data for audio in audioList]
    mix = np.zeros(get_shape(np_list))
    for aud in np_list: mix+=aud
    return mix

class Mixer(Transform):
    as_item_force=True
    def encodes(self, audioList):
        joined = join_audios(audioList)
        AudioType, sr = type(audioList[0]), audioList[0].sr
        return Tuple(AudioType(joined, sr)) + Tuple(audioList)

In [None]:
#export
class Unet_Trimmer(Transform):
    def __init__(self, trim_val):
        self.trim_val=trim_val
    def encodes(self, spec):
        if isinstance(spec, SpecBase):
            trim = spec.data.shape[1]//self.trim_val*self.trim_val
            return type(spec)(spec.data[:,:trim], spec.sr, spec.fn)
        trim = spec.shape[1]//self.trim_val*self.trim_val
        return data[:,:trim]

## Documentation