In [None]:
import numpy as np
import pandas as pd

import datetime
from data_collection.data_collection import Logger
import plotly.express as px
import scipy.signal as ss

import torchaudio as ta
import torch as tch
tafn = ta.functional
tatx = ta.transforms

from IPython.display import Audio
from nb_tools import *
from scipy.fft import fft, ifft, fftfreq
from functools import partial
%load_ext autoreload
%autoreload 2

# 1. Noise mixing

In [None]:
from pathlib import Path
@show_global_variables
def load_wav(path):
    path = Path(path)

    data = {}
    for f in filter(lambda f:f.suffix=='.wav', path.iterdir()):
        sig, fs = ta.load(f)
        data[f.stem] = sig

    return data

rate=16000
Audio = partial(Audio, rate=16000)

data = load_wav('../log/audio_data')

In [None]:
def sec(s, rate=rate):
    return int(s*rate)

In [None]:
noise = data['running_noise_c0'][:, rate*25:-rate*14]
ee0_0 = data['ee0deg_floor_c0'][:, rate*5:-rate*5]
ee0_1 = data['ee0deg_floor_c1'][:, rate*5:-rate*5]

ee90_0 = data['ee_90deg_floor_c0'][:, sec(2.3):sec(16)]
ee90_1 = data['ee_90deg_floor_c1'][:, sec(2.3):sec(16)]

ee180_0 = data['ee_180deg_floor_c0'][:, sec(0):sec(21)]
ee180_1 = data['ee_180deg_floor_c1'][:, sec(0):sec(21)]


In [None]:
def show_spectrum(sig, rate=rate, **kwargs):
    f, t, sxx = ss.spectrogram(sig, fs=rate)
    return px.imshow(np.log10(sxx), y=f, x=t, aspect='auto' )
#show_spectrum(data['ee_plus_runnng_c0'][0], nperseg=4096)

In [None]:
def mix(s1, s2, w1, w2):
    l1 = s1.shape[-1]
    l2 = s2.shape[-1]
    assert l1 >= l2

    pad_len = l1 - l2

    s1_idx = 0
    segments = []

    while s1_idx < l1:
        end_idx = min(l1, s1_idx+l2)

        s2_to_add = s2[:, :end_idx-s1_idx] # TODO: flip the noise every other iteration

        com_seg = s1[:, s1_idx:end_idx]*w1 + s2_to_add*w2
        segments.append(com_seg)
        s1_idx = end_idx
    return np.concatenate(segments, axis=-1)
    

In [None]:
from nb_tools import show_global_variables
@np.vectorize
def tri(idx, length):
    end = length - 1
    if not (idx // end) % 2: 
        return idx % end
    else: 
        return end - (idx % end)

@show_global_variables
def loop(sig, fs, dur, seg_st, seg_ed):
    hop_length = 200
    win_dur = hop_length/fs

    seg_st = seg_st//win_dur
    seg_ed = seg_ed//win_dur



    spectra = tatx.Spectrogram(power=2, hop_length=hop_length)(sig)

    n_win = dur//win_dur
    idx = tri(np.arange(n_win), seg_ed-seg_st) + seg_st



    return tatx.GriffinLim(hop_length=hop_length)(spectra[..., idx])
#sigl = loop(ee0_0*1e14, 16000, 20, 4, 8)

In [None]:
@show_global_variables
def freq_phase(n, fs, freq):
    return np.linspace(0, n/fs*freq*np.pi*2, n, endpoint=False)%(2*np.pi)

@show_global_variables
def fixedtime(fs, te):
    n_sample = int(te*fs)
    dt = [1/fs]*n_sample
    return np.cumsum(dt)

@show_global_variables
def jitteredtime(fs, te):
    n_sample = int(te*fs)
    ddt = (np.random.random(n_sample)-0.5)*(1/(fs*25))
    dt = np.cumsum(ddt) + 1/fs
    return np.cumsum(dt)

@show_global_variables
def varyingtime(fs, te):
    n_sample = int(te*fs)
    dddt = (np.random.random(n_sample)-0.5)*(1/(fs*20000000))
    ddt = np.cumsum(dddt)
    dt = np.cumsum(ddt) + 1/fs
    return np.cumsum(dt)

@show_global_variables
def time2phase(t, f):
    return (t*np.pi*2*f )% (2*np.pi)
    
@show_global_variables
def harmonics(times, basefreq, n_har=30):
    sigs = np.zeros((n_har, len(times)))
    for n in range(1, n_har+1):
        sigs[n-1] = np.sin(time2phase(times, basefreq*n)) 

    return sigs

@show_global_variables
def find_nearest(x, findfrom):
    return np.argmin(np.abs(x[:, None] - findfrom[None, :]), axis=1)

@show_global_variables
def butterworth_highpass(sig, fs, lf):
    filter = ss.butter(8, lf, fs=fs, btype='high', output='sos')
    return ss.sosfiltfilt(filter, sig)

@show_global_variables    
def butterworth_lowpass(sig, fs, lf):
    filter = ss.butter(8, lf, fs=fs, btype='low', output='sos')
    return ss.sosfiltfilt(filter, sig)


@show_global_variables
def get_avg_freq_from_sigs_spectro(sigs, rate=rate):

    def get_avg_freq(sig, rate=rate):
        f, _, sxx = ss.spectrogram(sig, fs=rate,)
        return f, (sxx**(1/2)).mean(-1)

    result = []
    for sig in sigs:
        f, meaned = get_avg_freq(sig, rate)
        result.append(meaned)

    return f, np.stack(result, axis=-1).mean(-1)


@show_global_variables
def fftfilter(sig,fs, mags, magfreqs):
    x = fft(sig)
    freqs = fftfreq(sig.shape[-1], 1/fs)
    fidx = find_nearest(freqs, magfreqs)

    x *= mags[fidx]
    
    return ifft(x)

@show_global_variables
def get_mag_of_freq_smooth(siglist, fs=16000, smooth_window=8000):
    fullsig = np.concatenate(siglist, axis=-1)
    mags = np.abs(fft(fullsig))[0]

    as_strided = np.lib.stride_tricks.as_strided
    stride_size = mags.strides[0]
    mags/=mags.max()
    mags_re = as_strided(
        mags, 
        shape=(len(mags), smooth_window),
        strides=(stride_size, stride_size)
    ).mean(-1)[::smooth_window//2]

    freqs = fftfreq(len(mags), 1/fs)[::smooth_window//2]

    return freqs, mags_re


In [None]:
fs = 16000
ts = varyingtime(fs, 10)
n_har = 20
base_freq=200
sigs = harmonics(ts, basefreq=base_freq, n_har=n_har)

In [None]:
def weight_n_add(sigs, sigfreqs, mags, magfreqs):
    idx = find_nearest(sigfreqs, magfreqs)
    return (sigs*mags[idx, None]).mean(0)


f, r = get_avg_freq_from_sigs_spectro([ee0_0, ee0_1, ee90_0, ee90_1, ee180_0, ee180_1])

s = weight_n_add(sigs, (np.arange(n_har)+1)*base_freq, r[0], f)

In [None]:
Audio(s, rate=16000)

In [None]:
freqs, mags = get_mag_of_freq_smooth([ee0_0, ee0_1, ee90_0, ee90_1, ee180_0, ee180_1])   

In [None]:
s = weight_n_add(sigs, (np.arange(n_har)+1)*base_freq, mags, freqs)

In [None]:
Audio(s, rate=16000)

In [None]:
sigf = fftfilter(sigs.mean(0),16000, mags, freqs)

In [None]:
sigf = fftfilter(s,16000, mags, freqs)

In [None]:
Audio(sigf.real[10000:-10000], rate=16000)

In [None]:
px.line(sigf.real)

In [None]:
px.line(tafn.detect_pitch_frequency(tch.tensor(s), 16000))

In [None]:
s

In [None]:
show_spectrum(sigf.real)

In [None]:
px.line(sigf.real[:10000])

In [None]:
show_spectrum(butterworth_lowpass(sigf.real, fs=16000, lf=6000))

In [None]:

fs, rs = get_avg_freq_from_sigs([s])


In [None]:
px.line(y=[rs/rs.max(), r[0]/r.max(), ], x=fs)

In [None]:
def find_nearest(x, findfrom):
    return np.argmin(np.abs(x[:, None] - findfrom[None, :]), axis=1)

idx = find_nearest(200*np.arange(1, 30+1), f)

In [None]:
show_spectrum((sigs*np.array()).mean(0), fs)

In [None]:
len(sigs)

In [None]:
px.line(y=r/r.max(), x=f)

In [None]:
fs = 16000
sig = np.sin(time2phase(time(fs, 10), 300)) 
Audio(sig, rate=fs)

In [None]:
np.random.random(10)-0.5

In [None]:
show_spectrum(sig, 16000)

In [None]:
px.line(tafn.detect_pitch_frequency(tch.tensor(sig), fs, win_length=3, frame_time=0.5))

In [None]:
show_spectrum(ee0_0[0])

In [None]:
mixed = mix(ee0_0, noise, 1, 1)
show_spectrum(mixed[0])

In [None]:
show_spectrum(ee0_0[0])

In [None]:
ee0_0

In [None]:
px.line(tafn.detect_pitch_frequency(tch.tensor(mixed[0]), rate))

In [None]:
st = 7
se = 7.2
px.line(y=[
        (mixed-ee0_0.numpy())[0, sec(st):sec(se)],
        mixed[0, sec(st):sec(se)], 
        ee0_0.numpy()[0, sec(st):sec(se)],
     ])

In [None]:
px.line(tafn.detect_pitch_frequency(tch.tensor(mixed), 16000)[0])

In [None]:

s = butterworth_lowpass(ee0_0, 16000, 300)

In [None]:
Audio(s)

In [None]:
noise

In [None]:
mixed_high = mix(s, noise.numpy(), 1, 1)

In [None]:
Audio(mixed_high)

In [None]:
px.line(s[0, sec(st):sec(se)])

In [None]:
Audio(mixed_high)

In [None]:
st = 24
se = 25
px.line(y=[
        (mixed_high-s)[0, sec(st):sec(se)],
        mixed_high[0, sec(st):sec(se)], 
        s[0, sec(st):sec(se)],
     ])

In [None]:
show_spectrum(s[0, sec(st): sec(se)], nperseg=8192)

In [None]:
show_spectrum(mixed_high[0, sec(st): sec(se)], nperseg=8192)

In [None]:
show_spectrum(mixed_high[0], nperseg=8192)

In [None]:
show_spectrum(ee[0])


In [None]:
mixed = mix(ee, noise, 1, 0)
#show_spectrum(mixed[0])
px.line(tafn.detect_pitch_frequency(tch.tensor(mixed[0]), rate))

In [None]:
mixed = mix(ee, noise, 1, 0.5)
#show_spectrum(mixed[0])
px.line(tafn.detect_pitch_frequency(tch.tensor(mixed[0]), rate))

In [None]:
px.line(y=r[0]/r.max(), x=f)

In [None]:
px.line(y=np.log(r), x=f)

In [None]:
px.line(y=np.log(r[0]), x=f)