In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt
import IPython.display as ipd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from scipy import signal
import glob
import librosa
import time
import pickle
import subprocess
import os
from scipy.io import wavfile

import skimage.io
import sys
sys.path.append("../src")
from imtools import get_voronoi_image, splat_voronoi_image_1nn
from audioutils import extract_loudness
from tsp import get_tsp_tour
sys.path.append("../models")
from networks import CurveEncoder, CurveDecoder
from losses import mss_loss
sys.path.append("../data")
from datasets import CurveData, AudioData

## Datasets

In [None]:
n_samples = 3000
T = 300
samples_per_batch = 10000
win_length = 1024
sr = 44100

curve_train = CurveData(np.arange(10), n_samples, T, samples_per_batch)
curve_test  = CurveData(np.arange(10, 15), n_samples, T, samples_per_batch)

audio_train = AudioData("../data/musdb18hq/train/*/mixture.wav", T, samples_per_batch, win_length, sr)
audio_test  = AudioData("../data/musdb18hq/test/*/mixture.wav",  T, samples_per_batch, win_length, sr)

In [None]:
loader = DataLoader(curve_train, batch_size=16, shuffle=True)
Y = next(iter(loader)).numpy()
print(Y.shape)
plt.scatter(Y[0, :, 0], Y[0, :, 1], c=Y[0, :, 2::])
plt.plot(Y[0, :, 0], Y[0, :, 1], c='k', linewidth=1)

loader = DataLoader(audio_train, batch_size=16, shuffle=True)
X, L = next(iter(loader))
ipd.Audio(X[0, :], rate=sr)

# Encoder / Decoder Architectures

In [None]:
curve_loader = DataLoader(curve_train, batch_size=16, shuffle=True)
audio_loader = DataLoader(audio_train, batch_size=16, shuffle=True)
Y = next(iter(curve_loader))
X, L = next(iter(audio_loader))

mlp_depth = 3
n_units = 256
n_taps = 50
encoder = CurveEncoder(mlp_depth, n_units, n_taps, win_length)
decoder = CurveDecoder(mlp_depth, n_units, win_length)
print("Encoder params", encoder.get_num_parameters())
print("Decoder params", decoder.get_num_parameters())
N = encoder(Y, L)
ipd.Audio(N.detach().numpy()[1, :], rate=sr)
XN = X + N
YOut = decoder(XN)

# Test Example - Layla And Smooth Criminal

In [None]:
device = 'cpu'
n_points = 3000
n_iters = 100
I = skimage.io.imread("../data/images/layla.png")
N = min(I.shape[0], I.shape[1])
J, YLayla, final_cost = get_voronoi_image(I, device, n_points, n_neighbs=2, n_iters=n_iters, verbose=False, plot_iter_interval=0, use_lsqr=False)
YLayla = get_tsp_tour(YLayla)
#YLayla = curve_train.Ys[100]
plt.imshow(splat_voronoi_image_1nn(YLayla, 200, 200))

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
YLayla = np.array(YLayla, dtype=np.float32)
hop_length = win_length//2
TLayla = YLayla.shape[0]
YLayla = torch.from_numpy(YLayla).to(device)
YLayla = YLayla.view(1, TLayla, 5)
xsmooth, _ = librosa.load("../data/tunes/smoothcriminal.mp3", sr=sr)
#xsmooth, _ = librosa.load("../data/musdb18hq/test/Arise - Run Run Run/mixture.wav", sr=sr)
xsmooth = xsmooth[sr*15::] # Cut off quieter beginning
lsmooth = extract_loudness(xsmooth, sr, hop_length, win_length)
lsmooth = np.array(lsmooth[0:TLayla], dtype=np.float32)
xsmooth = np.array(xsmooth[0:hop_length*(TLayla-1)+win_length], dtype=np.float32)
xsmooth = torch.from_numpy(xsmooth[None, :]).to(device)
lsmooth = torch.from_numpy(lsmooth[None, :, None]).to(device)
print(xsmooth.shape, lsmooth.shape)

# Train Loop

In [None]:
# Try to use the GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
print("Device: ", device)

## Step 2: Create model with a test batch
noise_only = False
mlp_depth = 3
n_units = 512
n_taps = 200
pre_scale = 1
lam_xy = 1000 # Weight for geometry curve fit
lam_rgb = 200 # Weight for rgb curve fit
encoder = CurveEncoder(mlp_depth, n_units, n_taps, win_length, pre_scale)
encoder = encoder.to(device)
decoder = CurveDecoder(mlp_depth, n_units, win_length)
decoder = decoder.to(device)
print("Encoder params", encoder.get_num_parameters())
print("Decoder params", decoder.get_num_parameters())

## Step 3: Setup the loss function
batch_size=16
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)
# Start at learning rate of 0.001, rate decay of 0.98 factor every 10,000 steps
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10000/len(curve_train), gamma=0.98)

n_epochs = 10000
xy_losses = []
rgb_losses = []
audio_losses = []

test_xy_losses = []
test_rgb_losses = []
test_audio_losses = []

In [None]:
fac = 0.7
plt.figure(figsize=(fac*18, fac*18))

for epoch in range(n_epochs):
    #####################    STEP 1:  TRAIN     ##################### 
    curve_loader = DataLoader(curve_train, batch_size=16, shuffle=True)
    audio_loader = DataLoader(audio_train, batch_size=16, shuffle=True)
    
    train_xy_loss = 0
    train_rgb_loss = 0
    train_audio_loss = 0
    for batch_num, (Y, (X, L)) in enumerate(zip(curve_loader, audio_loader)): # Go through each mini batch
        # Move inputs/outputs to GPU
        Y = Y.to(device)
        X = X.to(device)
        L = L.to(device)
        # Reset the optimizer's gradients
        optimizer.zero_grad()
        
        # Encode the curve
        N = encoder(Y, L)
        # Add filtered noise to audio and decode
        if noise_only:
            XN = N
        else:
            XN = X + N
        mp3noise = get_mp3_noise(XN, sr) # Add on mp3 noise as a constant
        YOut = decoder(XN + mp3noise)
        
        # Add a loss term for the MSS fit of XN to X, as well as the L1 fit of Y to YOut
        loss_audio = mss_loss(X, XN)
        loss_xy = torch.mean(torch.abs(Y[:, :, 0:2]-YOut[:, :, 0:2]))
        loss_rgb = torch.mean(torch.abs(Y[:, :, 2::]-YOut[:, :, 2::]))
        
        loss = lam_xy*loss_xy + lam_rgb*loss_rgb + loss_audio
        
        train_xy_loss += loss_xy.item()
        train_rgb_loss += loss_rgb.item()
        train_audio_loss += loss_audio.item()
        
        # Compute the gradients of the loss function with respect
        # to all of the parameters of the model
        loss.backward()
        optimizer.step()
        
    last_Y = Y
    last_YOut = YOut
        
    audio_losses.append(train_audio_loss/len(audio_loader))
    xy_losses.append(train_xy_loss/len(audio_loader))
    rgb_losses.append(train_rgb_loss/len(audio_loader))
    
    print("Epoch {}, audio loss {:.3f}, xy loss {:.3f}, rgb loss {:.3f}".format(epoch, audio_losses[-1], xy_losses[-1], rgb_losses[-1]))
    scheduler.step()
    
    #####################    STEP 2:  TEST     ##################### 
    curve_loader = DataLoader(curve_test, batch_size=16)
    audio_loader = DataLoader(audio_test, batch_size=16)
    test_xy_loss = 0
    test_rgb_loss = 0
    test_audio_loss = 0
    for batch_num, (Y, (X, L)) in enumerate(zip(curve_loader, audio_loader)): # Go through each mini batch
        # Move inputs/outputs to GPU
        Y = Y.to(device)
        X = X.to(device)
        L = L.to(device)
        # Encode the curve
        N = encoder(Y, L)
        # Add filtered noise to audio and decode
        if noise_only:
            XN = N
        else:
            XN = X + N
        YOut = decoder(XN)
        # Add a loss term for the MSS fit of XN to X, as well as the L1 fit of Y to YOut
        loss_audio = mss_loss(X, XN)
        loss_xy = torch.mean(torch.abs(Y[:, :, 0:2]-YOut[:, :, 0:2]))
        loss_rgb = torch.mean(torch.abs(Y[:, :, 2::]-YOut[:, :, 2::]))
        test_xy_loss += loss_xy.item()
        test_rgb_loss += loss_rgb.item()
        test_audio_loss += loss_audio.item()
        
    last_Y_test = Y
    last_YOut_test = YOut
        
    test_audio_losses.append(test_audio_loss/len(audio_loader))
    test_xy_losses.append(test_xy_loss/len(audio_loader))
    test_rgb_losses.append(test_rgb_loss/len(audio_loader))
    
    
    #####################    STEP 3:  LAYLA     ##################### 
    N = encoder(YLayla, lsmooth)
    if noise_only:
        XN = N
    else:
        XN = xsmooth + N
    YOut = decoder(XN)
    loss_xy_layla  = torch.mean(torch.abs(YLayla[:, :, 0:2]-YOut[:, :, 0:2]))
    loss_rgb_layla = torch.mean(torch.abs(YLayla[:, :, 2::]-YOut[:, :, 2::]))
    YOut = YOut.detach().cpu().numpy()
    YOut = YOut[0, :, :]
    YOut[:, 2::] = np.maximum(YOut[:, 2::], 0)
    YOut[:, 2::] = np.minimum(YOut[:, 2::], 1)
    # Save as an mp3 and repeat
    x = XN.detach().cpu().numpy()[0, :]
    x = x/np.max(x)
    x = np.array(x*32768, dtype=np.int16)
    wavfile.write("Epoch{}.wav".format(epoch), sr, x)
    
    x = N.detach().cpu().numpy()[0, :]
    x = x/np.max(x)
    x = np.array(x*32768, dtype=np.int16)
    wavfile.write("Epoch{}Noise.wav".format(epoch), sr, x)
    
    subprocess.call("ffmpeg -i Epoch{}.wav Epoch{}.mp3".format(epoch, epoch).split(), stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
    z, sr = librosa.load("Epoch{}.mp3".format(epoch), sr=sr)
    z = torch.from_numpy(z[None, :]).to(XN)
    YOut_mp3 = decoder(z)
    loss_xy_layla_mp3  = torch.mean(torch.abs(YLayla[:, :, 0:2]-YOut_mp3[:, :, 0:2]))
    loss_rgb_layla_mp3 = torch.mean(torch.abs(YLayla[:, :, 2::]-YOut_mp3[:, :, 2::]))
    YOut_mp3 = YOut_mp3.detach().cpu().numpy()
    YOut_mp3 = YOut_mp3[0, :, :]
    YOut_mp3[:, 2::] = np.maximum(YOut_mp3[:, 2::], 0)
    YOut_mp3[:, 2::] = np.minimum(YOut_mp3[:, 2::], 1)
    

    
    #####################    STEP 4: PLOT    ##################### 
    plt.clf()
    plt.subplot(331)
    plt.imshow(splat_voronoi_image_1nn(YOut, 200, 200))
    plt.title("Layla Wav Reconstruction")
    plt.subplot(332)
    plt.imshow(splat_voronoi_image_1nn(YOut_mp3, 200, 200))
    plt.title("Layla Mp3 Reconstruction")
    
    plt.subplot(333)
    plt.plot(audio_losses)
    plt.plot(lam_xy*np.array(xy_losses))
    plt.plot(lam_rgb*np.array(rgb_losses))
    plt.legend(["Audio ({:.3f})".format(audio_losses[-1]), 
                "xy ({:.3f})".format(xy_losses[-1]),
                "rgb ({:.3f})".format(rgb_losses[-1])])
    plt.xlabel("Epoch")
    plt.ylabel("Scaled Loss")
    plt.title("Epoch {}, Train Losses".format(epoch)) 
    
    
    plt.subplot(334)
    plt.scatter(YLayla.detach().cpu()[0, :, 0], YOut[:, 0])
    plt.axis("equal")
    plt.title("X Coord Layla (Loss {:.3f})".format(loss_xy_layla))
    plt.subplot(335)
    plt.scatter(YLayla.detach().cpu()[0, :, 2], YOut[:, 2])
    plt.axis("equal")
    plt.title("R Coord Layla (Loss {:.3f})".format(loss_rgb_layla))
    
    plt.subplot(336)
    plt.plot(test_audio_losses)
    plt.plot(lam_xy*np.array(test_xy_losses))
    plt.plot(lam_rgb*np.array(test_rgb_losses))
    plt.legend(["Audio ({:.3f})".format(test_audio_losses[-1]), 
                "xy ({:.3f})".format(test_xy_losses[-1]),
                "rgb ({:.3f})".format(test_rgb_losses[-1])])
    plt.xlabel("Epoch")
    plt.ylabel("Scaled Loss")
    plt.title("Test Losses")
    
    plt.subplot(337)
    plt.scatter(YLayla.detach().cpu()[0, :, 0], YOut_mp3[:, 0])
    plt.axis("equal")
    plt.title("Mp3 X Coord Layla (Loss {:.3f})".format(loss_xy_layla_mp3))
    plt.subplot(338)
    plt.scatter(YLayla.detach().cpu()[0, :, 2], YOut_mp3[:, 2])
    plt.axis("equal")
    plt.title("Mp3 R Coord Layla (Loss {:.3f})".format(loss_rgb_layla_mp3))
    
    plt.subplot(339)
    textstr = "Epoch {}\n\nlam_xy = {}\nlam_rgb = {}\npre_scale={}\nn_units={}\nn_taps={}\npre_scale={}".format(epoch, lam_xy, lam_rgb, pre_scale, n_units, n_taps, pre_scale)
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    ax = plt.gca()
    ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=14,
            verticalalignment='top', bbox=props)
    plt.axis("off")
    
    plt.savefig("Epoch{}.png".format(epoch), bbox_inches='tight')