In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt
import IPython.display as ipd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch_audiomentations import Compose, PitchShift
from scipy import signal
import glob
import librosa
import time
import pickle
import subprocess
from skimage.transform import resize
from scipy.io import wavfile

import skimage.io
import sys
sys.path.append("../src")
from imtools import get_voronoi_image, splat_voronoi_image_1nn
from audioutils import extract_loudness, get_mp3_noise, get_chroma_filterbank, get_batch_chroma
from tsp import get_tsp_tour
from wavelets2d import get_color_wavelet_tsp, invert_sparse_coefficients
sys.path.append("../models")
from networks import CurveEncoder, CurveDecoder
from losses import mss_loss
sys.path.append("../data")
from dataset import CurveData, AudioData

## Datasets

In [None]:
n_samples = 3000
T = 300
samples_per_batch = 10000
win_length = 1024
hop_length = win_length//2
sr = 44100

curve_train = CurveData(np.arange(10), n_samples, T, samples_per_batch, voronoi=False)
curve_test  = CurveData(np.arange(10, 15), n_samples, T, samples_per_batch, voronoi=False)

audio_train = AudioData("../data/musdb18hq/train/*/mixture.wav", T, samples_per_batch, win_length, sr)
audio_test  = AudioData("../data/musdb18hq/test/*/mixture.wav",  T, samples_per_batch, win_length, sr)

# Test Example - Layla And Smooth Criminal

In [None]:
n_levels = 5
n_points = 3000
box_scale = 50
I = skimage.io.imread("../data/images/layla.png")
I = resize(I, (256, 256), anti_aliasing=True)
YLayla = get_color_wavelet_tsp(I, n_levels, n_points, box_scale)
J = invert_sparse_coefficients(YLayla, I.shape[0], n_levels, box_scale)
plt.figure(figsize=(8, 4))
plt.subplot(121)
plt.imshow(J)
plt.subplot(122)
plt.plot(YLayla[:, 0], YLayla[:, 1])

In [None]:
# Try to use the GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
print("Device: ", device)
YLayla = np.array(YLayla, dtype=np.float32)
TLayla = YLayla.shape[0]
YLayla = torch.from_numpy(YLayla).to(device)
YLayla = YLayla.view(1, TLayla, 5)
xsmooth, _ = librosa.load("../data/tunes/smoothcriminal.mp3", sr=sr)
#xsmooth, _ = librosa.load("../data/musdb18hq/test/Arise - Run Run Run/mixture.wav", sr=sr)
xsmooth = xsmooth[sr*15::] # Cut off quieter beginning
lsmooth = extract_loudness(xsmooth, sr, hop_length, win_length)
lsmooth = np.array(lsmooth[0:TLayla], dtype=np.float32)
xsmooth = np.array(xsmooth[0:hop_length*(TLayla-1)+win_length], dtype=np.float32)
xsmooth = torch.from_numpy(xsmooth[None, :]).to(device)
lsmooth = torch.from_numpy(lsmooth[None, :, None]).to(device)
print(xsmooth.shape, lsmooth.shape)

# Train Loop

In [None]:
## Parameters for encoder/decoder
mlp_depth = 3
n_units = 512
n_taps = 200
noise_eps = 0.2
use_mp3_noise = True
pre_scale = 1
lam_xy = 1000 # Weight for geometry curve fit
lam_rgb = 400 # Weight for rgb curve fit
encoder = CurveEncoder(mlp_depth, n_units, n_taps, win_length, pre_scale)
encoder = encoder.to(device)
decoder = CurveDecoder(mlp_depth, n_units, win_length, voronoi=False)
decoder = decoder.to(device)
print("Encoder params", encoder.get_num_parameters())
print("Decoder params", decoder.get_num_parameters())

## Setup the loss function
batch_size=16
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)
# Start at learning rate of 0.001, rate decay of 0.98 factor every 10,000 steps
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10000/len(curve_train), gamma=0.98)

n_epochs = 100
xy_losses = []
rgb_losses = []
audio_losses = []

test_xy_losses = []
test_rgb_losses = []
test_audio_losses = []

## Setup chroma filterbank and hann window for STFT
chroma_filterbank = get_chroma_filterbank(sr, win_length).to(device)
hann = torch.hann_window(win_length).to(device)
CLayla = get_batch_chroma(xsmooth, win_length, hop_length, hann, chroma_filterbank)

## Setup data augmentation
pitch_shifter = PitchShift(sample_rate=sr)

In [None]:
curve_loader = DataLoader(curve_train, batch_size=16, shuffle=True)
audio_loader = DataLoader(audio_train, batch_size=16, shuffle=True)

(Y, (X, L)) = next(zip(curve_loader, audio_loader))
Y = Y.to(device)
X = X.to(device)
L = L.to(device)
C = get_batch_chroma(X, win_length, hop_length, hann, chroma_filterbank)
N = encoder(Y, L, C)
XN = X + N
XYOut, RGBOut = decoder(XN)
print(XYOut.shape, RGBOut.shape)

In [None]:
fac = 0.7
plt.figure(figsize=(fac*18, fac*18))

epoch = 0
while epoch < n_epochs:
    #####################    STEP 1:  TRAIN     ##################### 
    curve_loader = DataLoader(curve_train, batch_size=16, shuffle=True)
    audio_loader = DataLoader(audio_train, batch_size=16, shuffle=True)
    
    train_xy_loss = 0
    train_rgb_loss = 0
    train_audio_loss = 0
    tic = time.time()
    n_per_batch = len(audio_loader)
    for batch_num, (Y, (X, L)) in enumerate(zip(curve_loader, audio_loader)): # Go through each mini batch
        # Reset the optimizer's gradients
        optimizer.zero_grad()
        
        # Move inputs/outputs to GPU
        Y = Y.to(device)
        X = X.to(device)
        L = L.to(device)
        
        # Do data augmentation on the audio samples X
        X = pitch_shifter(X.view(X.shape[0], 1, X.shape[1]))[:, 0, :]
        
        
        # Compute chroma and encode the curve
        C = get_batch_chroma(X, win_length, hop_length, hann, chroma_filterbank)
        N = encoder(Y, L, C)
        
        # Add filtered noise to audio and decode
        XN = X + N
        if noise_eps == 0 or np.random.rand() < 0.5:
            XYOut, RGBOut = decoder(XN)
        else:
            if use_mp3_noise:
                added_noise = get_mp3_noise(XN, sr)
            else:
                added_noise = torch.randn(XN.shape).to(XN)
            XYOut, RGBOut = decoder(XN + noise_eps*added_noise)
        
        # Add a loss term for the MSS fit of XN to X, as well as the L1 fit of Y to YOut
        loss_audio = mss_loss(X, XN)
        loss_xy = torch.mean(torch.abs(Y[:, :, 0:2]-XYOut)/torch.abs(0.1+Y[:, :, 0:2]))
        loss_rgb = torch.mean(torch.abs(Y[:, :, 2::]-RGBOut))
        loss = lam_xy*loss_xy + lam_rgb*loss_rgb + loss_audio
        train_xy_loss += loss_xy.item()
        train_rgb_loss += loss_rgb.item()
        train_audio_loss += loss_audio.item()

        
        # Compute the gradients of the loss function with respect
        # to all of the parameters of the model
        loss.backward()
        optimizer.step()
        
        dt = time.time()-tic
        rate = dt/(batch_num+1)
        time_left = (rate*n_per_batch - dt)
        minutes = int(np.floor(time_left/60))
        seconds = int(time_left-60*minutes)
        ipd.clear_output()
        print("Epoch {}, batch {} of {}, {}m{}s left".format(epoch, batch_num+1, n_per_batch, minutes, seconds), flush=True)
        
    last_Y = Y
    last_XYOut = XYOut
    last_RGBOut = RGBOut
        
    audio_losses.append(train_audio_loss/len(audio_loader))
    xy_losses.append(train_xy_loss/len(audio_loader))
    rgb_losses.append(train_rgb_loss/len(audio_loader))
    
    print("Epoch {}, audio loss {:.3f}, xy loss {:.3f}, rgb loss {:.3f}".format(epoch, audio_losses[-1], xy_losses[-1], rgb_losses[-1]))
    scheduler.step()
    
    #####################    STEP 2:  TEST     ##################### 
    curve_loader = DataLoader(curve_test, batch_size=16)
    audio_loader = DataLoader(audio_test, batch_size=16)
    test_xy_loss = 0
    test_rgb_loss = 0
    test_audio_loss = 0
    for batch_num, (Y, (X, L)) in enumerate(zip(curve_loader, audio_loader)): # Go through each mini batch
        # Move inputs/outputs to GPU
        Y = Y.to(device)
        X = X.to(device)
        L = L.to(device)
        # Compute chroma and encode the curve
        C = get_batch_chroma(X, win_length, hop_length, hann, chroma_filterbank)
        N = encoder(Y, L, C)
        # Add filtered noise to audio and decode
        XN = X + N
        if not use_mp3_noise:
            added_noise = torch.randn(XN.shape).to(XN)
            XYOut, RGBOut = decoder(XN + noise_eps*added_noise)
        else:
            # Skip adding mp3 noise for speed
            XYOut, RGBOut = decoder(XN)
        # Add a loss term for the MSS fit of XN to X, as well as the L1 fit of Y to YOut
        loss_audio = mss_loss(X, XN)
        loss_xy = torch.mean(torch.abs(Y[:, :, 0:2]-XYOut)/torch.abs(0.1+Y[:, :, 0:2]))
        loss_rgb = torch.mean(torch.abs(Y[:, :, 2::]-RGBOut))
        test_xy_loss += loss_xy.item()
        test_rgb_loss += loss_rgb.item()
        test_audio_loss += loss_audio.item()
        
    last_Y_test = Y
    last_XYOut_test = XYOut
    last_RGBOut_test = RGBOut
        
    test_audio_losses.append(test_audio_loss/len(audio_loader))
    test_xy_losses.append(test_xy_loss/len(audio_loader))
    test_rgb_losses.append(test_rgb_loss/len(audio_loader))
    
    
    #####################    STEP 3:  LAYLA     ##################### 
    N = encoder(YLayla, lsmooth, CLayla)
    XN = xsmooth + N
    added_noise = torch.randn(XN.shape).to(XN)
    XYOut, RGBOut = decoder(XN + noise_eps*added_noise)
    loss_xy_layla  = torch.mean(torch.abs(YLayla[:, :, 0:2]-XYOut)/torch.abs(0.1+YLayla[:, :, 0:2]))
    loss_rgb_layla = torch.mean(torch.abs(YLayla[:, :, 2::]-RGBOut))
    XYOut = XYOut.detach().cpu().numpy()[0, :, :]
    RGBOut = RGBOut.detach().cpu().numpy()[0, :, :]
    XYRGBOut = np.concatenate((XYOut, RGBOut), axis=1)
    LaylaOut = invert_sparse_coefficients(XYRGBOut, I.shape[0], n_levels, box_scale)
    LaylaOut[LaylaOut < 0] = 0
    LaylaOut[LaylaOut > 1] = 1
    
    # Save as an mp3 and repeat
    x = XN.detach().cpu().numpy()[0, :]
    x = x/np.max(x)
    x = np.array(x*32768, dtype=np.int16)
    wavfile.write("Epoch{}.wav".format(epoch), sr, x)
    
    x = N.detach().cpu().numpy()[0, :]
    x = x/np.max(x)
    x = np.array(x*32768, dtype=np.int16)
    wavfile.write("Epoch{}Noise.wav".format(epoch), sr, x)
    
    subprocess.call("ffmpeg -i Epoch{}.wav Epoch{}.mp3".format(epoch, epoch).split(), stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
    z, sr = librosa.load("Epoch{}.mp3".format(epoch), sr=sr)
    z = torch.from_numpy(z[None, :]).to(XN)
    XYOut_mp3, RGBOut_mp3 = decoder(z)
    loss_xy_layla_mp3  = torch.mean(torch.abs(YLayla[:, :, 0:2]-XYOut_mp3)/torch.abs(0.1+YLayla[:, :, 0:2]))
    loss_rgb_layla_mp3 = torch.mean(torch.abs(YLayla[:, :, 2::]-RGBOut_mp3))
    XYOut_mp3 = XYOut_mp3.detach().cpu().numpy()[0, :, :]
    RGBOut_mp3 = RGBOut_mp3.detach().cpu().numpy()[0, :, :]
    XYRGBOut_mp3 = np.concatenate((XYOut_mp3, RGBOut_mp3), axis=1)
    LaylaOut_mp3 = invert_sparse_coefficients(XYRGBOut_mp3, I.shape[0], n_levels, box_scale)
    LaylaOut_mp3[LaylaOut_mp3 < 0] = 0
    LaylaOut_mp3[LaylaOut_mp3 > 1] = 1

    
    #####################    STEP 4: PLOT    ##################### 
    plt.clf()
    plt.subplot(331)
    plt.imshow(LaylaOut)
    plt.title("Layla Wav Reconstruction")
    plt.subplot(332)
    plt.imshow(LaylaOut_mp3)
    plt.title("Layla Mp3 Reconstruction")
    
    plt.subplot(333)
    plt.plot(audio_losses)
    plt.plot(lam_xy*np.array(xy_losses))
    plt.plot(lam_rgb*np.array(rgb_losses))
    plt.legend(["Audio ({:.3f})".format(audio_losses[-1]), 
                "xy ({:.3f})".format(xy_losses[-1]),
                "rgb ({:.3f})".format(rgb_losses[-1])])
    plt.xlabel("Epoch")
    plt.ylabel("Scaled Loss")
    plt.title("Epoch {}, Train Losses".format(epoch)) 
    
    
    plt.subplot(334)
    plt.scatter(YLayla.detach().cpu()[0, :, 0], XYOut[:, 0])
    plt.axis("equal")
    plt.title("Wav X Coord Layla (Loss {:.3f})".format(loss_xy_layla))
    plt.subplot(335)
    plt.scatter(YLayla.detach().cpu()[0, :, 2], RGBOut[:, 0])
    plt.axis("equal")
    plt.title("Wav R Coord Layla (Loss {:.3f})".format(loss_rgb_layla))
    
    plt.subplot(336)
    plt.plot(test_audio_losses)
    plt.plot(lam_xy*np.array(test_xy_losses))
    plt.plot(lam_rgb*np.array(test_rgb_losses))
    plt.legend(["Audio ({:.3f})".format(test_audio_losses[-1]), 
                "xy ({:.3f})".format(test_xy_losses[-1]),
                "rgb ({:.3f})".format(test_rgb_losses[-1])])
    plt.xlabel("Epoch")
    plt.ylabel("Scaled Loss")
    plt.title("Test Losses")
    
    plt.subplot(337)
    plt.scatter(YLayla.detach().cpu()[0, :, 0], XYOut_mp3[:, 0])
    plt.axis("equal")
    plt.title("Mp3 X Coord Layla (Loss {:.3f})".format(loss_xy_layla_mp3))
    plt.subplot(338)
    plt.scatter(YLayla.detach().cpu()[0, :, 2], RGBOut_mp3[:, 0])
    plt.axis("equal")
    plt.title("Mp3 R Coord Layla (Loss {:.3f})".format(loss_rgb_layla_mp3))
    
    plt.subplot(339)
    textstr = "Epoch {}\n\nlam_xy = {}\nlam_rgb = {}\npre_scale={}\nn_units={}\nn_taps={}\npre_scale={}".format(epoch, lam_xy, lam_rgb, pre_scale, n_units, n_taps, pre_scale)
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    ax = plt.gca()
    ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=14,
            verticalalignment='top', bbox=props)
    plt.axis("off")
    
    plt.savefig("Epoch{}.png".format(epoch), bbox_inches='tight')
    
    epoch += 1