In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from IPython.display import Audio

In [None]:
import numpy as np
from tqdm import tqdm_notebook
import librosa

In [None]:
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch import nn

In [None]:
from zachary.datasets import AudioDataset, SpectrumDataset
from zachary.utils import get_torch_device, get_num_trainable_params
from zachary.weight_initializers import initialize_model
from zachary.plotting import plot_mag_phase
from zachary.modules import SeparableConv1d, SeparableConvTranspose1d

In [None]:
BATCH_SIZE = 128
DEVICE = get_torch_device()

In [None]:
dataset = SpectrumDataset()

In [None]:
dataset.example_length = 9

In [None]:
dataset[0][0].shape, dataset[0][1].shape

In [None]:
def plot_spectrum(y):
    plt.rcParams['figure.figsize'] = (18, 4)

    fig, (ax1) = plt.subplots(1, 1)
    ax1.imshow(y.numpy(), aspect='auto', origin='lower')

In [None]:
def plot_signal(x):
    plt.rcParams['figure.figsize'] = (18, 4)

    fig, (ax1) = plt.subplots(1, 1)
    ax1.plot(x[0].numpy())

In [None]:
plot_signal(dataset[-1][1])

In [None]:
plot_spectrum(dataset[-1][0])

In [None]:
class Encoder(nn.Module):
    def __init__(self, channels=[8, 16, 32, 64, 128], separable=False):
        super(Encoder, self).__init__()
        
        if separable:
            convolution = SeparableConv1d
        else:
            convolution = nn.Conv1d
        
        entry_layers = 8
        entry_channels = 8
        kernel_channels = [entry_channels * entry_layers] + channels
        kernel_sizes = [4, 4, 4, 4, 4]
        
        self.e1 = convolution(in_channels=1, out_channels=entry_channels, kernel_size=3, dilation=2**0, padding=2**0)
        self.e2 = convolution(in_channels=1, out_channels=entry_channels, kernel_size=3, dilation=2**1, padding=2**1)
        self.e3 = convolution(in_channels=1, out_channels=entry_channels, kernel_size=3, dilation=2**2, padding=2**2)
        self.e4 = convolution(in_channels=1, out_channels=entry_channels, kernel_size=3, dilation=2**3, padding=2**3)
        self.e5 = convolution(in_channels=1, out_channels=entry_channels, kernel_size=3, dilation=2**4, padding=2**4)
        self.e6 = convolution(in_channels=1, out_channels=entry_channels, kernel_size=3, dilation=2**5, padding=2**5)
        self.e7 = convolution(in_channels=1, out_channels=entry_channels, kernel_size=3, dilation=2**6, padding=2**6)
        self.e8 = convolution(in_channels=1, out_channels=entry_channels, kernel_size=3, dilation=2**7, padding=2**7)
        
        
        self.c1 = convolution(in_channels=kernel_channels[0], out_channels=kernel_channels[1], kernel_size=kernel_sizes[0], stride=kernel_sizes[0])
        self.c2 = convolution(in_channels=kernel_channels[1], out_channels=kernel_channels[2], kernel_size=kernel_sizes[1], stride=kernel_sizes[1])
        self.c3 = convolution(in_channels=kernel_channels[2], out_channels=kernel_channels[3], kernel_size=kernel_sizes[2], stride=kernel_sizes[2])
        self.c4 = convolution(in_channels=kernel_channels[3], out_channels=kernel_channels[4], kernel_size=kernel_sizes[3], stride=kernel_sizes[3])
        self.c5 = convolution(in_channels=kernel_channels[4], out_channels=kernel_channels[5], kernel_size=kernel_sizes[4], stride=kernel_sizes[4] // 2)
        
        self.entry = [self.e1, self.e2, self.e3, self.e4, self.e5, self.e6, self.e7, self.e8]
        self.convolutions = [self.c1, self.c2, self.c3, self.c4, self.c5]

    def forward(self, x):
        xs = []
        for layer in self.entry:
            xs.append(F.relu(layer(x)))
        xs = torch.cat(xs, dim=1)
        
        for layer in self.convolutions[:-1]:
            xs = F.relu(layer(xs))
        xs = self.convolutions[-1](xs)
        return xs

In [None]:
class Decoder(nn.Module):
    def __init__(self, channels=[256, 128, 64, 32, 16], separable=False):
        super(Decoder, self).__init__()
        
        if separable:
            deconvolution = SeparableConvTranspose1d
        else:
            deconvolution = nn.ConvTranspose1d
        
        sizes = [8, 4, 4, 4, 2]
        
        self.c1 = deconvolution(in_channels=channels[0], out_channels=channels[1], kernel_size=sizes[0], stride=sizes[0]//8)
        self.c2 = deconvolution(in_channels=channels[1], out_channels=channels[2], kernel_size=sizes[1], stride=sizes[1])
        self.c3 = deconvolution(in_channels=channels[2], out_channels=channels[3], kernel_size=sizes[2], stride=sizes[2])
        self.c4 = deconvolution(in_channels=channels[3], out_channels=channels[4], kernel_size=sizes[3], stride=sizes[3])
        self.c5 = deconvolution(in_channels=channels[4], out_channels=1, kernel_size=sizes[4], stride=sizes[4])
        
        self.convolutions = [self.c1, self.c2, self.c3, self.c4, self.c5]

    def forward(self, z):
        for layer in self.convolutions[:-1]:
            z = F.relu(layer(z))
        return self.convolutions[-1](z)

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, channels=[8, 16, 32, 64, 128], separable=False):
        super(AutoEncoder, self).__init__()
        
        self.encoder = Encoder(channels, separable)
        channels.reverse()
        self.decoder = Decoder(channels, separable)
    
    def encode(self, x):
        return self.encoder(x)
    
    def decode(self, x):
        return self.decoder(x)
    
    def forward(self, x):
        return self.decoder(self.encoder(x))

In [None]:
class VAE(nn.Module):
    def __init__(self, channels=[8, 16, 32, 64, 128], separable=False):
        super(VAE, self).__init__()

        self.encoder = Encoder(channels, separable)
        channels.reverse()
        self.decoder = Decoder(channels, separable)
        ch = self.encoder.c5.out_channels
        self.mu_layer = nn.Conv1d(ch, ch, 3, 1, 1)
        self.logvar_layer = nn.Conv1d(ch, ch, 3, 1, 1)

    def encode(self, x):
        h1 = F.relu(self.encoder(x))
        return self.mu_layer(h1), self.logvar_layer(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [None]:
model = Decoder([513, 256, 128, 64, 32])
initialize_model(model)
print(get_num_trainable_params(model))

In [None]:
ae_loss = F.mse_loss

In [None]:
def vae_loss(result, x):
    x_hat, mu, logvar = result
    mse = F.mse_loss(x_hat, x)
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return mse + kld

In [None]:
def diger_loss(signal_hat, signal):
    mse = F.mse_loss(signal_hat, signal)
    cosine_loss = 1 - F.cosine_similarity(signal_hat.squeeze(1), signal.squeeze(1)).mean()
    return cosine_loss + mse

In [None]:
if isinstance(model, AutoEncoder):
    loss_fn = ae_loss
elif isinstance(model, VAE):
    loss_fn = vae_loss
else:
    print('UNKNOWN MODEL')
    loss_fn = diger_loss

In [None]:
optimizer = torch.optim.Adam(model.parameters())

### Increase `example_length` during training

In [None]:
example_length = 11
dataset.example_length = example_length
data_loader = DataLoader(dataset, pin_memory=True, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)

### This is the training loop

In [None]:
model.to(DEVICE)
model.train()
for i in range(100):
    batch = 1
    with tqdm_notebook(total=dataset.examples.shape[0]) as pbar:
        for example, target in data_loader:
            optimizer.zero_grad()
            loss = loss_fn(model(example.to(DEVICE)), target.to(DEVICE))

            pbar.set_description(f'Epoch: {i + 1} - loss: {loss.data.cpu().numpy():.4f}')
            pbar.update(example.shape[0])

            batch += 1

            loss.backward()
            optimizer.step()

## Test performance

In [None]:
sample_a = dataset.audio[:30*44100].unsqueeze(0).unsqueeze(0).to(DEVICE)

In [None]:
sample = dataset.spectrum[:,:30*100].unsqueeze(0).to(DEVICE)

In [None]:
model.eval()
with torch.no_grad():
    sample_hat = model(sample)

if isinstance(model, VAE):
    sample_hat, _, _ = sample_hat

sample_hat = sample_hat.squeeze(0).cpu()
sample_hat_np = sample_hat.squeeze(0).numpy()

In [None]:
plot_signal(sample_hat)

In [None]:
Audio(sample_hat_np, rate=44100)

In [None]:
plot_signal(dataset.audio[:sample_hat.shape[1]].unsqueeze(0).cpu())

In [None]:
Audio(dataset.audio[:sample_hat.shape[1]].cpu().numpy(), rate=44100)

In [None]:
model.eval()
with torch.no_grad():
    z = model.encode(sample)

if isinstance(model, VAE):
    z, _ = z

In [None]:
plt.rcParams['figure.figsize'] = (18, 4)

fig, (ax1) = plt.subplots(1, 1)
ax1.imshow(z[0].cpu().numpy(), aspect='auto')
pass

In [None]:
import scipy.interpolate as si

In [None]:
def bspline(cv, n=100, degree=3, periodic=False):
    # If periodic, extend the point array by count+degree+1
    if degree < 1:
        raise ValueError('degree cannot be less then 1!')
    count = len(cv)

    if periodic:
        factor, fraction = divmod(count + degree + 1, count)
        cv = np.concatenate((cv,) * factor + (cv[:fraction],))
        count = len(cv)

    # If opened, prevent degree from exceeding count-1
    else:
        1if count < degree + 1:
            raise ValueError('number of cvs must be higher than degree + 1')

    # Calculate knot vector
    if periodic:
        kv = np.arange(0 - degree, count + degree + degree - 1, dtype='int')
    else:
        kv = np.array([0] * degree + list(range(count - degree + 1)) + [count - degree] * degree, dtype='int')

    # Calculate query range
    u = np.linspace(periodic, (count - degree), n)

    # Calculate result
    arange = np.arange(len(u))
    points = np.zeros((len(u), cv.shape[1]))
    for i in range(cv.shape[1]):
        points[arange, i] = si.splev(u, (kv, cv[:, i], degree))

    return points


def sample_z(z_dims, mean, std, num_cv, resolution, degree, is_periodic):
    # Generates splines of random lengths in z_dims dimensions
    # num_cv = np.random.randint(64, 128)
    cv = np.random.normal(mean, std, (num_cv, z_dims))
    num_points = num_cv * resolution
    spline = bspline(cv, num_points, degree, is_periodic)
    return spline

In [None]:
zs = sample_z(model.encoder.c5.out_channels, 0., 2., 100, 10, 2, True)

In [None]:
zs_t = torch.from_numpy(zs.astype('float32').T).unsqueeze(0).to(DEVICE)

In [None]:
model.eval()
with torch.no_grad():
    y = model.decode(zs_t)

y_hat = y.squeeze(0).cpu()
y_hat_np = y_hat.squeeze(0).numpy()

In [None]:
plot_signal(y_hat)

In [None]:
Audio(y_hat_np, rate=44100)