In [None]:
%load_ext lab_black
import time

import numpy as np
from scipy import signal
from scipy.io import wavfile as wav
import matplotlib.pyplot as plt

import plotly.graph_objects as go
from ipywidgets import interact, IntSlider

plt.style.use("dark_background")
np.set_printoptions(precision=2, linewidth=150)
base_layout = dict(
    template="plotly_dark",
    xaxis_showgrid=False,
    yaxis_showgrid=False,
    margin=dict(l=20, r=20, t=20, b=20),
)

In [None]:
nperseg = 1024 * 32
rate, data = wav.read("moon_river_clipped.wav")
data = data[:, 0] + data[:, 1]  # convert to mono
f, t, Zxx = signal.stft(data, rate, nperseg=nperseg)

clip = nperseg // 64
Zxx = Zxx[:clip]
f = f[:clip]
Zxx.shape

In [None]:
# TODO use only plotly
_, ax = plt.subplots(figsize=(20, 6))

ax.pcolormesh(t, f, np.abs(Zxx), vmin=0, shading="gouraud")
ax.set_title("STFT Magnitude")
ax.set_ylabel("Frequency [Hz]")
ax.set_xlabel("Time [sec]")

In [None]:
def get_peaks_cwt(fft, width_range=(3, 13), height=400):
    fft = np.abs(fft)
    widths = np.arange(width_range[0], width_range[1])
    peaks = signal.find_peaks_cwt(fft, widths)
    volumes = fft[peaks]
    mask = volumes > height
    return peaks[mask], volumes[mask]


fig = go.FigureWidget(layout=base_layout)
fig.update_layout(
    xaxis_range=[0, Zxx.shape[0]],
    yaxis_range=[0, 4000],
    width=1000,
    height=300,
)
fig.add_scattergl()
xrange = Zxx.shape[1] - 1


@interact(x=(0, xrange))
def update(x=xrange // 2):
    freq_slice = Zxx[:, x]
    peaks, volumes = get_peaks_cwt(freq_slice)

    shapes = list()
    for peak, vol in zip(peaks, volumes):
        shapes.append(
            {
                "type": "line",
                "line_color": "orange",
                "x0": peak,
                "y0": 0,
                "x1": peak,
                "y1": vol,
            }
        )

    with fig.batch_update():
        fig.data[0].y = np.abs(freq_slice)
        fig.layout.shapes = shapes
    print(volumes)


fig

In [None]:
# def get_peaks(fft, window=17, height=400, distance=10):
#     fft = np.abs(fft)
#     smooth = signal.savgol_filter(fft, window, 3)
#     peaks, peaks_prop = signal.find_peaks(smooth, height=height, distance=distance)
#     signal.find_peaks_cwt
#     # TODO maybe use prominences
#     heights = peaks_prop["peak_heights"]
#     return peaks, heights, smooth


# fig = go.FigureWidget(layout=base_layout)
# fig.update_layout(
#     xaxis_range=[0, Zxx.shape[0]],
#     yaxis_range=[0, 4000],
#     width=1000,
#     height=300,
# )
# fig.add_scattergl()
# fig.add_scattergl(line_color="green")

# xrange = Zxx.shape[1] - 1


# @interact(x=(0, xrange))
# def update(x=xrange // 2):
#     freq_slice = Zxx[:, x]
#     peaks, heights, smooth = get_peaks(freq_slice)

#     shapes = list()
#     for peak, height in zip(peaks, heights):
#         shapes.append(
#             {
#                 "type": "line",
#                 "line_color": "orange",
#                 "x0": peak,
#                 "y0": 0,
#                 "x1": peak,
#                 "y1": height,
#             }
#         )

#     with fig.batch_update():
#         fig.data[0].y = np.abs(freq_slice)
#         fig.data[1].y = smooth
#         fig.layout.shapes = shapes


# fig

In [None]:
def get_ratios(peaks):
    peaks = peaks.reshape((1, -1))
    return peaks / peaks.T


def get_strengths(heights):
    heights = heights.reshape((1, -1))
    return (heights * heights.T) ** (1 / 2)


# significant = get_strengths(heights) > 400
# significant = np.triu(significant)
# get_ratios(peaks) * significant

In [None]:
f = go.FigureWidget(layout=base_layout)
f.update_layout(
    yaxis_range=[0, Zxx.shape[0]],
    xaxis_range=[0, Zxx.shape[0]],
    width=600,
    height=600,
)

f.add_scattergl(mode="markers")
f

In [None]:
speedup = 1
time_len = 60  # in seconds, max 60
i_range = Zxx.shape[1] * time_len // 60

for i in range(i_range):
    freq_slice = Zxx[:, i]
    peaks, heights, _ = get_peaks(freq_slice)

    xs, ys = np.meshgrid(peaks, peaks)
    xs = xs.flatten()
    ys = ys.flatten()
    sizes = get_strengths(heights).flatten() / 50
    ratios = get_ratios(peaks)

    with f.batch_update():
        f.update_traces(x=xs, y=ys, marker_size=sizes)  # set marker_color
#     time.sleep(time_len / i_range / speedup)
print("done")