In [1]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from IPython.display import display
import numpy as np
import torch
import matplotlib.pyplot as plt
import scipy.signal as signal
from helpers import audio
from nn_modules import nn_proc

In [2]:
randfunc=np.random.rand

def synth_input_sample(t, chooser):
    if 'sine' == chooser:
        return audio.randsine(t)
    elif 'box' == chooser:
        return audio.box(t)
    elif 'noisysine' == chooser:
        return audio.randsine(t) + 0.1*(2*np.random.rand(t.shape[0])-1)
    elif 'noisybox' == chooser:
        return audio.box(t) * (2*np.random.rand(t.shape[0])-1)
    elif 'pluck' == chooser:
        return audio.pluck(t)
   
N_samples = 4096
t = np.linspace(0,1,N_samples)
old_signal_type = 'box'
x = synth_input_sample(t,old_signal_type)

device = torch.device('cpu')
torch.set_default_tensor_type('torch.FloatTensor')
x_torch = torch.autograd.Variable(torch.from_numpy(x).to(device), requires_grad=False).float()


In [3]:
knobranges = np.array([[-30,0], [1,5], [10,2048]])

# Data settings
shrink_factor = 2  # reduce dimensionality of run by this factor
time_series_length = 8192 // shrink_factor
sampling_freq = 44100. // shrink_factor

# Analysis parameters
ft_size = 1024 // shrink_factor
hop_size = 384 // shrink_factor
expected_time_frames = int(np.ceil(time_series_length/float(hop_size)) + np.ceil(ft_size/float(hop_size)))

# Define model
model = nn_proc.MPAEC(expected_time_frames, ft_size=ft_size, hop_size=hop_size)

# Load model weights
checkpoint_file = 'modelcheckpoint.tar'
checkpoint = torch.load(checkpoint_file, map_location=device)
model.load_state_dict(checkpoint['state_dict'])


In [6]:
# Define interactive widgets and their handler routine
@interact(signal_type=['box','sine','pluck','noisybox','noisysine'],\
    threshold=(knobranges[0][0],knobranges[0][1],1), \
    ratio=(knobranges[1][0],knobranges[1][1],0.1), \
    attackrelease=(knobranges[2][0],knobranges[2][1],50))
def demowidget(signal_type, threshold, ratio, attackrelease):
    global old_signal_type, x, x_torch
    
    # update the model
    checkpoint = torch.load(checkpoint_file, map_location=device)
    model.load_state_dict(checkpoint['state_dict'])
    
    if (signal_type != old_signal_type): # don't regen x unless input changed
        x = synth_input_sample(t, signal_type)
        x_torch = torch.autograd.Variable(torch.from_numpy(x).to(device), requires_grad=False).float()
    old_signal_type = signal_type
 
    y_true = audio.compressor(x, threshold, ratio, attackrelease)

    thresh_nn = (threshold-knobranges[0][0])/(knobranges[0][1]-knobranges[0][0]) - 0.5
    ratio_nn = (ratio-knobranges[1][0])/(knobranges[1][1]-knobranges[1][0]) - 0.5
    attack_nn = (attackrelease-knobranges[2][0])/(knobranges[2][1]-knobranges[2][0]) - 0.5 
    knobs = np.array([thresh_nn, ratio_nn, attack_nn])
    knobs_torch = torch.autograd.Variable(torch.from_numpy(knobs).to(device), requires_grad=False).float()

    y_pred, mag, mag_hat = model.forward(x_torch.unsqueeze(0), knobs_torch.unsqueeze(0))

    plt.figure(figsize=(8,5))
    plt.plot(t,x,c='b',lw=1.5, label='Input')
    plt.plot(t,y_true,c='r',lw=1.5, label='Target')
    plt.plot(t,y_pred.squeeze(0).data.cpu().numpy(),c=(0,0.5,0,0.75),lw=1.5, label='Predicted')
    
    thresh_line = 10**(threshold/20.0)*np.ones(2) # show threshold line
    plt.plot([t[0],t[-1]],thresh_line,c='k',lw=1, linestyle='dashed', label='Threshold') 
    plt.legend(loc='lower right')
    plt.ylim(-1,1)
    plt.show()
    return 

A Jupyter Widget