In [1]:
import numpy as np
from numpy import pi, exp, log10
from numpy.fft import fft, fftshift, fftfreq, ifft
from bokeh.plotting import figure, show, output_notebook
import bokeh.palettes as pl
import panel as pn
pn.extension()

import utils

import matplotlib.pyplot as plt
import scipy

def labels(p, x='', y='', t=''):
    p.xaxis.axis_label=x
    p.yaxis.axis_label=y
    p.title=t

def fontsize(p):
    p.xaxis.major_label_text_font_size = '12pt'
    p.yaxis.major_label_text_font_size = '12pt'
    p.xaxis.axis_label_text_font_size = '12pt'
    p.yaxis.axis_label_text_font_size = '12pt'
    p.xaxis.axis_label_text_font_style = 'normal'
    p.yaxis.axis_label_text_font_style = 'normal'
    p.title.text_font_size = '12pt'
    p.toolbar.logo = None

def ticks(p):
    p.xaxis.ticker.num_minor_ticks = 10
    p.yaxis.ticker.num_minor_ticks = 3

def bplot(x):
    p = figure(width=800, height=400)
    p.line(np.arange(len(x)), x, line_width=2)
    p.toolbar.logo = None
    show(p)

def bplot2(n, x):
    p = figure(width=800, height=400)
    p.line(n, x)
    p.toolbar.logo = None
    show(p)

In [2]:
import ffmpeg
from glob import glob

files = glob('./data/**/**.mp3')

# convert all to wav
# for f in files:
#     out = f.replace('.mp3', '.wav')
#     ffmpeg.input(f).output(out).run()

In [3]:
gain = 1
FS = 44100
stft = utils.get_stft(fs=FS)

In [4]:
from sklearn.neighbors import KNeighborsRegressor

knn = KNeighborsRegressor(n_neighbors=1)

In [5]:
# train the model
n_files = 60*3
n_test = 5
n_samples = 5

n_frames = 20

train = glob('data/**/**.wav')[:n_files]
test = glob('data/**/**.wav')[n_files:n_files+n_test]

features = []
labels = []

for f in glob('data/**/**.wav')[:n_files]:
    fs, x = scipy.io.wavfile.read(f)
    x = x / gain
    if fs != FS:
        print('Error: wrong sampling rate')
        continue
    if f.find('_hat.wav') > 0:
        continue
    if x.ndim < 2:
        print('Error: mono file')
        continue

    l = x[:, 0]
    r = x[:, 1]
    s = (l + r) / 2
    BS = utils.b_matrix() @ utils.spectrogram(stft, s)

    P_IID, P_IC = utils.encode(stft=stft, l=l, r=r)
    P_IID = np.nan_to_num(P_IID)
    P_IC = np.nan_to_num(P_IC)
    P = utils.parameters_concat(P_IID, P_IC)


    for i in range(n_samples):
        # take a random set of frames from the spectrogram S
        idx = np.random.randint(0, BS.shape[1] - n_frames)

        feature = BS[:, idx:idx+n_frames]
        feature = utils.complex_to_real(feature).flatten()
        features.append(feature)
        
        # the label is the PS parameters of the last frame
        label = P[:, idx+n_frames]
        label = utils.complex_to_real(label)
        labels.append(label)

knn = KNeighborsRegressor(n_neighbors=1).fit(features, labels)
knn

KeyboardInterrupt: 

In [6]:
# test the model

for f in test:
    fs, x = scipy.io.wavfile.read(f)
    if fs != FS:
        print('Error: wrong sampling rate')
        continue

    l = x[:, 0]
    r = x[:, 1]
    s = (l + r) / 2
    BS = utils.b_matrix() @ utils.spectrogram(stft, s)

    P_IID, P_IC = utils.encode(stft=stft, l=l, r=r)
    P = utils.parameters_concat(P_IID, P_IC)

    P_hat = np.zeros_like(P)

    features = []
    for i in range(BS.shape[1] - n_frames):
        feature = BS[:, i:i+n_frames]
        feature = utils.complex_to_real(feature).flatten()
        features.append(feature)
    
    labels = knn.predict(features)
    for i in range(len(labels)):
        P_hat[:, i+n_frames] = utils.real_to_complex(labels[i])
    
    P_IID_hat, P_IC_hat = utils.parameters_split(P_hat)
    l_hat, r_hat = utils.decode(stft=stft, s=s, P_IID=P_IID_hat, P_IC=P_IC_hat)
    scipy.io.wavfile.write(f.replace('.wav', '_hat.wav'), FS, np.array([l_hat*gain, r_hat*gain]).astype(np.int16).T)

In [6]:
fs, x1 = scipy.io.wavfile.read('freebird.wav')
fs, x1.shape

import os
# delete all files with _hat.wav
for f in glob('data/**/**_hat.wav'):
    os.remove(f)


In [8]:
# cut down waveform size to two minutes centered around the start of the freebird solo
start = 5*60 + 41 + 30
end = 7*60 + 41 - 30
x = x1[start*fs:end*fs]/gain

This creates the $\mathbf{B}$ matrix which consists of 1s and 0s and sums the 2049 FFT bins into 34 frequency bands

The cross spectrogram is defined as:
$$\rho(\mathbf{X}, \mathbf{Y}) = \mathbf{B}(\mathbf{X} \times \mathbf{Y}^*)$$
where $\times$ denotes element-wise multiplication, and $*$ denotes element-wise complex conjugation.

In [9]:
l = x[:, 0]
r = x[:, 1]
s = (l + r) / 2 # mono signal
N = len(l)

P_IID, P_IC = utils.encode(stft, l, r)
X = 10**(np.abs(P_IID)/10)

In [10]:
from bokeh import palettes as pl
from bokeh.models import LogColorMapper, ColorBar

m = LogColorMapper(palette=pl.Inferno256, low=X.min(), high=X.max())

p = figure(width=1500, height=700, title='Spectrogram', x_axis_label='Time (s)', y_axis_label='Frequency (kHz)')
p.min_border=0

p.image(image=[X], x=0, y=0, dw=X.shape[1]*int(4096*(1-0.75))/fs, dh=fs/2/1e3, color_mapper=m)
p.add_layout(ColorBar(color_mapper=m), 'right')

show(p)

In [11]:
y_mixed = utils.decode(stft, s, P_IID, P_IC)*gain
scipy.io.wavfile.write('freebird_decoded.wav', fs, y_mixed.T.astype(np.int16))