In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
from gantools import plot
from gantools.data import load
from scipy.signal import firwin
import matplotlib.pyplot as plt
from gantools.blocks import downsample

In [None]:
dataset = load.load_nsynth_dataset(shuffle=False)

In [None]:
samples = dataset.get_samples(N=16)

In [None]:
sample0 = samples[10][256:512]

In [None]:
sample0.shape

In [None]:
Nwin=32
win = firwin(numtaps=Nwin, cutoff=1/3,)

In [None]:
#Creates two subplots and unpacks the output array immediately
f, (ax1, ax2) = plt.subplots(1, 2)
ax1.plot(win)
ax2.plot(np.arange(Nwin//2)/Nwin, np.abs(np.fft.fft(win)[:Nwin//2]))

In [None]:
def downsample_1d(sig, s=2, Nwin=2):
    if len(sig.shape)==2:
        return np.apply_along_axis(downsample_1d,1, sig, s=s, Nwin=Nwin)
    win = firwin(numtaps=Nwin, cutoff=2/5)
    ntimes = np.log2(s)
    assert(ntimes-np.int(ntimes)<1e-6)
    ntimes = np.int(np.round(ntimes))
    new_sig = sig.copy()
    for _ in range(ntimes):
        new_sig = np.convolve(new_sig,win, 'same')
        new_sig = new_sig[1::2]
    return new_sig

def upsamler_1d(sig, s=2, Nwin=2):
    if len(sig.shape)==2:
        return np.apply_along_axis(upsamler_1d, 1, sig, s=s, Nwin=Nwin)
    win = firwin(numtaps=Nwin, cutoff=1/2)
    ntimes = np.log2(s)
    assert(ntimes-np.int(ntimes)<1e-6)
    ntimes = np.int(np.round(ntimes))
    tsig = sig.copy()
    for _ in range(ntimes):
        new_sig = np.zeros(shape=[len(tsig)*2])
        new_sig[1::2] = tsig
        new_sig[::2] = tsig
        new_sig = np.convolve(new_sig,win, 'same')
        tsig = new_sig
    return new_sig

In [None]:
x = np.random.rand(25,256)
assert(np.sum(np.abs(downsample_1d(x, s=2) - downsample(x, s=2)))<1e-5)
assert(np.sum(np.abs(downsample_1d(x, s=4) - downsample(x, s=4)))<1e-5)
assert(np.sum(np.abs(downsample_1d(x, s=8) - downsample(x, s=8)))<1e-5)


In [None]:
assert(np.sum(np.abs(downsample_1d(sample0, s=2)-downsample(np.reshape(sample0, [1, len(sample0)]), s=2)))<1e-5)
assert(np.sum(np.abs(downsample_1d(sample0, s=4)-downsample(np.reshape(sample0, [1, len(sample0)]), s=4)))<1e-5)
assert(np.sum(np.abs(downsample_1d(sample0, s=8)-downsample(np.reshape(sample0, [1, len(sample0)]), s=8)))<1e-5)

In [None]:
s = 8

ds = downsample_1d(sample0, s=s, Nwin=30)
ds2 = downsample_1d(sample0, s=s, Nwin=2)

dus = upsamler_1d(ds, s=s, Nwin=30)
dus2 = upsamler_1d(ds2, s=s, Nwin=2)
ns = len(ds)
#Creates two subplots and unpacks the output array immediately
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(10,5))
ax1.plot(s*(1+np.arange(ns))-1,ds, 'x',label='downsampled')
# ax1.plot(s*(1+np.arange(ns))-1,ds2, 'x-', label='downsampled old')
ax1.plot(dus, label='Smooth')
# ax1.plot(dus2, label='Smooth old')
ax1.plot(sample0, label='original')
ax1.legend()
ax2.plot(np.arange(ns//2)/ns/s, np.abs(np.fft.fft(ds)[:ns//2])*s, label='Downsampled')
ax2.plot(np.arange(ns//2)/ns/s, np.abs(np.fft.fft(ds)[:ns//2])*s, label='Downsampled old')
ax2.plot(np.arange(ns*s//2)/ns/s, np.abs(np.fft.fft(sample0)[:ns*s//2]), label='Original')
ax2.plot(np.arange(ns*s//2)/ns/s, np.abs(np.fft.fft(dus)[:ns*s//2]), label='Smooth')
ax2.legend()