In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from IPython.display import Audio

In [None]:
import numpy as np
from tqdm import tqdm_notebook
import librosa

In [None]:
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch import nn

In [None]:
from zachary.datasets import AtemporalDataset
from zachary.utils import get_torch_device, get_num_trainable_params
from zachary.weight_initializers import initialize_model

In [None]:
BATCH_SIZE = 64
DEVICE = get_torch_device()

In [None]:
dataset = AtemporalDataset()

In [None]:
dataset[0].shape

In [None]:
dur = librosa.time_to_frames(10, sr=44100, hop_length=512, n_fft=1024)

In [None]:
plt.rcParams['figure.figsize'] = (18, 4)

fig, (ax1) = plt.subplots(1, 1)
ax1.imshow(dataset[:dur].numpy().T, aspect='auto', origin='lower')
pass

In [None]:
def stft_to_signal(S, num_iters=15):
    S_T = S.T

    # Retrieve phase information
    phase = 2 * np.pi * np.random.random_sample(S_T.shape) - np.pi
    signal = None
    for idx in range(num_iters):
        D = S_T * np.exp(1j * phase)
        signal = librosa.istft(D, hop_length=512, win_length=1024)
        # don't calculate phase during the last iteration, because it will not be used.
        if idx < num_iters - 1:
            phase = np.angle(librosa.stft(signal, n_fft=1024, hop_length=512))

    return signal

In [None]:
sig = stft_to_signal(dataset[:dur].numpy(), num_iters=30)

In [None]:
plt.rcParams['figure.figsize'] = (18, 4)
t = np.linspace(0, len(sig)/44100, len(sig))

fig, (ax1) = plt.subplots(1, 1)
ax1.plot(t, sig)
pass

In [None]:
Audio(sig, rate=44100)

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.c1 = nn.Linear(513, 256)
        self.c2 = nn.Linear(256, 128)
        self.c3 = nn.Linear(128, 64)
        self.c4 = nn.Linear(64, 32)
        
    def forward(self, x):
        z = F.relu(self.c1(x))
        z = F.relu(self.c2(z))
        z = F.relu(self.c3(z))
        z = self.c4(z)
        
        return z

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.c1 = nn.Linear(32, 64)
        self.c2 = nn.Linear(64, 128)
        self.c3 = nn.Linear(128, 256)
        self.c4 = nn.Linear(256, 513)
        
    def forward(self, x):
        z = F.relu(self.c1(x))
        z = F.relu(self.c2(z))
        z = F.relu(self.c3(z))
        z = self.c4(z)
        
        return z

In [None]:
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = Encoder()
        self.decoder = Decoder()  

    def forward(self, x):
        z = self.encoder(x)
        y = self.decoder(z)
        
        return y

In [None]:
model = Autoencoder()
initialize_model(model)
print('\t', get_num_trainable_params(model))

In [None]:
loss_fn = F.mse_loss

In [None]:
optimizer = torch.optim.Adam(model.parameters())

In [None]:
data_loader = DataLoader(dataset, pin_memory=True, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)

### This is the training loop

In [None]:
model.to(DEVICE)
model.train()
for i in range(2):
    batch = 1
    with tqdm_notebook(total=len(dataset)) as pbar:
        for absolute in data_loader:
            optimizer.zero_grad()
            cc = absolute.to(DEVICE)
            loss = loss_fn(model(cc), cc)

            pbar.set_description(f'Epoch: {i + 1} - loss: {loss.data.cpu().numpy():.2E}')
            pbar.update(absolute.shape[0])

            batch += 1

            loss.backward()
            optimizer.step()

## Test performance

In [None]:
from zachary.datasets import load_audio_file

In [None]:
sample = dataset[:dur].unsqueeze(0).to(DEVICE)

In [None]:
model.eval()
with torch.no_grad():
    sample_hat = model(sample)

sample_hat = sample_hat.squeeze(0).cpu()
sample_hat_np = sample_hat.squeeze(0).numpy()

In [None]:
signal_hat = stft_to_signal(sample_hat_np, num_iters=30)

In [None]:
plt.rcParams['figure.figsize'] = (18, 4)
t = np.linspace(0, len(sig)/44100, len(sig))

fig, (ax1) = plt.subplots(1, 1)
ax1.plot(t, signal_hat)
pass

In [None]:
Audio(signal_hat, rate=44100)

In [None]:
import scipy.interpolate as si

In [None]:
def bspline(cv, n=100, degree=3, periodic=False):
    # If periodic, extend the point array by count+degree+1
    if degree < 1:
        raise ValueError('degree cannot be less then 1!')
    count = len(cv)

    if periodic:
        factor, fraction = divmod(count + degree + 1, count)
        cv = np.concatenate((cv,) * factor + (cv[:fraction],))
        count = len(cv)

    # If opened, prevent degree from exceeding count-1
    else:
        if count < degree + 1:
            raise ValueError('number of cvs must be higher than degree + 1')

    # Calculate knot vector
    if periodic:
        kv = np.arange(0 - degree, count + degree + degree - 1, dtype='int')
    else:
        kv = np.array([0] * degree + list(range(count - degree + 1)) + [count - degree] * degree, dtype='int')

    # Calculate query range
    u = np.linspace(periodic, (count - degree), n)

    # Calculate result
    arange = np.arange(len(u))
    points = np.zeros((len(u), cv.shape[1]))
    for i in range(cv.shape[1]):
        points[arange, i] = si.splev(u, (kv, cv[:, i], degree))

    return points


def sample_z(z_dims, mean, std, num_cv, resolution, degree, is_periodic):
    # Generates splines of random lengths in z_dims dimensions
    # num_cv = np.random.randint(64, 128)
    cv = np.random.normal(mean, std, (num_cv, z_dims))
    num_points = num_cv * resolution
    spline = bspline(cv, num_points, degree, is_periodic)
    return spline

In [None]:
zs = sample_z(model.encoder.c4.out_features, 0., 1., 100, 25, 2, True)

In [None]:
zs_t = torch.from_numpy(zs.astype('float32')).to(DEVICE)

In [None]:
model.eval()
with torch.no_grad():
    y = model.decoder(zs_t)

y_hat = y.cpu()
y_hat_np = y_hat.numpy()

In [None]:
s_hat = stft_to_signal(y_hat_np)

In [None]:
plt.rcParams['figure.figsize'] = (18, 4)
t = np.linspace(0, len(s_hat)/44100, len(s_hat))

fig, (ax1) = plt.subplots(1, 1)
ax1.plot(t, s_hat)
pass

In [None]:
Audio(s_hat, rate=44100)