In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt
import IPython.display as ipd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch_audiomentations import Compose, PitchShift
from scipy import signal
import glob
import librosa
import time
import pickle
import subprocess
from scipy.io import wavfile

import skimage.io
import sys
sys.path.append("../src")
from imtools import get_voronoi_image, splat_voronoi_image_1nn
from audioutils import extract_loudness, get_mp3_noise, get_batch_stft_noise
from tsp import get_tsp_tour
sys.path.append("../models")
from networks import CurveSTFTEncoder, CurveDecoder, BinaryDecoder
from losses import mss_loss
sys.path.append("../data")
from dataset import CurveData, AudioData

## Datasets

In [None]:
# Try to use the GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
print("Device: ", device)

n_samples = 3000
T = 300
samples_per_batch = 10000
win_length = 1024
hop_length = win_length//2
sr = 44100

## Setup hann window for encoder
hann = torch.hann_window(win_length).to(device)

curve_train = CurveData(np.arange(10), n_samples, T, samples_per_batch)
curve_test  = CurveData(np.arange(10, 15), n_samples, T, samples_per_batch)

audio_train = AudioData("../data/musdb18hq/train/*/mixture.wav", T, samples_per_batch, win_length, sr)
audio_test  = AudioData("../data/musdb18hq/test/*/mixture.wav",  T, samples_per_batch, win_length, sr)

# Test Example - Layla And Smooth Criminal

In [None]:
n_points = 3000
n_iters = 100
I = skimage.io.imread("../data/images/layla.png")
N = min(I.shape[0], I.shape[1])
J, YLayla, final_cost = get_voronoi_image(I, 'cpu', n_points, n_neighbs=2, n_iters=n_iters, verbose=False, plot_iter_interval=0, use_lsqr=False)
YLayla = get_tsp_tour(YLayla)

In [None]:
YLayla = np.array(YLayla, dtype=np.float32)
TLayla = YLayla.shape[0]
YLayla = torch.from_numpy(YLayla).to(device)
YLayla = YLayla.view(1, TLayla, 5)
xsmooth, _ = librosa.load("../data/tunes/smoothcriminal.mp3", sr=sr)
#xsmooth, _ = librosa.load("../data/musdb18hq/test/Arise - Run Run Run/mixture.wav", sr=sr)
xsmooth = xsmooth[sr*15::] # Cut off quieter beginning
lsmooth = extract_loudness(xsmooth, sr, hop_length, win_length)
lsmooth = np.array(lsmooth[0:TLayla], dtype=np.float32)
xsmooth = np.array(xsmooth[0:hop_length*(TLayla-1)+win_length], dtype=np.float32)
xsmooth = torch.from_numpy(xsmooth[None, :]).to(device)
lsmooth = torch.from_numpy(lsmooth[None, :, None]).to(device)
print(xsmooth.shape, lsmooth.shape)
plt.imshow(splat_voronoi_image_1nn(YLayla.detach().cpu()[0, :, :], 512, 512))

# Train Loop

In [None]:
## Parameters for encoder/decoder
mlp_depth = 3
n_units = 512
noise_eps = 0.2
use_mp3_noise = True
n_taps = 10
max_lag = 512
tap_amp = 0.02
tap_sigma = 5

lam_xy = 150 # Weight for geometry curve fit
lam_rgb = 50 # Weight for rgb curve fit
lam_taps = 1 # Weight for change of taps

encoder = CurveSTFTEncoder(mlp_depth, n_units, win_length, n_taps, max_lag, tap_amp, tap_sigma, 5)
encoder = encoder.to(device)
decoder = CurveDecoder(mlp_depth, n_units, win_length, 5)
decoder = decoder.to(device)
print("Encoder params", encoder.get_num_parameters())
print("Decoder params", decoder.get_num_parameters())

## Setup the loss function
batch_size=16
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)
# Start at learning rate of 0.001, rate decay of 0.98 factor every 10,000 steps
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10000/len(curve_train), gamma=0.98)

n_epochs = 1000
xy_losses = []
rgb_losses = []
taps_losses = []
audio_losses = []

test_taps_losses = []
test_xy_losses = []
test_rgb_losses = []
test_audio_losses = []

## Setup data augmentation
pitch_shifter = PitchShift(sample_rate=sr)

In [None]:
fac = 0.7
plt.figure(figsize=(fac*18, fac*24))

epoch = 0
while epoch < n_epochs:
    #####################    STEP 1:  TRAIN     ##################### 
    curve_loader = DataLoader(curve_train, batch_size=16, shuffle=True)
    audio_loader = DataLoader(audio_train, batch_size=16, shuffle=True)
    
    train_xy_loss = 0
    train_rgb_loss = 0
    train_audio_loss = 0
    train_taps_loss = 0
    tic = time.time()
    n_per_batch = len(audio_loader)
    for batch_num, (Y, (X, L)) in enumerate(zip(curve_loader, audio_loader)): # Go through each mini batch
        # Reset the optimizer's gradients
        optimizer.zero_grad()
        
        # Move inputs/outputs to GPU
        Y = Y.to(device)
        X = X.to(device)
        
        # Do data augmentation on the audio samples X
        X = pitch_shifter(X.view(X.shape[0], 1, X.shape[1]))[:, 0, :]
        
        # Encode the curve
        N, taps = encoder(X, Y)        
        
        # Add filtered noise to audio and decode
        XN = N
        if noise_eps == 0 or np.random.rand() < 0.5:
            YOut = decoder(XN)
        else:
            if use_mp3_noise:
                added_noise = get_mp3_noise(XN, sr)
            else:
                added_noise = torch.randn(XN.shape).to(XN)
            YOut = decoder(XN + noise_eps*added_noise)
        
        # Loss terms
        loss_audio = mss_loss(X, XN)
        
        loss_xy  = lam_xy *torch.mean(torch.abs(YOut[:, :, 0:2]-Y[:, :, 0:2]))
        loss_rgb = lam_rgb*torch.mean(torch.abs(YOut[:, :, 2::]-Y[:, :, 2::]))
        
        ## Minimize the amount by which the taps change from step to step
        diff = taps[:, 1::, :] - taps[:, 0:-1, :]
        diff = torch.sum(torch.abs(diff), dim=2)
        loss_taps = lam_taps*torch.mean(diff)
        
        loss = loss_xy + loss_rgb + loss_audio #+ loss_taps
        train_xy_loss += loss_xy.item()
        train_rgb_loss += loss_rgb.item()
        train_audio_loss += loss_audio.item()
        train_taps_loss += loss_taps.item()
        
        # Compute the gradients of the loss function with respect
        # to all of the parameters of the model
        loss.backward()
        optimizer.step()
        
        dt = time.time()-tic
        rate = dt/(batch_num+1)
        time_left = (rate*n_per_batch - dt)
        minutes = int(np.floor(time_left/60))
        seconds = int(time_left-60*minutes)
        ipd.clear_output()
        print("Epoch {}, batch {} of {}, {}m{}s left".format(epoch, batch_num+1, n_per_batch, minutes, seconds), flush=True)
        
    last_Y = Y
    last_YOut = YOut
        
    audio_losses.append(train_audio_loss/len(audio_loader))
    xy_losses.append(train_xy_loss/len(audio_loader))
    rgb_losses.append(train_rgb_loss/len(audio_loader))
    taps_losses.append(train_taps_loss/len(audio_loader))
    
    print("Epoch {}, audio loss {:.3f}, xy loss {:.3f}, rgb loss {:.3f}".format(epoch, audio_losses[-1], xy_losses[-1], rgb_losses[-1]))
    scheduler.step()
    
    #####################    STEP 2:  TEST     ##################### 
    curve_loader = DataLoader(curve_test, batch_size=16)
    audio_loader = DataLoader(audio_test, batch_size=16)
    test_xy_loss = 0
    test_rgb_loss = 0
    test_audio_loss = 0
    test_taps_loss = 0
    for batch_num, (Y, (X, L)) in enumerate(zip(curve_loader, audio_loader)): # Go through each mini batch
        # Move inputs/outputs to GPU
        Y = Y.to(device)
        X = X.to(device)
        L = L.to(device)
        # Encode the curve
        N, taps_diff = encoder(X, Y)  
        # Add filtered noise to audio and decode
        XN = N
        if not use_mp3_noise:
            added_noise = torch.randn(XN.shape).to(XN)
            YOut = decoder(XN + noise_eps*added_noise)
        else:
            # Skip adding mp3 noise for speed
            YOut = decoder(XN)
        # Loss terms
        loss_audio = mss_loss(X, XN)
        
        loss_xy  = lam_xy *torch.mean(torch.abs(YOut[:, :, 0:2]-Y[:, :, 0:2]))
        loss_rgb = lam_rgb*torch.mean(torch.abs(YOut[:, :, 2::]-Y[:, :, 2::]))
        
        # Minimize the amount by which the taps change from step to step
        diff = taps[:, 1::, :] - taps[:, 0:-1, :]
        diff = torch.sum(torch.abs(diff), dim=2)
        loss_taps = lam_taps*torch.mean(diff)
        
        test_xy_loss += loss_xy.item()
        test_rgb_loss += loss_rgb.item()
        test_audio_loss += loss_audio.item()
        test_taps_loss += loss_taps.item()
        
    last_Y_test = Y
    last_YOut_test = YOut
        
    test_audio_losses.append(test_audio_loss/len(audio_loader))
    test_xy_losses.append(test_xy_loss/len(audio_loader))
    test_rgb_losses.append(test_rgb_loss/len(audio_loader))
    test_taps_losses.append(test_taps_loss/len(audio_loader))
    
    
    #####################    STEP 3:  LAYLA     ##################### 
    N, taps = encoder(xsmooth, YLayla)
    XN = N
    added_noise = torch.randn(XN.shape).to(XN)
    YOut = decoder(XN + noise_eps*added_noise)
    loss_layla_xy  = torch.mean(torch.abs(YOut[:, :, 0:2]-YLayla[:, :, 0:2]))
    loss_layla_rgb = torch.mean(torch.abs(YOut[:, :, 2::]-YLayla[:, :, 2::]))

    YOut = YOut.detach().cpu()[0, :, :].numpy()
    # Save as an mp3 and repeat
    x = XN.detach().cpu().numpy()[0, :]
    x = x/np.max(x)
    x = np.array(x*32768, dtype=np.int16)
    wavfile.write("Epoch{}.wav".format(epoch), sr, x)

    x = N.detach().cpu().numpy()[0, :]
    x = x/np.max(x)
    x = np.array(x*32768, dtype=np.int16)
    wavfile.write("Epoch{}Noise.wav".format(epoch), sr, x)

    subprocess.call("ffmpeg -i Epoch{}.wav Epoch{}.mp3".format(epoch, epoch).split(), stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
    z, sr = librosa.load("Epoch{}.mp3".format(epoch), sr=sr)
    z = torch.from_numpy(z[None, :]).to(XN)
    YOut_mp3 = decoder(z)
    loss_layla_xy_mp3  = torch.mean(torch.abs(YOut_mp3[:, :, 0:2]-YLayla[:, :, 0:2]))
    loss_layla_rgb_mp3 = torch.mean(torch.abs(YOut_mp3[:, :, 2::]-YLayla[:, :, 2::]))
    YOut_mp3 = YOut_mp3.detach().cpu()[0, :, :].numpy()



    #####################    STEP 4: PLOT    ##################### 
    plt.clf()
    plt.subplot(431)
    plt.imshow(splat_voronoi_image_1nn(YOut, 512, 512))
    plt.title("Layla Wav Reconstruction")
    plt.subplot(432)
    plt.imshow(splat_voronoi_image_1nn(YOut_mp3, 512, 512))
    plt.title("Layla Mp3 Reconstruction")

    plt.subplot(433)
    plt.plot(audio_losses)
    plt.plot(xy_losses)
    plt.plot(rgb_losses)
    plt.plot(taps_losses)
    plt.legend(["Audio ({:.3f})".format(audio_losses[-1]), 
                "xy ({:.3f})".format(xy_losses[-1]),
                "rgb ({:.3f})".format(rgb_losses[-1]),
                "taps ({:.3f})".format(taps_losses[-1])])
    plt.xlabel("Epoch")
    plt.ylabel("Scaled Loss")
    plt.title("Epoch {}, Train Losses".format(epoch)) 


    plt.subplot(434)
    plt.scatter(YLayla.detach().cpu()[0, :, 0], YOut[:, 0])
    plt.axis("equal")
    plt.title("Wav X Coord Layla (Mean Error {:.3f})".format(loss_layla_xy))
    plt.subplot(435)
    plt.scatter(YLayla.detach().cpu()[0, :, 2], YOut[:, 2])
    plt.axis("equal")
    plt.title("Wav R Coord Layla (Mean Error {:.3f})".format(loss_layla_rgb))

    plt.subplot(436)
    plt.plot(test_audio_losses)
    plt.plot(test_xy_losses)
    plt.plot(test_rgb_losses)
    plt.plot(test_taps_losses)
    plt.legend(["Audio ({:.3f})".format(test_audio_losses[-1]), 
                "xy ({:.3f})".format(test_xy_losses[-1]), 
                "rgb ({:.3f})".format(test_rgb_losses[-1]), 
                "taps ({:.3f})".format(test_taps_losses[-1])])
    plt.xlabel("Epoch")
    plt.ylabel("Scaled Loss")
    plt.title("Test Losses")

    plt.subplot(437)
    plt.scatter(YLayla.detach().cpu()[0, :, 0], YOut_mp3[:, 0])
    plt.axis("equal")
    plt.title("Mp3 X Coord Layla (Mean Error {:.3f})".format(loss_layla_xy_mp3))
    plt.subplot(438)
    plt.scatter(YLayla.detach().cpu()[0, :, 2], YOut_mp3[:, 2])
    plt.axis("equal")
    plt.title("Mp3 R Coord Layla (Mean Error {:.3f})".format(loss_layla_rgb_mp3))

    plt.subplot(439)
    textstr = "Epoch {}\n\nlam_xy = {}\nlam_rgb = {}\nlam_taps = {}\nn_units={}\nn_taps={}\nmax_lag={}\ntap_amp={}\ntap_sigma={}".format(epoch, lam_xy, lam_rgb, lam_taps, n_units, n_taps, max_lag, tap_amp, tap_sigma)
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    ax = plt.gca()
    ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=14,
            verticalalignment='top', bbox=props)
    plt.axis("off")
    
    plt.subplot2grid((4, 3), (3, 0))
    plt.imshow(taps.detach().cpu()[0, :, 1::], aspect='auto', interpolation='none', cmap='magma_r')
    plt.subplot2grid((4, 3), (3, 1), colspan=2)
    plt.plot(taps.detach().cpu()[0, 0, 1::])
    plt.plot(taps.detach().cpu()[0, 2, 1::])
    plt.plot(taps.detach().cpu()[0, 3, 1::])

    plt.savefig("Epoch{}.png".format(epoch), bbox_inches='tight')

    epoch += 1

In [None]:
torch.save(encoder.state_dict(), "encoder.pkl")

In [None]:
torch.save(decoder.state_dict(), "decoder.pkl")