In [1]:
import soundfile as snd
import numpy as np
import torch
import torch.nn as nn
from poutyne.framework import Model
import plotly.graph_objects as go
from pathlib import Path

fold = Path.home() / 'Music' / 'Ardour' / 'CNN Training Data' / 'interchange' / 'CNN Training Data' / 'audiofiles'
x = snd.read(fold / 'Raw Guitar-1.wav')[0]
y = snd.read(fold / 'Orange Amp Heavy-1.wav')[0]

In [2]:
def audio_plot(x, n=1000):
    from scipy.ndimage.morphology import grey_dilation, grey_erosion
    m = max(1,len(x)//n)
    xmax = grey_dilation(x, size=m)[::m]
    xmin = grey_erosion (x, size=m)[::m]
    rng = np.arange(0,len(x),m)
    return [go.Scatter(x=rng, y=xmin, mode='lines', fill=None, line_color='indigo'), go.Scatter(x=rng, y=xmax, mode='lines', fill='tonexty', line_color='indigo')]

In [3]:
go.FigureWidget(data=audio_plot(x))

FigureWidget({
    'data': [{'line': {'color': 'indigo'},
              'mode': 'lines',
              'type':…

In [4]:
go.FigureWidget(data=audio_plot(y))

FigureWidget({
    'data': [{'line': {'color': 'indigo'},
              'mode': 'lines',
              'type':…

In [5]:
from math import floor

class StackedNet(nn.Module):
    def __init__(self, layer_width, n_layers):
        super().__init__()
        self.layer_width = layer_width
        self.n_layers = n_layers
        pw = 1
        i = 0
        self.layers = []
        for d in range(n_layers):
            d = floor(2**d)
            self.layers += [nn.Sequential(
                nn.Conv1d(pw, layer_width, kernel_size=3, dilation=d, padding=d),
                nn.ReLU()
            )]
            pw = layer_width
            i += layer_width
        self.layers = nn.Sequential(*self.layers)
        self.olayer = nn.Conv1d(i,1,kernel_size=1)
        
    def forward(self, x):
        xl = [x]
        for l in self.layers:
            xl += [l(xl[-1])]
        return self.olayer(torch.cat(xl[1:], dim=1))

In [6]:
from torch.utils.data import Dataset, DataLoader
class CroppedDataset(Dataset):
    def __init__(self, x, y, w):
        self.x = torch.FloatTensor(x)
        self.y = torch.FloatTensor(y)
        self.w = w
        self.l = min(len(x),len(y))-w
        assert self.l > 0
        self.idx = np.random.permutation(range(self.l))
        
    def __getitem__(self, i):
        i = self.idx[i]
        return self.x[i:i+self.w].reshape(1,-1), self.y[i:i+self.w].reshape(1,-1)
    
    def __len__(self):
        return self.l

In [72]:
from poutyne.framework import Callback
from IPython.display import display
from time import time
from scipy.ndimage import gaussian_filter, median_filter
from math import ceil

class Plotter(Callback):
    def __init__(self):
        self.fig = None
        self.data = []
    
    def on_batch_end(self, batch, logs):
        self.data += [logs['loss']]
        if self.fig is None:
            ysmooth = median_filter(self.data, ceil(len(self.data)/100))
            self.fig = go.FigureWidget(data=[go.Scattergl(y=self.data),go.Scattergl(y=ysmooth)])
            display(self.fig)
            self.last_t = time()
        else:
            t = time()
            if t - self.last_t > 0.5:
                self.fig.data[0].y = self.data
                self.fig.data[1].y = median_filter(self.data, ceil(len(self.data)/100))
                self.last_t = t


In [68]:
from torch.optim import AdamW

ds = CroppedDataset(x,y[500:], 10000)
dl = DataLoader(ds, batch_size=1)

In [19]:
from math import floor
import torch.nn.functional as F

class PolySoftPlusFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        r1 = (x<=-0.5)
        r2 = (~r1)&(x<0.5)
        r1,r2 = torch.nonzero(r1).T,torch.nonzero(r2).T
        ctx.save_for_backward(x,r1,r2)
        y = x.clone()
        r1,r2 = tuple(r1),tuple(r2)
        y[r1] = 0
        y[r2] = 0.5*(x[r2]+0.5)**2
        return y
    
    @staticmethod
    def backward(ctx, g):
        x,r1,r2 = ctx.saved_tensors
        r1,r2 = tuple(r1), tuple(r2)
        go = g.clone()
        go[r1] = 0
        go[r2] *= x[r2]+0.5
        return go

class PolySoftPlus(nn.Module):
    def __init__(self):
        super().__init__()
        self.f = PolySoftPlusFunction()
        
    def forward(self, x):
        return self.f.apply(x)
        #y = torch.empty_like(x)
        #r2 = (x>-0.5)&(x<0.5)
        #y[~r2] = F.relu(x[~r2])
        #y[r2] = 0.5*(x[r2]+0.5)**2
        #return y

class StackedNet2(nn.Module):
    def __init__(self, layer_width, dilations):
        super().__init__()
        self.layer_width = layer_width
        self.n_layers = len(dilations)
        print(dilations)
        pw = 1
        n = 0
        self.layers = []
        for d in dilations:
            k = min(2,1+d)
            self.layers += [
                nn.Conv1d(pw, layer_width, kernel_size=k, dilation=max(d,1)),
                nn.ReLU(),
                #PolySoftPlus(),
            ]
            n += pw*layer_width*k + layer_width
            pw = layer_width
        self.layers = nn.Sequential(*self.layers, nn.Conv1d(pw,1,1))
        n += pw + 1
        print(n, ' parameters')
        
    def forward(self, x):
        return self.layers(x)

In [73]:
#net = StackedNet2(10, 2**np.arange(10))
#net = StackedNet2(8, 2**np.arange(1,10))
#net = StackedNet2(8, 2**np.arange(10))
#net = StackedNet2(8, [100,110,120,130,140,150,160,170,180,190])
#net = StackedNet2(3, np.arange(50,100,10))
#net = StackedNet2(5, [0,0,0,60,70,80,90,100,0,0,0])
#net = StackedNet2(5,[1,2,4,8,16,32,64])

def trunc_mse(ytrue,y):
    m = min(ytrue.shape[-1],y.shape[-1])
    return torch.mean((ytrue[:,:,-m:]-y[:,:,-m:])**2)

#net = StackedNet2(10, np.arange(100,200,10))
primes = [101,103,107,109,113,127,131,137,139,149,151,157,163,167,173,179,181,191,193,197,199,211]
#net = StackedNet2(10,[101,131,167,181,211])
#net = StackedNet2(4,[8,16,32,64,128,256])
#net = StackedNet2(4,[8,16,32,150])
#net = StackedNet2(7,[8,16,32,64,128,256])
#net = StackedNet2(10,[8,16,32,64,128,256])
net = StackedNet2(10,list(reversed([2,4,8,16,32,64,128,256])))
#net = StackedNet2(30,[128,64,32,16,8,256,128,64,32,16])

l1loss = nn.L1Loss()

def reg_loss(ytrue,y):
    return trunc_mse(ytrue,y) + 0.005*sum(l1loss(p,torch.zeros_like(p)) for p in net.parameters())

model = Model(net, AdamW(net.parameters()), trunc_mse)

model.fit_generator(dl, callbacks=[Plotter()])

[256, 128, 64, 32, 16, 8, 4, 2]
1511  parameters
Epoch 1/1000 ETA 168783s Step 1/11254000: loss: 0.046454

FigureWidget({
    'data': [{'type': 'scattergl', 'uid': 'f67e7e87-4824-45d1-b954-c3f50491775a', 'y': [0.04645…

Epoch 1/1000 ETA 258862s Step 24325/11254000: loss: 0.002140

KeyboardInterrupt: 

In [74]:
xtest = x[1010000:1014000]
go.FigureWidget(data=audio_plot(xtest))

FigureWidget({
    'data': [{'line': {'color': 'indigo'},
              'mode': 'lines',
              'type':…

In [75]:
ytrue = y[800:][1010000:1014000]
go.FigureWidget(data=audio_plot(ytrue))

FigureWidget({
    'data': [{'line': {'color': 'indigo'},
              'mode': 'lines',
              'type':…

In [76]:
xpred = model.predict_on_batch(torch.FloatTensor(xtest).reshape(1,1,-1))[0,0]
go.FigureWidget(data=audio_plot(xpred))

FigureWidget({
    'data': [{'line': {'color': 'indigo'},
              'mode': 'lines',
              'type':…

In [14]:
print(sum(len(p) for p in net.parameters()))

86


In [15]:
go.FigureWidget(data=[go.Scatter(y=l.weight.detach().numpy().ravel(), mode='lines') for l in net.layers if hasattr(l,'weight')])

FigureWidget({
    'data': [{'mode': 'lines',
              'type': 'scatter',
              'uid': '16cf30d7-…

In [16]:
go.FigureWidget(data=[go.Heatmap(z=net.layers[8].weight.detach().numpy()[:,:,0])])

FigureWidget({
    'data': [{'type': 'heatmap',
              'uid': 'f975ec9f-9a2d-4caf-8148-f7c53cd3c029',
 …

Yay! It works! We can train a neural network to function as a drop-in replacement for distortion. This is really fascinating....

Now, time to generate code to implement this.

In [347]:
def mat_to_str(m):
    if len(m.shape) == 0:
        return str(m.item())
    else:
        return '{' + ','.join(mat_to_str(mi) for mi in m) + '}'
    
def layer_to_str(layer, name, xi, xo, latency):
    m = layer.in_channels
    n = layer.out_channels
    d, = layer.dilation
    k, = layer.kernel_size

    def join(s,l,j,e):
        return s + j.join(str(li) for li in l) + e

    #weights = mat_to_str(layer.weight)
    weights = mat_to_str(np.moveaxis(layer.weight.detach().numpy(),2,0))
    bias = mat_to_str(layer.bias)

    r = f"""
    // auto-generated code for layer {name}: {layer}
    const float w_{name}[{n}][{m}][{k}] = {weights};
    const float b_{name}[{n}] = {bias};
    
    // Fill with biases for {name}
    for (int i = 0; i < {n}; i++) {{
        float bi = b_{name}[i];
        for (int l = {latency}; l < L; l++) {{
            {xo}[i][l] = bi;
        }}
    }}
    
    // Apply main filter for {name}
    for (int i = 0; i < {n}; i++){{
        for (int j = 0; j < {m}; j++){{
            for (int k = 0; k < {k}; k++) {{
                float wijk = w_{name}[i][j][k];
                int offset = ({k-1}-k)*{d};
                for (int l = {latency}; l < L; l++){{
                    {xo}[i][l] += wijk * {xi}[j][l-offset];
                }}
            }}
        }}
    }}
    
    """
    
    r = f"""
    // auto-generated code for layer {name}: {layer}
    const float w_{name}[{k}][{n}][{m}] = {weights};
    const float b_{name}[{n}] = {bias};
    
    // Fill with biases for {name}
    for (int i = 0; i < {n}; i++) {{
        for (int l = {latency}; l < L; l++) {{
            {xo}[i][l] = b_{name}[i];
        }}
    }}
    
    // Apply main filter for {name}
    // {xo}[:,{latency}:] = sum(w[k]@{xi}[:,{latency}-({k-1}-k)*{d}:L-({k-1}-k)*{d}] for k in w.shape[0])
    for (int k = 0; k < {k}; k++) {{
        int offset = ({k-1}-k)*{d};
        cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, {n}, L-{latency}, {m}, 1.0, &w_{name}[k][0][0], {m}, &{xi}[0][{latency}-offset], MAX_L, 1.0, &{xo}[0][{latency}], MAX_L);
    }}
    
    """
    return r

def relu_to_str(size,xi,xo,latency):
    r = f"""
    // Rectified Linear Unit (ReLU)
    for (int i = 0; i < {size}; i++) {{
        for (int l = {latency}; l < L; l++) {{
            {xo}[i][l] = {xi}[i][l] > 0 ? {xi}[i][l] : 0;
        }}
    }}

    """
    return r


In [348]:
def sequential_to_str(seq):
    
    #r = """void apply_cnn(float* x, float* y, int L) {
    #float* xin[1] = {x};
    #"""
    
    max_w = max(l.out_channels for l in seq if hasattr(l,'out_channels'))
    latency = sum((l.kernel_size[0]-1)*l.dilation[0] for l in seq if hasattr(l,'out_channels'))
    
    #r += f"""
    #float x_even[{max_w}][L];
    #float x_odd [{max_w}][L];
    #"""
    
    r = f"""
extern "C" {{
#include <cblas.h>
}};

// About {max_w*2*(8192+latency)/1e6} MB
const int latency = {latency};
const int MAX_L = {8192+latency};
float x_even[{max_w}][MAX_L];
float x_odd [{max_w}][MAX_L];

void apply_cnn(float* x, float* y, int L) {{

    // Ensure we don't segfault
    L = L > MAX_L ? MAX_L : L;
    
    for (int i = 0; i < L; i++) {{
        x_odd[0][i] = x[i];
    }}
    """
    
    s = None
    i = 0
    latency = 0
    xevenodd = ["x_even","x_odd"]
    xi = xevenodd[1]
    for l in seq:
        if isinstance(l, nn.ReLU):
            r += relu_to_str(s,xi,xi,latency)
        else:
            latency += (l.kernel_size[0]-1)*l.dilation[0]
            #xo = f"x{i}"
            xo = xevenodd[i%2]
            r += layer_to_str(l,f"layer_{i}",xi,xo,latency)
            s = l.out_channels
            i += 1
            xi = xo
    r += f"""
    // Copy result back to y
    for (int l = {latency}; l < L; l++) {{
        y[l] = {xo}[0][l];
    }}
}}
    """
    return r

In [349]:
s = sequential_to_str(net.layers)
print(s)


extern "C" {
#include <cblas.h>
};

// About 0.121744 MB
const int latency = 504;
const int MAX_L = 8696;
float x_even[7][MAX_L];
float x_odd [7][MAX_L];

void apply_cnn(float* x, float* y, int L) {

    // Ensure we don't segfault
    L = L > MAX_L ? MAX_L : L;
    
    for (int i = 0; i < L; i++) {
        x_odd[0][i] = x[i];
    }
    
    // auto-generated code for layer layer_0: Conv1d(1, 7, kernel_size=(2,), stride=(1,), dilation=(8,))
    const float w_layer_0[2][7][1] = {{{0.43975532054901123},{-0.9576542973518372},{-0.7286499738693237},{-0.37760645151138306},{0.6205844879150391},{0.8448692560195923},{0.1774430125951767}},{{0.13842950761318207},{-0.4070027768611908},{-0.18621504306793213},{0.006337991915643215},{-0.39210474491119385},{0.13554632663726807},{-0.6126473546028137}}};
    const float b_layer_0[7] = {0.20157429575920105,-0.09349554777145386,-0.2968412935733795,1.174821138381958,1.1558347940444946,0.07309959083795547,0.02421138435602188};
    
    // Fill with biases

In [350]:
main = """
#include <iostream>
#include <chrono>
int main() {
        int N = 1000;
        auto now = std::chrono::high_resolution_clock::now;
        float x[N];
        float y[N];
        for (int i = 0; i < N; i++) {
            x[i] = rand();
        }
        auto start = now();
        for (int i = 0; i < 1000; i++) {
            apply_cnn(x,y,N);
        }
        auto end = now();
        std::cout << 100*std::chrono::duration<double>(end-start).count()/1000/(N-504)*96e3 << "%" << std::endl;
}
"""
with open('/tmp/test.cpp', 'w') as f:
    f.write(s+main)

In [354]:
!g++ -O3 /tmp/test.cpp -o /tmp/test -lopenblas -L/usr/lib64

In [367]:
!/tmp/test

1.82049%


Bam! Sub 2% of the CPU usage. That's really good. With 4 CPUs, it's even less.

In [77]:
torch.save(net.layers, '/home/michael/Other/code/cnn_distortion/models/cnn_dist_v2.pt')


Couldn't retrieve source code for container of type Sequential. It won't be checked for correctness upon loading.


Couldn't retrieve source code for container of type Conv1d. It won't be checked for correctness upon loading.


Couldn't retrieve source code for container of type ReLU. It won't be checked for correctness upon loading.

