In [78]:
import soundfile as snd
import numpy as np
import torch
import torch.nn as nn
from poutyne.framework import Model
import plotly.graph_objects as go
from pathlib import Path

fold = Path.cwd() / '..' / 'data'
x = snd.read(fold / 'Raw Guitar-1.flac')[0]
y = snd.read(fold / 'Orange Amp Heavy-1.flac')[0]

In [2]:
def audio_plot(x, n=1000):
    from scipy.ndimage.morphology import grey_dilation, grey_erosion
    m = max(1,len(x)//n)
    xmax = grey_dilation(x, size=m)[::m]
    xmin = grey_erosion (x, size=m)[::m]
    rng = np.arange(0,len(x),m)
    return [go.Scatter(x=rng, y=xmin, mode='lines', fill=None, line_color='indigo'), go.Scatter(x=rng, y=xmax, mode='lines', fill='tonexty', line_color='indigo')]

In [3]:
go.FigureWidget(data=audio_plot(x))

FigureWidget({
    'data': [{'line': {'color': 'indigo'},
              'mode': 'lines',
              'type':…

In [4]:
go.FigureWidget(data=audio_plot(y))

FigureWidget({
    'data': [{'line': {'color': 'indigo'},
              'mode': 'lines',
              'type':…

In [79]:
from torch.utils.data import Dataset, DataLoader
class CroppedDataset(Dataset):
    def __init__(self, x, y, w):
        self.x = torch.FloatTensor(x)
        self.y = torch.FloatTensor(y)
        self.w = w
        self.l = min(len(x),len(y))-w
        assert self.l > 0
        self.idx = np.random.permutation(range(self.l))
        
    def __getitem__(self, i):
        i = self.idx[i]
        return self.x[i:i+self.w].reshape(1,-1), self.y[i:i+self.w].reshape(1,-1)
    
    def __len__(self):
        return self.l

In [80]:
from poutyne.framework import Callback
from IPython.display import display
from time import time
from scipy.ndimage import gaussian_filter, median_filter
from math import ceil

class Plotter(Callback):
    def __init__(self):
        self.fig = None
        self.data = []
    
    def on_batch_end(self, batch, logs):
        self.data += [logs['loss']]
        if self.fig is None:
            ysmooth = median_filter(self.data, ceil(len(self.data)/100))
            self.fig = go.FigureWidget(data=[go.Scattergl(y=self.data),go.Scattergl(y=ysmooth)])
            display(self.fig)
            self.last_t = time()
        else:
            t = time()
            if t - self.last_t > 0.5:
                self.fig.data[0].y = self.data
                self.fig.data[1].y = median_filter(self.data, ceil(len(self.data)/100))
                self.last_t = t


In [81]:
from torch.optim import AdamW

ds = CroppedDataset(x,y[500:], 10000)
dl = DataLoader(ds, batch_size=1)

In [19]:
from math import floor
import torch.nn.functional as F

class PolySoftPlusFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        r1 = (x<=-0.5)
        r2 = (~r1)&(x<0.5)
        r1,r2 = torch.nonzero(r1).T,torch.nonzero(r2).T
        ctx.save_for_backward(x,r1,r2)
        y = x.clone()
        r1,r2 = tuple(r1),tuple(r2)
        y[r1] = 0
        y[r2] = 0.5*(x[r2]+0.5)**2
        return y
    
    @staticmethod
    def backward(ctx, g):
        x,r1,r2 = ctx.saved_tensors
        r1,r2 = tuple(r1), tuple(r2)
        go = g.clone()
        go[r1] = 0
        go[r2] *= x[r2]+0.5
        return go

class PolySoftPlus(nn.Module):
    def __init__(self):
        super().__init__()
        self.f = PolySoftPlusFunction()
        
    def forward(self, x):
        return self.f.apply(x)

class StackedNet2(nn.Module):
    def __init__(self, layer_width, dilations):
        super().__init__()
        self.layer_width = layer_width
        self.n_layers = len(dilations)
        print(dilations)
        pw = 1
        n = 0
        self.layers = []
        for d in dilations:
            k = min(2,1+d)
            self.layers += [
                nn.Conv1d(pw, layer_width, kernel_size=k, dilation=max(d,1)),
                nn.ReLU(),
                #PolySoftPlus(),
            ]
            n += pw*layer_width*k + layer_width
            pw = layer_width
        self.layers = nn.Sequential(*self.layers, nn.Conv1d(pw,1,1))
        n += pw + 1
        print(n, ' parameters')
        
    def forward(self, x):
        return self.layers(x)

In [82]:
def trunc_mse(ytrue,y):
    m = min(ytrue.shape[-1],y.shape[-1])
    return torch.mean((ytrue[:,:,-m:]-y[:,:,-m:])**2)

net = StackedNet2(10,list(reversed([2,4,8,16,32,64,128,256])))

model = Model(net, AdamW(net.parameters()), trunc_mse)
model.fit_generator(dl, callbacks=[Plotter()])

[256, 128, 64, 32, 16, 8, 4, 2]
1511  parameters
Epoch 1/1000 ETA 657623s Step 1/11254000: loss: 0.141564

FigureWidget({
    'data': [{'type': 'scattergl', 'uid': '2c5462ee-6d98-44e5-a8a3-c5459db252bd', 'y': [0.14156…

Epoch 1/1000 ETA 259298s Step 1572/11254000: loss: 0.003547

KeyboardInterrupt: 

In [74]:
xtest = x[1010000:1014000]
go.FigureWidget(data=audio_plot(xtest))

FigureWidget({
    'data': [{'line': {'color': 'indigo'},
              'mode': 'lines',
              'type':…

In [75]:
ytrue = y[800:][1010000:1014000]
go.FigureWidget(data=audio_plot(ytrue))

FigureWidget({
    'data': [{'line': {'color': 'indigo'},
              'mode': 'lines',
              'type':…

In [76]:
xpred = model.predict_on_batch(torch.FloatTensor(xtest).reshape(1,1,-1))[0,0]
go.FigureWidget(data=audio_plot(xpred))

FigureWidget({
    'data': [{'line': {'color': 'indigo'},
              'mode': 'lines',
              'type':…

In [83]:
go.FigureWidget(data=[go.Scatter(y=l.weight.detach().numpy().ravel(), mode='lines') for l in net.layers if hasattr(l,'weight')])

FigureWidget({
    'data': [{'mode': 'lines',
              'type': 'scatter',
              'uid': '2084cb20-…

In [16]:
go.FigureWidget(data=[go.Heatmap(z=net.layers[8].weight.detach().numpy()[:,:,0])])

FigureWidget({
    'data': [{'type': 'heatmap',
              'uid': 'f975ec9f-9a2d-4caf-8148-f7c53cd3c029',
 …

Yay! It works! We can train a neural network to function as a drop-in replacement for distortion. This is really fascinating....

Now, time to generate code to implement this.

In [77]:
torch.save(net.layers, '/home/michael/Other/code/cnn_distortion/models/cnn_dist_v3.pt')


Couldn't retrieve source code for container of type Sequential. It won't be checked for correctness upon loading.


Couldn't retrieve source code for container of type Conv1d. It won't be checked for correctness upon loading.


Couldn't retrieve source code for container of type ReLU. It won't be checked for correctness upon loading.

