In [None]:
import numpy as np
from numpy import pi, exp, log10
from numpy.fft import fft, fftshift, fftfreq, ifft
from bokeh.plotting import figure, show, output_notebook
import bokeh.palettes as pl
import panel as pn
pn.extension()
from glob import glob
from tqdm import tqdm 
import os

import utils

import matplotlib.pyplot as plt
import scipy

In [None]:
fs, x1 = scipy.io.wavfile.read('freebird.wav')
x1 = np.array(x1, dtype=np.float32)
fs, x1.shape

FS = 44100
stft = utils.get_stft(fs=FS)

# cut down waveform size to two minutes centered around the start of the freebird solo
start = 5*60 + 41 #+ 30
end = 7*60 + 41 #- 30
x = x1[start*fs:end*fs]

l = x[:, 0]
r = x[:, 1]
scipy.io.wavfile.write('freebird_cut.wav', fs, np.array([l, r]).T.astype(np.int16))


s = (l + r) / 2 # mono signal
N = len(l)

P_IID, P_IC = utils.encode(stft, l, r)
y_mixed = utils.decode(stft, s, P_IID, P_IC)
l1 = y_mixed.T[:, 0]
r1 = y_mixed.T[:, 1]
scipy.io.wavfile.write('freebird_decoded.wav', fs, y_mixed.T.astype(np.int16))

In [None]:
import ffmpeg

files = glob('./data/**/**.mp3')

# convert all to wav
for f in tqdm(files):
    out = f.replace('.mp3', '.wav')
    ffmpeg.input(f).output(out).run(overwrite_output=True, quiet=True)

In [None]:
files = glob('./data/**/**.wav')

# mark any file that won't load as .wav.bad
for f in tqdm(files):
    try:
        fs, l, r = utils.load_wav(f)
    except:
        # print(f)
        os.rename(f, f + '.bad')

# reverse the above
# for f in glob('./data/**/**.wav.bad'):
#    os.rename(f, f.replace('.bad', ''))

In [None]:
files = glob('./data/**/**.wav')

FS = 44100
stft = utils.get_stft(fs=FS)

# generate PS coding, then save as npz
for f in tqdm(files):
    fs, l, r = utils.load_wav(f)
    N = len(l)
    P_IID, P_IC = utils.encode(stft, l, r)
    out = f + '_ps.npz'
    np.savez_compressed(out, P_IID=P_IID, P_IC=P_IC)

In [None]:
import os
from glob import glob
# delete all files with _hat.wav
for f in glob('data/**/**_hat.wav'):
    os.remove(f)

FS = 44100
stft = utils.get_stft(fs=FS)

In [None]:
n_train = 20
n_test = 3
n_frames = 20

train = glob('data/**/**.wav')[:n_train]
test = glob('data/**/**.wav')[n_train:n_train+n_test]
len(train), len(test)

In [None]:
# train the model

n_samples = 200
features = []
labels = []

from sklearn.neighbors import KNeighborsRegressor
knn = KNeighborsRegressor(n_neighbors=1)

for f in tqdm(train):
    if f.find('_hat.wav') > 0: # TODO: really need a more reliable way of doing this
        continue
   
    s, BS, P_IID, P_IC, P, m = utils.file_to_feature_label(f, stft)

    for i in range(n_samples):
        # take a random set of frames from the spectrogram S
        idx = np.random.randint(0, BS.shape[1] - n_frames)

        feature = BS[:, idx:idx+n_frames]
        feature = utils.complex_to_real(feature).flatten()
        features.append(feature)
        
        # the label is the PS parameters of the last frame
        label = P[:, idx+n_frames]
        label = utils.complex_to_real(label)
        labels.append(label)

knn = KNeighborsRegressor(n_neighbors=1).fit(features, labels)
knn

In [None]:
# test the model
for f in test:
    print(f)
    s, BS, P_IID, P_IC, P, m = utils.file_to_feature_label(f, stft)
    P_hat = np.zeros_like(P)

    features = []
    for i in range(BS.shape[1] - n_frames):
        feature = BS[:, i:i+n_frames]
        feature = utils.complex_to_real(feature).flatten()
        features.append(feature)
    
    labels = knn.predict(features)
    for i in range(len(labels)):
        P_hat[:, i+n_frames] = utils.real_to_complex(labels[i])
    
    P_IID_hat, P_IC_hat = utils.parameters_split(P_hat)
    l_hat, r_hat = utils.decode(stft=stft, s=s, P_IID=P_IID_hat, P_IC=P_IC_hat)
    scipy.io.wavfile.write(f.replace('.wav', '_hat.wav'), FS, np.array([l_hat*m, r_hat*m]).astype(np.int16).T)

    # print mse
    # P_IID /= np.linalg.norm(P_IID)
    # P_IC /= np.linalg.norm(P_IC)
    # P_IID_hat /= np.linalg.norm(P_IID_hat)
    # P_IC_hat /= np.linalg.norm(P_IC_hat)

    mse_IID = np.mean(np.abs(P_IID - P_IID_hat)**2)
    mse_IC = np.mean(np.abs(P_IC - P_IC_hat)**2)
    print(f'mse_IID={mse_IID:.2f}, mse_IC={mse_IC:.2f}', mse_IID/mse_IC)

In [None]:
import os
from glob import glob
# delete all files with _hat.wav
for f in glob('data/**/**_hat.wav'):
    os.remove(f)

FS = 44100
stft = utils.get_stft(fs=FS)

In [None]:
n_train = 300
n_test = 50

files = glob('data/**/**.wav')
train = files[:n_train]
test = files[n_train:n_train+n_test]

train_features = np.zeros((n_train, 1294, utils.FREQ_BINS*2))
train_labels = np.zeros((n_train, 1294, utils.FREQ_BINS*2))

test_features = np.zeros((n_test, 1294, utils.FREQ_BINS*2))
test_labels = np.zeros((n_test, 1294, utils.FREQ_BINS*2))

for i, f in tqdm(enumerate(train)):
    s, BS, P_IID, P_IC, P, m = utils.file_to_feature_label(f, stft)
    train_features[i] = utils.complex_to_real(BS).T
    train_labels[i] = utils.complex_to_real(P_IC).T

for i, f in tqdm(enumerate(test)):
    s, BS, P_IID, P_IC, P, m = utils.file_to_feature_label(f, stft)
    test_features[i] = utils.complex_to_real(BS).T
    test_labels[i] = utils.complex_to_real(P_IC).T

2it [00:00,  5.74it/s]

[[ 0.75013838 -0.22872972  0.06775588 ...  0.78225283  0.77540292
   0.86139403]
 [ 0.90733926  0.38855803  0.43985821 ...  0.48604539  0.62873925
   0.70555426]
 [-0.09651976 -0.17780221 -0.01063181 ...  0.5428377   0.69144383
   0.68308714]
 ...
 [ 0.17235238 -0.08866042 -0.05647785 ...  0.99986576  0.9999623
   0.9999721 ]
 [-0.09696846 -0.13084458 -0.1513419  ...  0.99987938  0.99996953
   0.99998621]
 [ 0.0301151   0.27069207  0.27276086 ...  0.99990236  0.99999096
   0.99999678]]
[[ 0.84493707  0.87568237  0.84435323 ...  0.77720969  0.78367652
   0.76530098]
 [ 0.99619339  0.97061538  0.98075035 ...  0.93622768  0.92077887
   0.92538475]
 [ 0.55229379  0.89787161  0.88688152 ...  0.86588353  0.84968861
   0.81120776]
 ...
 [ 0.29573141  0.00384122 -0.11495963 ... -0.94724108 -0.98041472
  -0.98371376]
 [-0.08460059  0.00200618 -0.03844582 ... -0.92099013 -0.98624072
  -0.99656598]
 [-0.66898418  0.03028949  0.14746236 ... -0.90263105 -0.98517886
  -0.99502645]]


3it [00:00,  5.14it/s]

[[ 0.63168802  0.74947525  0.95335763 ...  0.83770782  0.88787163
   0.90593832]
 [ 0.99617393  0.99495088  0.99526963 ...  0.92194177  0.92719658
   0.89351239]
 [ 0.99052749  0.99815282  0.99732442 ...  0.78429358  0.87539759
   0.90759967]
 ...
 [-0.5888529  -0.10964038 -0.13741575 ... -0.027945   -0.2936149
  -0.40888141]
 [-0.00632056  0.178721   -0.09478092 ... -0.12552018 -0.07896963
  -0.29586755]
 [ 0.59511921 -0.13111548 -0.13747668 ... -0.16894161 -0.53144976
  -0.62159327]]


4it [00:00,  4.58it/s]

[[ 0.99997039  0.99998137  0.99997797 ...  0.99999087  0.99999176
   0.99999639]
 [ 0.99995862  0.99993383  0.99993134 ...  0.99999164  0.99999635
   0.99999767]
 [ 0.99962864  0.99930127  0.99917653 ...  0.99942772  0.9995133
   0.99975758]
 ...
 [ 0.41213658  0.27683169  0.33528498 ...  0.99990471  0.99999081
   0.99999715]
 [-0.01944529 -0.12411329 -0.06762995 ...  0.99987824  0.99998239
   0.99999476]
 [ 0.39166877  0.26340922  0.12733328 ...  0.99991316  0.99998195
   0.99999364]]


5it [00:01,  4.31it/s]

[[ 0.99211793  0.76908695  0.66581525 ...  0.83817348  0.93688957
   0.94651309]
 [-0.06845352  0.27978288  0.27432789 ...  0.96063786  0.95365525
   0.94461397]
 [-0.02092894  0.41786298  0.60388536 ...  0.95226811  0.95681276
   0.96991973]
 ...
 [-0.63528921 -0.66821183 -0.73264757 ...  0.07238795  0.74701154
   0.92561212]
 [-0.66153842 -0.77197643 -0.87065771 ...  0.48855061  0.8630665
   0.93473945]
 [-0.68085673 -0.51726761 -0.37456822 ...  0.99980003  0.99988267
   0.99988876]]


6it [00:01,  4.07it/s]

[[ 0.99999191  0.99999748  0.99999762 ...  0.99999868  0.99999922
   0.9999996 ]
 [ 0.99999957  0.99999974  0.99999988 ...  0.99999974  0.99999971
   0.99999978]
 [ 0.99999993  0.9999999   0.99999988 ...  1.          1.
   1.        ]
 ...
 [ 0.60858136  0.27183938  0.13034353 ...  0.99966304  0.99995218
   0.99998067]
 [ 0.51335266  0.35901379  0.22435268 ...  0.99956317  0.99992266
   0.99996931]
 [ 0.58330995  0.14241148 -0.1266562  ...  0.99965645  0.99995528
   0.99998891]]


7it [00:01,  3.94it/s]

[[ 0.96551176  0.96085962  0.96208868 ...  0.99953     0.99960218
   0.9996505 ]
 [ 0.91941943  0.24057774  0.28636036 ...  0.99701587  0.99518474
   0.99387581]
 [ 0.68555966  0.7934193   0.87763503 ...  0.99829079  0.99919848
   0.9986759 ]
 ...
 [-0.07984297  0.01949937  0.22156858 ...  0.99947998  0.99799443
   0.99569456]
 [ 0.10178636  0.01274263  0.00198571 ...  0.99476571  0.97228675
   0.9426446 ]
 [-0.01553566 -0.21764737 -0.17285732 ...  0.99030126  0.99933196
   0.99865161]]


8it [00:01,  3.81it/s]

[[-0.02462914 -0.34844978 -0.13398469 ...  0.99897851  0.99851698
   0.99901817]
 [-0.04736297 -0.00230033  0.07818976 ...  0.93926054  0.97316867
   0.98704918]
 [-0.39177347 -0.26104869 -0.08096054 ...  0.98193715  0.98835542
   0.99092265]
 ...
 [-0.64473311 -0.04138584 -0.16485043 ...  0.98609877  0.99904057
   0.9999088 ]
 [ 0.00724341  0.06236173  0.03558446 ...  0.98736231  0.99876122
   0.99988131]
 [ 0.64894394  0.48185914  0.05584582 ...  0.99788534  0.99951958
   0.99815304]]


9it [00:02,  3.72it/s]

[[0.99981426 0.99780185 0.99486468 ... 0.99840528 0.99872731 0.99951469]
 [0.98131987 0.97206991 0.98999683 ... 0.99577694 0.99289201 0.99486087]
 [0.99967171 0.99999693 0.99999503 ... 0.97842385 0.98189017 0.98349837]
 ...
 [0.97977638 0.98986761 0.99999628 ... 0.98655836 0.98649437 0.97790955]
 [0.99989001 0.99999986 0.99999993 ... 0.52462646 0.87734964 0.99532428]
 [0.99971554 0.99992551 0.9999237  ... 0.99787496 0.99308199 0.99733159]]


10it [00:02,  3.70it/s]

[[ 0.67821249  0.89797209  0.94400385 ...  0.85018835  0.8723236
   0.92652242]
 [ 0.73666411  0.76180293  0.77993281 ...  0.70381675  0.89033952
   0.97195965]
 [ 0.94915539  0.90432762  0.89630183 ...  0.82049575  0.91940995
   0.96947074]
 ...
 [-0.51322566 -0.35847358 -0.20932712 ...  0.99255311  0.99925615
   0.99980836]
 [ 0.29914073  0.11354708 -0.02790953 ...  0.9951203   0.99938582
   0.9998776 ]
 [-0.25768007 -0.44318867 -0.41555378 ...  0.99964294  0.99995655
   0.9999871 ]]


11it [00:02,  3.87it/s]

[[1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]
 ...
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]]





KeyboardInterrupt: 

In [56]:
from keras.models import Sequential
from keras.layers import Dense, SimpleRNN, Input, LSTM

model = Sequential()
model.add(Input(shape=(None, utils.FREQ_BINS*2)))
model.add(SimpleRNN(units=utils.FREQ_BINS*2, return_sequences=True)) #activation=activation[0]))

# model.add(LSTM(units=utils.FREQ_BINS*2, return_sequences=True)) #activation=activation[1]))
# model.add(Dense(units=utils.FREQ_BINS*2, activation='tanh'))
model.add(Dense(units=utils.FREQ_BINS*2, activation=None))
model.compile(loss='mean_squared_error', optimizer='adam')

model.summary()


In [57]:
model.fit(train_features, train_labels, epochs=1, batch_size=1)

[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 92ms/step - loss: 0.1807


<keras.src.callbacks.history.History at 0x7f682d3dba60>

In [61]:
# model.evaluate(train_features, train_labels, batch_size=1)
# model.evaluate(test_features, test_labels, batch_size=1)
print(np.abs(train_labels[0]).round(2)[:1])
print(np.abs(model.predict(train_features[:1])[0]).round(2)[:1])

[[0.75 0.91 0.1  0.32 0.2  0.04 0.07 0.83 0.04 0.51 0.01 0.05 0.39 0.68
  0.37 0.17 0.14 0.22 0.05 0.78 0.48 0.12 0.21 0.18 0.52 0.3  0.32 0.17
  0.04 0.12 0.2  0.17 0.1  0.03 0.   0.   0.   0.   0.   0.   0.   0.
  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.  ]]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step
[[0.08 0.06 0.05 0.05 0.06 0.05 0.04 0.09 0.06 0.07 0.03 0.08 0.09 0.05
  0.07 0.08 0.1  0.08 0.05 0.05 0.03 0.04 0.05 0.04 0.05 0.05 0.06 0.02
  0.02 0.03 0.02 0.07 0.06 0.02 0.03 0.   0.02 0.01 0.01 0.06 0.   0.
  0.01 0.01 0.04 0.02 0.01 0.06 0.03 0.01 0.02 0.02 0.01 0.07 0.01 0.01
  0.05 0.07 0.06 0.01 0.02 0.05 0.03 0.   0.01 0.03 0.07 0.05]]


In [59]:
for f in test:
    print(f)
    s, BS, P_IID, P_IC, P, m = utils.file_to_feature_label(f, stft)
    P_hat = np.zeros_like(P)

    feature = utils.complex_to_real(BS).T
    # label = utils.complex_to_real(P).T
    
    P_hat = utils.real_to_complex(model.predict(np.array([feature]))[0].T)
    P_IID_hat, P_IC_hat = utils.parameters_split(P_hat)
    l_hat, r_hat = utils.decode(stft=stft, s=s, P_IID=P_IID_hat, P_IC=P_IC_hat)
    scipy.io.wavfile.write(f.replace('.wav', '_hat.wav'), FS, np.array([l_hat*m, r_hat*m]).astype(np.int16).T)

    # print mse
    # P_IID /= np.linalg.norm(P_IID)
    # P_IC /= np.linalg.norm(P_IC)
    # P_IID_hat /= np.linalg.norm(P_IID_hat)
    # P_IC_hat /= np.linalg.norm(P_IC_hat)

    mse_IID = np.mean(np.abs(P_IID - P_IID_hat)**2)
    mse_IC = np.mean(np.abs(P_IC - P_IC_hat)**2)
    print(np.mean(np.abs(P_IID_hat)), np.mean(np.abs(P_IC_hat)), sep='\t')
    print(np.mean(np.abs(P_IID)), np.mean(np.abs(P_IC)), sep='\t')
    print(f'mse_IID={mse_IID:.2f}, mse_IC={mse_IC:.2f}', mse_IID/mse_IC)

data/014/014588.wav
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 79ms/step


ValueError: operands could not be broadcast together with shapes (0,1294) (34,1294) 

In [None]:
X = 10**(np.abs(P_IID)/10)

This creates the $\mathbf{B}$ matrix which consists of 1s and 0s and sums the 2049 FFT bins into 34 frequency bands

The cross spectrogram is defined as:
$$\rho(\mathbf{X}, \mathbf{Y}) = \mathbf{B}(\mathbf{X} \times \mathbf{Y}^*)$$
where $\times$ denotes element-wise multiplication, and $*$ denotes element-wise complex conjugation.

In [None]:
from bokeh import palettes as pl
from bokeh.models import LogColorMapper, ColorBar

m = LogColorMapper(palette=pl.Inferno256, low=X.min(), high=X.max())

p = figure(width=1500, height=700, title='Spectrogram', x_axis_label='Time (s)', y_axis_label='Frequency (kHz)')
p.min_border=0

p.image(image=[X], x=0, y=0, dw=X.shape[1]*int(4096*(1-0.75))/fs, dh=fs/2/1e3, color_mapper=m)
p.add_layout(ColorBar(color_mapper=m), 'right')

show(p)