In [1]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from IPython.display import display

import numpy as np
import torch
import matplotlib.pyplot as plt
import scipy.signal as signal
from helpers import audio
#%matplotlib inline  # not necessary

In [2]:
randfunc=np.random.rand

def synth_input_sample(t, chooser):
    if 'sine' == chooser:
        return audio.randsine(t)
    elif 'box' == chooser:
        return audio.box(t)
    elif 'noisysine' == chooser:
        return audio.randsine(t) + 0.1*(2*np.random.rand(t.shape[0])-1)
    elif 'noisybox' == chooser:
        return audio.box(t) * (2*np.random.rand(t.shape[0])-1)
    elif 'pluck' == chooser:
        return audio.pluck(t)
   
N_samples = 4096
t = np.linspace(0,1,N_samples)
old_sig_type = 'box'
x = synth_input_sample(t,old_sig_type)

device = torch.device('cpu')
torch.set_default_tensor_type('torch.FloatTensor')
x_torch = torch.autograd.Variable(torch.from_numpy(x).to(device), requires_grad=False).float()


In [3]:
#knobranges = {'thresh':[-30,0], 'ratio':[1,5], 'attack':[10,800]}
knobranges = np.array([[-30,0], [1,5], [10,800]])

def compressor(x, thresh=-24, ratio=2, attack=2048, dtype=np.float32):
    """
    simple compressor effect, code thanks to Eric Tarr @hackaudio
    Inputs:
       x:        the input waveform
       thresh:   threshold in dB
       ratio:    compression ratio
       attack:   attack & release time (it's a simple compressor!) in samples
    """
    fc = 1.0/float(attack)               # this is like 1/attack time
    b, a = signal.butter(1, fc, analog=False, output='ba')
    zi = signal.lfilter_zi(b, a)
    dB = 20. * np.log10(np.abs(x) + 1e-6).astype(dtype)
    in_env, _ = signal.lfilter(b, a, dB, zi=zi*dB[0])  # input envelope calculation
    out_env = np.copy(in_env).astype(dtype)               # output envelope
    i = np.where(in_env >  thresh)          # compress where input env exceeds thresh
    out_env[i] = thresh + (in_env[i]-thresh)/ratio
    gain = np.power(10.0,(out_env-in_env)/10).astype(dtype)
    y = (np.copy(x) * gain).astype(dtype)
    return y


In [4]:
from nn_modules import cls_fe_dft, nn_proc
from helpers import audio
from losses import loss_functions
import torch.nn as nn

class MPAEC(nn.Module):  # mag-phase autoencoder
    """
        Class for building the analysis part
        of the Front-End ('Fe').
    """
    def __init__(self, expected_time_frames, ft_size=1024, hop_size=384, decomposition_rank=25):
        super(MPAEC, self).__init__()
        self.dft_analysis = cls_fe_dft.Analysis(ft_size=ft_size, hop_size=hop_size)
        self.dft_synthesis = cls_fe_dft.Synthesis(ft_size=ft_size, hop_size=hop_size)
        self.aenc = nn_proc.AutoEncoder(expected_time_frames, decomposition_rank)
        self.phs_aenc = nn_proc.AutoEncoder(expected_time_frames, 2)


    def clip_grad_norm_(self):
        torch.nn.utils.clip_grad_norm_(list(self.dft_analysis.parameters()) +
                                      list(self.dft_synthesis.parameters()),
                                      max_norm=1., norm_type=1)

    def forward(self, x_cuda, knobs_cuda):
        # trainable STFT, outputs spectrograms for real & imag parts
        x_real, x_imag = self.dft_analysis.forward(x_cuda)
        # Magnitude-Phase computation
        mag = torch.norm(torch.cat((x_real.unsqueeze(0), x_imag.unsqueeze(0)), 0), 2, dim=0)
        phs = torch.atan2(x_imag, x_real+1e-6)

        # Processes Magnitude and phase individually
        mag_hat = self.aenc.forward(mag, knobs_cuda, skip_connections='sf')
        phs_hat = self.phs_aenc.forward(phs, knobs_cuda, skip_connections=False) + phs # <-- Slightly smoother convergence

        # Back to Real and Imaginary
        an_real = mag_hat * torch.cos(phs_hat)
        an_imag = mag_hat * torch.sin(phs_hat)

        # Forward synthesis pass
        x_hat = self.dft_synthesis.forward(an_real, an_imag)

        return x_hat, mag, mag_hat

    
# Data settings
shrink_factor = 2  # reduce dimensionality of run by this factor
time_series_length = 8192 // shrink_factor
sampling_freq = 44100. // shrink_factor

# Analysis parameters
ft_size = 1024 // shrink_factor
hop_size = 384 // shrink_factor
expected_time_frames = int(np.ceil(time_series_length/float(hop_size)) + np.ceil(ft_size/float(hop_size)))
decomposition_rank = 25


model = MPAEC(expected_time_frames, ft_size=ft_size, hop_size=hop_size, decomposition_rank=decomposition_rank)    
checkpoint = torch.load('modelcheckpoint.tar', map_location=device)
model.load_state_dict(checkpoint['state_dict'])


In [9]:
@interact(sig_type=['box','sine','pluck','noisybox','noisysine'],\
    thresh=(-0.5,0.5,0.1),ratio=(-0.5,0.5,0.1),attackrelease=(-0.5,0.5,0.1))
def demowidget(sig_type, thresh, ratio, attackrelease):
    global old_sig_type, x, x_torch
    
    # update the model
    checkpoint = torch.load('modelcheckpoint.tar', map_location=device)
    model.load_state_dict(checkpoint['state_dict'])
    
    if (sig_type != old_sig_type): # don't regen x unless input changed
        x = synth_input_sample(t, sig_type)
        x_torch = torch.autograd.Variable(torch.from_numpy(x).to(device), requires_grad=False).float()
    old_sig_type = sig_type
    
    thresh_w = knobranges[0][0] + (thresh+0.5)*(knobranges[0][1]-knobranges[0][0])
    ratio_w  = knobranges[1][0] + (ratio +0.5)*(knobranges[1][1]-knobranges[1][0])
    attackrelease_w = knobranges[2][0] + (attackrelease+0.5)*(knobranges[2][1]-knobranges[2][0])

    y_true = compressor(x, thresh_w, ratio_w, attackrelease_w)
    
    knobs = np.array([thresh, ratio, attackrelease])
    #knobs = (knobs_w - knobranges[:,0])/(knobranges[:,1]-knobranges[:,0]) - 0.5
    #print("knobs_w, knobs = ",knobs_w, knobs)
    knobs_torch = torch.autograd.Variable(torch.from_numpy(knobs).to(device), requires_grad=False).float()

    y_pred, mag, mag_hat = model.forward(x_torch.unsqueeze(0), knobs_torch.unsqueeze(0))

    plt.figure(figsize=(8,5))
    plt.plot(t,x,c='b',lw=1.5, label='Input')
    plt.plot(t,y_true,c='r',lw=1.5, label='Target')
    plt.plot(t,y_pred.squeeze(0).data.cpu().numpy(),c=(0,0.5,0,0.75),lw=1.5, label='Predicted')
    
    thresh_line = 10**(thresh_w/20.0)*np.ones(2) # show threshold line
    plt.plot([t[0],t[-1]],thresh_line,c='k',lw=1, linestyle='dashed', label='Threshold') 
    plt.legend(loc='lower right')
    plt.ylim(-1,1)
    plt.show()
    return 

A Jupyter Widget