In [1]:
import numpy as np
from numpy import pi, exp, log10
from numpy.fft import fft, fftshift, fftfreq, ifft
from bokeh.plotting import figure, show, output_notebook
import bokeh.palettes as pl
import panel as pn
pn.extension()
from glob import glob
from tqdm import tqdm 
import os

import utils

import matplotlib.pyplot as plt
import scipy

In [None]:
fs, x1 = scipy.io.wavfile.read('freebird.wav')
x1 = np.array(x1, dtype=np.float32)
fs, x1.shape

FS = 44100
stft = utils.get_stft(fs=FS)

# cut down waveform size to two minutes centered around the start of the freebird solo
start = 5*60 + 41 #+ 30
end = 7*60 + 41 #- 30
x = x1[start*fs:end*fs]

l = x[:, 0]
r = x[:, 1]
scipy.io.wavfile.write('freebird_cut.wav', fs, np.array([l, r]).T.astype(np.int16))


s = (l + r) / 2 # mono signal
N = len(l)

P_IID, P_IC = utils.encode(stft, l, r)
y_mixed = utils.decode(stft, s, P_IID, P_IC)
l1 = y_mixed.T[:, 0]
r1 = y_mixed.T[:, 1]
scipy.io.wavfile.write('freebird_decoded.wav', fs, y_mixed.T.astype(np.int16))

In [None]:
import ffmpeg

files = glob('./data/**/**.mp3')

# convert all to wav
# for f in files:
#     out = f.replace('.mp3', '.wav')
#     ffmpeg.input(f).output(out).run()

In [None]:
files = glob('./data/**/**.wav')

# mark any file that won't load as .wav.bad
for f in tqdm(files):
    try:
        fs, l, r = utils.load_wav(f)
    except:
        print(f)
        os.rename(f, f + '.bad')

# reverse the above
# for f in glob('./data/**/**.wav.bad'):
#   os.rename(f, f.replace('.bad', ''))

In [None]:
files = glob('./data/**/**.wav')

FS = 44100
stft = utils.get_stft(fs=FS)

# generate PS coding, then save as npz
for f in tqdm(files):
    fs, l, r = utils.load_wav(f)
    N = len(l)
    P_IID, P_IC = utils.encode(stft, l, r)
    out = f + '_ps.npz'
    np.savez_compressed(out, P_IID=P_IID, P_IC=P_IC)

In [2]:
import os
from glob import glob
# delete all files with _hat.wav
for f in glob('data/**/**_hat.wav'):
    os.remove(f)

In [3]:
n_files = 50
n_test = 5
n_frames = 20

train = glob('data/**/**.wav')[:n_files]
test = glob('data/**/**.wav')[n_files:n_files+n_test]
len(train), len(test)

(50, 5)

In [9]:
# train the model
FS = 44100
stft = utils.get_stft(fs=FS)

n_samples = 200
features = []
labels = []

from sklearn.neighbors import KNeighborsRegressor
knn = KNeighborsRegressor(n_neighbors=1)

for f in tqdm(train):
    if f.find('_hat.wav') > 0: # TODO: really need a more reliable way of doing this
        continue
   
    s, BS, P_IID, P_IC, P, m = utils.file_to_feature_label(f, stft)

    for i in range(n_samples):
        # take a random set of frames from the spectrogram S
        idx = np.random.randint(0, BS.shape[1] - n_frames)

        feature = BS[:, idx:idx+n_frames]
        feature = utils.complex_to_real(feature).flatten()
        features.append(feature)
        
        # the label is the PS parameters of the last frame
        label = P[:, idx+n_frames]
        label = utils.complex_to_real(label)
        labels.append(label)

knn = KNeighborsRegressor(n_neighbors=1).fit(features, labels)
knn

100%|██████████| 50/50 [00:12<00:00,  3.89it/s]


In [5]:
# test the model
for f in test:
    print(f)
    s, BS, P_IID, P_IC, P, m = utils.file_to_feature_label(f, stft)

    P_hat = np.zeros_like(P)

    features = []
    for i in range(BS.shape[1] - n_frames):
        feature = BS[:, i:i+n_frames]
        feature = utils.complex_to_real(feature).flatten()
        features.append(feature)
    
    labels = knn.predict(features)
    for i in range(len(labels)):
        P_hat[:, i+n_frames] = utils.real_to_complex(labels[i])
    
    P_IID_hat, P_IC_hat = utils.parameters_split(P_hat)
    l_hat, r_hat = utils.decode(stft=stft, s=s, P_IID=P_IID_hat, P_IC=P_IC_hat)
    scipy.io.wavfile.write(f.replace('.wav', '_hat.wav'), FS, np.array([l_hat*m, r_hat*m]).astype(np.int16).T)

    # print mse
    # P_IID /= np.linalg.norm(P_IID)
    # P_IC /= np.linalg.norm(P_IC)
    # P_IID_hat /= np.linalg.norm(P_IID_hat)
    # P_IC_hat /= np.linalg.norm(P_IC_hat)

    mse_IID = np.mean(np.abs(P_IID - 0)**2)
    mse_IC = np.mean(np.abs(P_IC - 0)**2)
    print(f'mse_IID={mse_IID:.2f}, mse_IC={mse_IC:.2f}', mse_IID/mse_IC)

data/006/006443.wav
mse_IID=8.63, mse_IC=0.28 30.95949838613903
data/006/006611.wav
mse_IID=14.62, mse_IC=0.42 35.168218706819445
data/006/006439.wav
mse_IID=18.18, mse_IC=0.19 97.66403926573513
data/006/006762.wav
mse_IID=67.74, mse_IC=0.21 319.12829857595176
data/006/006463.wav
mse_IID=32.61, mse_IC=0.34 96.12001543414927


In [6]:
from keras.models import Sequential
from keras.layers import Dense, SimpleRNN



2024-11-09 10:16:52.756354: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-09 10:16:52.757990: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-11-09 10:16:52.762593: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-11-09 10:16:52.774012: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1731165412.793383   17646 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1731165412.79

In [7]:
X = 10**(np.abs(P_IID)/10)

This creates the $\mathbf{B}$ matrix which consists of 1s and 0s and sums the 2049 FFT bins into 34 frequency bands

The cross spectrogram is defined as:
$$\rho(\mathbf{X}, \mathbf{Y}) = \mathbf{B}(\mathbf{X} \times \mathbf{Y}^*)$$
where $\times$ denotes element-wise multiplication, and $*$ denotes element-wise complex conjugation.

In [8]:
from bokeh import palettes as pl
from bokeh.models import LogColorMapper, ColorBar

m = LogColorMapper(palette=pl.Inferno256, low=X.min(), high=X.max())

p = figure(width=1500, height=700, title='Spectrogram', x_axis_label='Time (s)', y_axis_label='Frequency (kHz)')
p.min_border=0

p.image(image=[X], x=0, y=0, dw=X.shape[1]*int(4096*(1-0.75))/fs, dh=fs/2/1e3, color_mapper=m)
p.add_layout(ColorBar(color_mapper=m), 'right')

show(p)