In [1]:
import sys

import pandas as pd

sys.path.insert(0, '..')

In [None]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import importlib
import scipy.signal as signal_scipy
import os

from scipy.ndimage import gaussian_filter1d

import superlets_package.morlet as morlet
import superlets_package.superlets as superlet


In [None]:
importlib.reload(superlet)

In [None]:
importlib.reload(morlet)

# Create simulated EMG signals

### Signal parameters

In [None]:
fs = 1024
signal_duration = 3
t = np.linspace(0, signal_duration, signal_duration * fs)

### Baseline noise

In [None]:
np.random.seed(10)
noise = np.random.normal(0, 0.01, len(t))
#noise = np.random.normal(0, 0.00001, len(t))
#noise = np.array([0.00001] * len(t))


# Contaminaciones de frecuencia conocidas

In [None]:
# Contamination in time (sin) and in frequency (cos)
contamination = False
amplitud = 0.2

if contamination:
    base_signal = np.zeros_like(t)
    
    burst_seno = (t >= 1) & (t < 1.5)
    base_signal[burst_seno] = amplitud * np.sin(2 * np.pi * 50 * t[burst_seno])
    
    burst_coseno = (t >= 0) & (t < 0.5)
    base_signal[burst_coseno] += amplitud * np.cos(2 * np.pi * 100 * t[burst_coseno])
else:
    base_signal = np.zeros_like(t)

fig, ax = plt.subplots(figsize=(15, 2), dpi=300)
ax.set_xlabel("Time (s)")
ax.plot(t, base_signal)

### Burst parameters

In [None]:
burst_duration = 1
burst_start = 1
real_t_onset = burst_start
real_t_offset = burst_start + burst_duration

### De Luca method to create burst with a known frequency

In [None]:
burst = []
[filtro, PdeLuca, burst, MNF_ideal, MNF_analytic] = superlet.fdeluca(120, 60, fs, burst_duration, plot = False)
print(f'MNF ideal = {MNF_ideal}')
print(f'MNF analytic = {MNF_analytic}')

### Add noise to the signal with a particular value for SNR

In [None]:
filename = "/Users/neuralrehabilitationgroup/PycharmProjects/Superlets-Marina/RESULTS/signal_burst_1.npy"

In [None]:
if os.path.exists(filename):
    os.remove(filename)

In [None]:
if os.path.exists(filename):
    burst_signal = np.load(filename)
else:
    burst_signal = noise.copy()
    burst_signal[int(burst_start * fs):int(burst_start * fs + len(burst))] += burst
    np.save('/Users/neuralrehabilitationgroup/PycharmProjects/Superlets-Marina/RESULTS/signal_burst_1.npy', burst_signal)

plt.plot(noise)
plt.figure()
plt.plot(burst_signal)

### Signal visualization

In [None]:
importlib.reload(superlet)

In [None]:
psd_welch, f_welch, MNF = superlet.compute_psd_welch(burst_signal, fs, plot=True, muscle='Simulated EMG signal with burst')
MNF_1 = [MNF]

In [None]:
burst_signal+=base_signal

In [None]:
fig, ax = plt.subplots(figsize=(15, 2), dpi=300)
ax.set_xlabel("Time (s)")
ax.plot(jnp.linspace(0, len(burst_signal)/fs, len(burst_signal)), burst_signal)

In [None]:
MNF_1

### Definición del vector de frecuencia

In [None]:
freqs = np.linspace(20,500,int(fs/2)) # Número de puntos del vector frecuencia
freq_min = freqs[0]  # Minimum freq
freq_max = freqs[-1]  # Frecuencia máxima
num_freqs = len(freqs)  # Resolución del rango de frecuencias

In [None]:
def calcular_snr(signal, noise):
    # Potencia de la señal
    signal_power = np.mean(signal ** 2)
    # Potencia del ruido
    noise_power = np.mean(noise ** 2)
    # Cálculo del SNR en dB
    snr_db = 10 * np.log10(signal_power / noise_power)
    return snr_db

In [None]:
SNR = 5
# if 'SNR' not in globals():
#     SNR = "Inf"
signal_1 = burst_signal
signal_1, xn = superlet.add_noise(burst_signal, SNR, fs, plot = True)
snr_real = calcular_snr(signal_1, xn)
#signal_1 = superlet.add_wgn_to_sig(burst_signal, noise, SNR)

In [None]:
psd_welch, f_welch, MNF = superlet.compute_psd_welch(signal_1, fs, plot=True, muscle='Simulated EMG signal with burst')

In [None]:
snr_real

In [None]:
fig, ax = plt.subplots(figsize=(15, 2), dpi=300)
ax.set_xlabel("Time (ms)")
ax.plot(jnp.linspace(0, len(signal_1) / fs, len(signal_1)), signal_1)
print(f"Reference mean frequency: {MNF_1[0]} Hz")
plt.plot(t, signal_1)

# WAVELET PARAMETERS = 3, 16, 33, 55, 60 cycles

In [None]:
params_wavelet = [3, 16, 33, 55, 60]

In [None]:
plot_scalogram = True
plot_imnf = False
plot_response = True
plot_means = False
plot_time_estimation = False

In [None]:
importlib.reload(superlet)
if plot_scalogram:
    fig, ax = plt.subplots(ncols=len(params_wavelet), figsize=(18,5), dpi=300, sharey=True, sharex=True)
if plot_imnf:
    fig_2, ax_2 = plt.subplots(ncols=len(params_wavelet), figsize=(18, 5), dpi=300)
if plot_response:
    fig_3, (ax_3_1, ax_3_2) = plt.subplots(1, 2, figsize=(18, 5))  

mae_wavelet_f = []
std_wavelet_f = []

mae_wavelet_t = []
std_wavelet_t = []

total_time = []
total_freq = []

total_scalogram_wavelet = []
res_wavelet = []

for i, c in enumerate(params_wavelet):
    wavelet = f'cmor{c}-1.0'
    
    cwtmatr, f, physical_freqs = morlet.wavelet_transform_2(signal_1,wavelet,freqs,fs)
    cwtmatr = np.abs(cwtmatr[:,:])
    
    scalogram_2 = np.array(jnp.abs(cwtmatr)**2)
    
    if plot_scalogram:
        pcm = ax[i].pcolormesh(t,f, scalogram_2, shading='gouraud', cmap='jet')
        
        ax[i].set_title(f'c={c}')
        if i == 0:
            ax[i].set_ylabel('Frequency (Hz)')
        else:
            ax[i].set_ylabel('')  # Eliminar la etiqueta del eje Y en los demás subplots

        ax[i].set_xlabel('Time (s)')
        ax[i].set_ylim(f[0], f[-1])
     
    total_scalogram_wavelet.append(scalogram_2)
    instant_freq = np.sum(scalogram_2 * f[:, np.newaxis], axis=0) / np.sum(scalogram_2, axis=0)
    
    # VALOR MEDIO EN UNA VENTANA DESLIZANTE
    window_size = 5
    sigma = 60
    smooth_sigma = 20
    
    # media_movil, slope_gradient, first_index, last_index = superlet.compute_smoothed_gradient(instant_freq, t, fs=fs, burst_start=burst_start, burst_duration=burst_duration, window_size=50, sigma=2)
    # 
    # first_index, last_index = superlet.find_low_oscillation_region(instant_freq, t, fs, burst_start, burst_duration, threshold=0.05, min_duration=burst_duration/2)
    
    energia_temporal = np.sum(scalogram_2, axis=0)
    
    first_index, last_index = superlet.find_plateau_region(energia_temporal, burst_start, burst_duration, t, fs, smooth_sigma=smooth_sigma)
    
    t_onset = t[first_index]
    t_offset = t[last_index]
       
    start_idx = np.searchsorted(t, burst_start)  # Encuentra el índice de inicio
    end_idx = np.searchsorted(t, burst_start + burst_duration)  # Encuentra el índice final
    
    instant_mean_freq_burst = instant_freq[start_idx:end_idx]
    
    if plot_imnf:
        ax_2[i].plot(t, instant_freq)

        ax_2[i].set_xlabel("Time (s)")
        ax_2[i].set_ylabel("Frequency (Hz)")
        #ax_2[i].set_xlim(t[0], t[-1])
        plt.title("Instantaneous frequency")
        ax_2[i].grid(True)

    if plot_response:
        ax_3_1.plot(t, instant_freq, label=f"c={c}")
        ax_3_1.set_xlabel("Time (s)")
        ax_3_1.set_ylabel("Frequency (Hz)")
        ax_3_1.legend()
        ax_3_1.grid(True)
        ax_3_1.set_title("Instantaneous frequency")
        
        ax_3_2.plot(t, energia_temporal, label=f"c={c}")
        ax_3_2.set_xlabel("Time (s)")
        ax_3_2.set_ylabel("Energy (V²)")
        ax_3_2.legend()
        ax_3_2.grid(True)
        ax_3_2.set_title("Energy over time")
        
    #MAE in frequency
    print(f'Estimated mean frequency: {np.mean(instant_mean_freq_burst)} Hz')
    mae_f, std_f = superlet.calculate_mae(instant_mean_freq_burst, MNF_1)
    mae_wavelet_f.append(mae_f)
    std_wavelet_f.append(std_f)
    
    #MAE in time
    print(f'Estimated onset = {t_onset} s, offset = {t_offset} s')
    mae_t, std_t = superlet.calculate_mae((t_onset,t_offset), (real_t_onset,real_t_offset))
    mae_wavelet_t.append(mae_t)
    std_wavelet_t.append(std_t)
    
    time = np.arange(scalogram_2.shape[1]) / fs

    total_time.append(time)
    total_freq.append(freqs)
    
res_wavelet = superlet.compute_avg_response_resolution(total_scalogram_wavelet, total_time, total_freq, params_wavelet)

if plot_scalogram:
    cbar = fig.colorbar(pcm, ax=ax, orientation='horizontal', pad=0.15, shrink=0.1)
    cbar.set_label('Power (V²)', rotation=0, labelpad=15)
    
plt.tight_layout()
plt.show()
    

In [None]:
mae_wavelet_f

In [None]:
mae_wavelet_t

## PARÁMETROS SUPERLET: base_cycle, min_order, max_order = [3, 5, 1, 1, 1, 1], [1, 1, 5, 10, 20, 30], [30, 30, 40, 100, 100, 200]

In [None]:
params_superlet = [[1, 1, 3, 5, 5], [3, 5, 1, 1, 1], [30, 40, 40, 50, 60]]

In [None]:
importlib.reload(superlet)

if plot_scalogram:
    fig, ax = plt.subplots(ncols=len(params_wavelet), figsize=(18,5), dpi=300, sharey=True, sharex=True)
if plot_imnf:
    fig_2, ax_2 = plt.subplots(ncols=len(params_wavelet), figsize=(18, 5), dpi=300)
if plot_response:
    fig_3, (ax_3_1, ax_3_2) = plt.subplots(1, 2, figsize=(18, 5)) 

mae_superlet_f = []
std_superlet_f = []

mae_superlet_t = []
std_superlet_t = []

total_time = []
total_freq = []

total_scalogram_superlet = []

for (i, (base_cycle, min_order, max_order)) in enumerate(zip(*params_superlet)):

    wv, scalogram = superlet.adaptive_superlet_transform(signal_1, freqs, sampling_freq=fs,
                                                         base_cycle=base_cycle, min_order=min_order,
                                                         max_order=max_order, mode="mul")
    
    if plot_scalogram:
        im = ax[i].imshow(jnp.abs(scalogram) ** 2, aspect='auto', cmap="jet", interpolation="none", origin="lower",extent=[0, len(signal_1) / fs, freqs[0], freqs[-1]])
        ax[i].set_title(f"$c_1$: {base_cycle}, o: {min_order}-{max_order}")
        ax[i].set_xlabel("Time (s)")
        if i == 0:
            ax[i].set_ylabel("Frequency (Hz)")
        else:
            ax[i].set_ylabel("")
        ax[i].set_ylim(freqs[0], freqs[-1])


    scalogram_2 = np.abs(scalogram) ** 2

    total_scalogram_superlet.append(scalogram_2)
    
    instant_freq = np.sum(scalogram_2 * freqs[:, np.newaxis], axis=0) / np.sum(scalogram_2, axis=0)
        
    energia_temporal = np.sum(scalogram_2, axis=0)
        
    first_index, last_index = superlet.find_plateau_region(energia_temporal, burst_start, burst_duration, t, fs, smooth_sigma=smooth_sigma)
    
    t_onset = t[first_index]
    t_offset = t[last_index]
    
    instant_mean_freq_burst = instant_freq[start_idx:end_idx]
    
    if plot_imnf:
        ax_2[i].plot(t, instant_freq)
    
        ax_2[i].set_xlabel("Time (s)")
        ax_2[i].set_ylabel("Frequency (Hz)")
        #ax_2[i].set_xlim(t[0], t[-1])
        plt.title("Instantaneous frequency")
        ax_2[i].grid(True)
        
    if plot_response:
        ax_3_1.plot(t, instant_freq, label=f"$c_1$: {base_cycle}, o: {min_order}-{max_order}")
        ax_3_1.set_xlabel("Time (s)")
        ax_3_1.set_ylabel("Frequency (Hz)")
        ax_3_1.legend()
        ax_3_1.grid(True)
        ax_3_1.set_title("Instantaneous frequency")
        
        ax_3_2.plot(t, energia_temporal, label=f"$c_1$: {base_cycle}, o: {min_order}-{max_order}")
        ax_3_2.set_xlabel("Time (s)")
        ax_3_2.set_ylabel("Energy (V²)")
        ax_3_2.legend()
        ax_3_2.grid(True)
        ax_3_2.set_title("Energy over time")
        
    #MAE in frequency
    print(f'Estimated mean frequency: {np.mean(instant_mean_freq_burst)} Hz')
    mae_f, std_f = superlet.calculate_mae(instant_mean_freq_burst, MNF_1)
    mae_superlet_f.append(mae_f)
    std_superlet_f.append(std_f)
    
    #MAE in time
    print(f'Estimated onset = {t_onset} s, offset = {t_offset} s')
    # estimates_combined = t_offset - t_onset
    # print(f'Estimated difference: {estimates_combined} s')
    # true_values_combined = t_offset_1 - t_onset_1
    # print(f'Real difference: {true_values_combined} s')

    mae_t, std_t = superlet.calculate_mae((t_onset,t_offset), (real_t_onset,real_t_offset))
    mae_superlet_t.append(mae_t)
    std_superlet_t.append(std_t)
    
    # Visualization
    plt.figure(figsize=(12, 6))

    plt.subplot(2, 1, 1)
    plt.plot(t, instant_freq)
    plt.axvline(t_onset, color='green', linestyle='--', label='t_onset')
    plt.axvline(t_offset, color='red', linestyle='--', label='t_offset')
    plt.xlabel('Time (s)')
    plt.ylabel('Freq (Hz)')
    plt.legend()
    plt.grid()
    
    time = np.arange(scalogram_2.shape[1]) / fs

    total_time.append(time)
    total_freq.append(freqs)

    total_time.append(t)
    total_freq.append(freqs)

res_superlet = superlet.compute_avg_response_resolution(total_scalogram_superlet, total_time, total_freq, [f'$c_1$={c3}, o:{c1}-{c2}' for c1, c2, c3 in zip(*params_superlet)])

if plot_scalogram:
    cbar = fig.colorbar(im, ax=ax, orientation='horizontal', pad=0.15, shrink=0.1)
    cbar.set_label('Power (V²)', rotation=0, labelpad=15)

In [None]:
mae_superlet_f

In [None]:
mae_superlet_t

## STFT PARAMETERS (window = duration/cycles) = 38, 200, 413, 550, 600, 824 ms

In [None]:
offset_burst = 0
offset = 0

In [None]:
importlib.reload(superlet)

In [None]:
from scipy.signal import stft
from scipy.signal import windows

windows_ms = [38, 200, 413, 550, 600]

if plot_scalogram:
    fig, ax = plt.subplots(ncols=len(params_wavelet), figsize=(18,5), dpi=300, sharey=True, sharex=True)
if plot_imnf:
    fig_2, ax_2 = plt.subplots(ncols=len(params_wavelet), figsize=(18, 5), dpi=300)
if plot_response:
    fig_3, (ax_3_1, ax_3_2) = plt.subplots(1, 2, figsize=(18, 5)) 

mae_stft_f = []
std_stft_f = []

mae_stft_t = []
std_stft_t = []

total_time = []
total_freq = []

res_stft = []
total_scalogram_stft = []

for i, w in enumerate(windows_ms):
    # Calcular la ventana Blackman
    window = windows.blackman(w)
    
    # Hacer que la f de la stft tenga el mismo tamaño que las demás
    nfft = 2 * (len(freqs))

    f, t_stft, Zxx = stft(signal_1, fs=fs, window='blackman', nperseg=w, noverlap=w-1, nfft=nfft, detrend=False, return_onesided=True, boundary='zeros', padded=True, axis=-1, scaling='spectrum')

    # Calcular la representación de potencia (magnitud al cuadrado)
    Zxx_power = np.abs(Zxx) ** 2
    
    if plot_scalogram:
        pcm = ax[i].pcolormesh(t_stft, f, Zxx_power, shading='gouraud', cmap='jet')
        
        ax[i].set_title(f'W={w * 1000 // fs} ms')
        if i == 0:
            ax[i].set_ylabel('Frecuencia [Hz]')
        else: 
            ax[i].set_ylabel('')
        ax[i].set_xlabel('Time [s]')

        ax[i].set_ylim(f[0], f[-1])
    
    scalogram_2 = Zxx_power
    
    #if np.any(np.isnan(jnp.abs(Zxx)**2)) or np.any(np.isinf(jnp.abs(Zxx)**2)):
    #scalogram_2 = jnp.nan_to_num(jnp.abs(Zxx)**2, nan=0.0, posinf=0.0, neginf=0.0)
    
    total_scalogram_stft.append(scalogram_2)
    
    # Para evitar la división por cero
    denominador = np.sum(scalogram_2, axis=0)
    # denominador = np.where(denominador == 0, 1e-10, denominador)
    instant_freq = np.sum(scalogram_2 * f[:, np.newaxis], axis=0) / denominador
    
    energia_temporal = np.sum(scalogram_2, axis=0)
    
    first_index, last_index = superlet.find_plateau_region(energia_temporal, burst_start, burst_duration, t, fs, smooth_sigma=smooth_sigma)
    
    t_onset = t[first_index]
    t_offset = t[last_index]

    instant_mean_freq_burst = instant_freq[start_idx:end_idx]

    if plot_imnf:
        ax_2[i].plot(t_stft, instant_freq)

        ax_2[i].set_xlabel("Time (s)")
        ax_2[i].set_ylabel("Frequency (Hz)")
        plt.title("Instantaneous frequency")
        ax_2[i].grid(True)

    if plot_response:
        ax_3_1.plot(t_stft, instant_freq, label=f'W={w * 1000 // fs} ms')
        ax_3_1.set_xlabel("Time (s)")
        ax_3_1.set_ylabel("Frequency (Hz)")
        ax_3_1.legend()
        ax_3_1.grid(True)
        ax_3_1.set_title("Instantaneous frequency")
        
        ax_3_2.plot(t_stft, energia_temporal, label=f'W={w * 1000 // fs} ms')
        ax_3_2.set_xlabel("Time (s)")
        ax_3_2.set_ylabel("Energy (V²)")
        ax_3_2.legend()
        ax_3_2.grid(True)
        ax_3_2.set_title("Energy over time")

    #MAE in frequency
    print(f'Estimated mean frequency: {np.mean(instant_mean_freq_burst)} Hz')
    mae_f, std_f = superlet.calculate_mae(instant_mean_freq_burst, MNF_1)
    mae_stft_f.append(mae_f)
    std_stft_f.append(std_f)
    
    #MAE in time
    print(f'Estimated onset = {t_onset} s, offset = {t_offset} s')
    # estimates_combined = t_offset - t_onset
    # print(f'Estimated difference: {estimates_combined} s')
    # true_values_combined = t_offset_1 - t_onset_1
    # print(f'Real difference: {true_values_combined} s')

    mae_t, std_t = superlet.calculate_mae((t_onset,t_offset), (real_t_onset,real_t_offset))
    mae_stft_t.append(mae_t)
    std_stft_t.append(std_t)
    
    # Visualization
    plt.figure(figsize=(12, 6))

    plt.subplot(2, 1, 1)
    plt.plot(t_stft, instant_freq)
    plt.axvline(t_onset, color='green', linestyle='--', label='t_onset')
    plt.axvline(t_offset, color='red', linestyle='--', label='t_offset')
    plt.xlabel('Time (s)')
    plt.ylabel('Freq (Hz)')
    plt.legend()
    plt.grid()

    plt.tight_layout()
    
    total_time.append(np.linspace(t_stft[0], t_stft[-1], scalogram_2.shape[1]))
    total_freq.append(np.linspace(f[0], f[-1], scalogram_2.shape[0]))

res_stft = superlet.compute_avg_response_resolution(total_scalogram_stft, total_time, total_freq, windows_ms, stft=[freqs[0],freqs[-1]])

if plot_scalogram:
    cbar = fig.colorbar(pcm, ax=ax, orientation='horizontal', pad=0.15, shrink=0.1)
    cbar.set_label('Power (V²)', rotation=0, labelpad=15)

In [None]:
mae_stft_f

In [None]:
mae_stft_t

# Plot comparison of the same SNR

## MAE IN FREQUENCY

In [None]:
params_stft = windows_ms

# Número de condiciones
num_conditions = len(mae_wavelet_f)
x = np.arange(num_conditions)  # Posiciones para cada grupo de barras
bar_width = 0.25  # Ancho de las barras

# Crear el gráfico de barras
plt.figure(figsize=(14, 7))

# Graficar cada técnica con sus valores y barras de error
plt.bar(x - bar_width, mae_stft_f, width=bar_width, yerr=std_stft_f, capsize=5, label='STFT', color='skyblue', alpha=0.7)
plt.bar(x, mae_wavelet_f, width=bar_width, yerr=std_wavelet_f, capsize=5, label='Wavelet', color='lightgreen', alpha=0.7)
plt.bar(x + bar_width, mae_superlet_f, width=bar_width, yerr=std_superlet_f, capsize=5, label='Superlet', color='salmon', alpha=0.7)

# Configuración de etiquetas en el eje x con los parámetros específicos
labels = [
    f"W={p_stft} ms\n\nc={p_w}\n\nc$_1$={p_s[0]}, o: {p_s[1]}-{p_s[2]}"
    for p_w, p_s, p_stft in zip(params_wavelet, zip(*params_superlet), params_stft)
]

plt.xticks(x, labels, rotation=90, ha='center', fontsize=18)
#plt.xlabel('Condiciones y Parámetros')
plt.yticks(np.arange(0, 101, 10))

plt.ylabel('MAE (Hz)')
plt.title('MAE in FREQUENCY (SNR = ' + str(SNR) +')', fontsize=18)
plt.ylim(0,100)
plt.legend(fontsize=14)
plt.grid(True)

# Ajustar el layout para mejorar visualización de etiquetas
plt.tight_layout()

# Mostrar el gráfico
plt.show()

## MAE IN TIME

In [None]:
params_stft = windows_ms

# Número de condiciones
num_conditions = len(mae_wavelet_t)
x = np.arange(num_conditions)  # Posiciones para cada grupo de barras
bar_width = 0.25  # Ancho de las barras

# Crear el gráfico de barras
plt.figure(figsize=(14, 7))

# Graficar cada técnica con sus valores y barras de error
plt.bar(x - bar_width, mae_stft_t, width=bar_width, yerr=std_stft_t, capsize=5, label='STFT', color='skyblue', alpha=0.7)
plt.bar(x, mae_wavelet_t, width=bar_width, yerr=std_wavelet_t, capsize=5, label='Wavelet', color='lightgreen', alpha=0.7)
plt.bar(x + bar_width, mae_superlet_t, width=bar_width, yerr=std_superlet_t, capsize=5, label='Superlet', color='salmon', alpha=0.7)

# Configuración de etiquetas en el eje x con los parámetros específicos
labels = [
    f"W={p_stft} ms\n\nc={p_w}\n\nc$_1$={p_s[0]}, o: {p_s[1]}-{p_s[2]}"
    for p_w, p_s, p_stft in zip(params_wavelet, zip(*params_superlet), params_stft)
]

plt.xticks(x, labels, rotation=90, ha='center', fontsize=18)
#plt.xlabel('Condiciones y Parámetros')
plt.yticks(np.arange(0, 0.21, 0.1))

plt.ylabel('MAE (s)')
plt.title('MAE in TIME (SNR = ' + str(SNR) +')', fontsize=18)
plt.ylim(0,0.2)
plt.legend(fontsize=14)
plt.grid(True)

# Ajustar el layout para mejorar visualización de etiquetas
plt.tight_layout()

# Mostrar el gráfico
plt.show()

# Compute resolution

In [None]:
importlib.reload(superlet)

In [None]:
rayleigh_limit = 1 / (4 * np.pi)

# Etiquetas para los grupos de parámetros con salto de línea adicional
labels = [
    f"W={p_stft} ms\nc={p_wavelet}\nc$_1$={p_superlet[0]}, o={p_superlet[1]}-{p_superlet[2]}"
    for p_wavelet, p_superlet, p_stft in zip(params_wavelet, zip(*params_superlet), params_stft)
]

x = np.arange(1,len(params_wavelet)+1)

plt.figure(figsize=(12, 8))

# Primer gráfico (Frecuencia)
plt.subplot(2, 1, 1)
plt.plot(x, [r[1] for r in res_wavelet], 'go-', label="Wavelet (Frequency)")
plt.plot(x, [r[1] for r in res_stft], 'bo-', label="STFT (Frequency)")
plt.plot(x, [r[1] for r in res_superlet], 'ro-', label="Superlet (Frequency)")
#plt.axhline(rayleigh_limit, color='k', linestyle='--', label="Rayleigh Limit")
plt.xlabel("Window Size / Cycles / Order")
plt.ylabel("Frequency Resolution (Hz)")
plt.legend()

# Segundo gráfico (Tiempo)
plt.subplot(2, 1, 2)
plt.plot(x, [r[0] for r in res_wavelet], 'go-', label="Wavelet (Time)")
plt.plot(x, [r[0] for r in res_stft], 'bo-', label="STFT (Time)")
plt.plot(x, [r[0] for r in res_superlet], 'ro-', label="Superlet (Time)")
#plt.axhline(1 / MNF_1[0], color='k', linestyle='--', label="Temporal Limit")
plt.xlabel("Window Size / Cycles / Order")
plt.ylabel("Time Resolution (s)")
plt.legend()

# Ajuste de los xticks para ambos subgráficos
for ax in plt.gcf().get_axes():
    ax.set_xticks(x)  # Establecer los valores de los ticks
    ax.set_xticklabels(labels, rotation=90, ha='center', fontsize=10)  # Establecer las etiquetas con el formato deseado

# Ajustar el diseño
plt.tight_layout()
plt.show()

# Only first time !!!!

In [None]:
MAES_f = {
    'stft': [],
    'wavelet': [],
    'superlet': [],
    'std_stft': [],
    'std_wavelet': [],
    'std_superlet': []
}

In [None]:
MAES_f['wavelet'].append(mae_wavelet_f)
MAES_f['stft'].append(mae_stft_f)
MAES_f['superlet'].append(mae_superlet_f)
MAES_f['std_wavelet'].append(std_wavelet_f)
MAES_f['std_stft'].append(std_stft_f)
MAES_f['std_superlet'].append(std_superlet_f)

In [None]:
MAES_f

In [None]:
MAES_t = {
    'stft': [],
    'wavelet': [],
    'superlet': [],
    'std_stft': [],
    'std_wavelet': [],
    'std_superlet': []
}

In [None]:
MAES_t['wavelet'].append(mae_wavelet_t)
MAES_t['stft'].append(mae_stft_t)
MAES_t['superlet'].append(mae_superlet_t)
MAES_t['std_wavelet'].append(std_wavelet_t)
MAES_t['std_stft'].append(std_stft_t)
MAES_t['std_superlet'].append(std_superlet_t)

In [None]:
MAES_t

# WE REPEAT EVERYTHING WITH DIFFERENT SNR

In [None]:
if os.path.exists(filename):
    burst_signal = np.load(filename) + base_signal
else:
    #burst_signal = np.random.normal(0, 0.08, len(t))
    burst_signal = np.array([0.00001] * len(t))
    burst_signal[int(burst_start * fs):int(burst_start * fs + len(burst))] += burst
    #np.save('/Users/neuralrehabilitationgroup/PycharmProjects/Superlets-Marina/RESULTS/signal_burst_1.npy', burst_signal)

plt.plot(noise)
plt.figure()
plt.plot(burst_signal)

In [None]:
SNR = 15
# if 'SNR' not in globals():
#     SNR = "Inf"
signal_2 = burst_signal
signal_2, xn = superlet.add_noise(burst_signal, SNR, fs, plot = True)
#signal_2 = superlet.add_wgn_to_sig(burst_signal, noise, SNR)
snr_real = calcular_snr(signal_2, xn)

In [None]:
psd_welch, f_welch, MNF = superlet.compute_psd_welch(signal_2, fs, plot=True, muscle='Simulated EMG signal with burst')

In [None]:
snr_real

In [None]:
plt.plot(signal_1)
plt.plot(signal_2)

In [None]:
fig, ax = plt.subplots(figsize=(15, 2), dpi=300)
ax.set_xlabel("Time (s)")
ax.plot(jnp.linspace(0, len(signal_2) / fs, len(signal_2)), signal_2)
print(f"Reference mean frequency: {MNF_1[0]} Hz")
plt.plot(t, signal_2)

# WAVELET PARAMETERS = 3, 16, 33, 55, 60 115 cycles

In [None]:
if plot_scalogram:
    fig, ax = plt.subplots(ncols=len(params_wavelet), figsize=(18,5), dpi=300, sharey=True, sharex=True)
if plot_imnf:
    fig_2, ax_2 = plt.subplots(ncols=len(params_wavelet), figsize=(18, 5), dpi=300)
if plot_response:
    fig_3, (ax_3_1, ax_3_2) = plt.subplots(1, 2, figsize=(18, 5))  

mae_wavelet_f = []
std_wavelet_f = []

mae_wavelet_t = []
std_wavelet_t = []

total_time = []
total_freq = []

total_scalogram_wavelet = []
res_wavelet = []

for i, c in enumerate(params_wavelet):
    wavelet = f'cmor{c}-1.0'
    
    cwtmatr, f, physical_freqs = morlet.wavelet_transform_2(signal_2,wavelet,freqs,fs)
    cwtmatr = np.abs(cwtmatr[:,:])
    
    scalogram_2 = np.array(jnp.abs(cwtmatr)**2)
    
    if plot_scalogram:
        pcm = ax[i].pcolormesh(t,f, scalogram_2, shading='gouraud', cmap='jet')
        
        ax[i].set_title(f'c={c}')
        if i == 0:
            ax[i].set_ylabel('Frequency (Hz)')
        else:
            ax[i].set_ylabel('')  # Eliminar la etiqueta del eje Y en los demás subplots
    
        ax[i].set_xlabel('Time (s)')
        ax[i].set_ylim(f[0], f[-1])
        
    total_scalogram_wavelet.append(scalogram_2)
    instant_freq = np.sum(scalogram_2 * f[:, np.newaxis], axis=0) / np.sum(scalogram_2, axis=0)
    
    energia_temporal = np.sum(scalogram_2, axis=0)
    
    first_index, last_index = superlet.find_plateau_region(energia_temporal, burst_start, burst_duration, t, fs, smooth_sigma=smooth_sigma)
    
    t_onset = t[first_index]
    t_offset = t[last_index]
    
    instant_mean_freq_burst = instant_freq[start_idx:end_idx]
    
    if plot_imnf:
        ax_2[i].plot(t, instant_freq)
    
        ax_2[i].set_xlabel("Time (s)")
        ax_2[i].set_ylabel("Frequency (Hz)")
        #ax_2[i].set_xlim(t[0], t[-1])
        plt.title("Instantaneous frequency")
        ax_2[i].grid(True)
    
    if plot_response:
        ax_3_1.plot(t, instant_freq, label=f"c={c}")
        ax_3_1.set_xlabel("Time (s)")
        ax_3_1.set_ylabel("Frequency (Hz)")
        ax_3_1.legend()
        ax_3_1.grid(True)
        ax_3_1.set_title("Instantaneous frequency")
        
        ax_3_2.plot(t, energia_temporal, label=f"c={c}")
        ax_3_2.set_xlabel("Time (s)")
        ax_3_2.set_ylabel("Energy (V²)")
        ax_3_2.legend()
        ax_3_2.grid(True)
        ax_3_2.set_title("Energy over time")
        
    #MAE in frequency
    print(f'Estimated mean frequency: {np.mean(instant_mean_freq_burst)} Hz')
    mae_f, std_f = superlet.calculate_mae(instant_mean_freq_burst, MNF_1)
    mae_wavelet_f.append(mae_f)
    std_wavelet_f.append(std_f)
    
    #MAE in time
    print(f'Estimated onset = {t_onset} s, offset = {t_offset} s')
    mae_t, std_t = superlet.calculate_mae((t_onset,t_offset), (real_t_onset,real_t_offset))
    mae_wavelet_t.append(mae_t)
    std_wavelet_t.append(std_t)
    
    # Visualization
    plt.figure(figsize=(8, 5))

    plt.subplot(2, 1, 1)
    plt.plot(t, instant_freq)
    #plt.plot(t, media_movil)
    plt.axvline(t_onset, color='green', linestyle='--', label='t_onset')
    plt.axvline(t_offset, color='red', linestyle='--', label='t_offset')
    plt.xlabel('Time (s)')
    plt.ylabel('Freq (Hz)')
    plt.legend()
    plt.grid()

    plt.tight_layout()
    
    time = np.arange(scalogram_2.shape[1]) / fs

    total_time.append(time)
    total_freq.append(freqs)
    
res_wavelet = superlet.compute_avg_response_resolution(total_scalogram_wavelet, total_time, total_freq, params_wavelet)

if plot_scalogram:
    cbar = fig.colorbar(pcm, ax=ax, orientation='horizontal', pad=0.15, shrink=0.1)
    cbar.set_label('Power (V²)', rotation=0, labelpad=15)

In [None]:
mae_wavelet_f

In [None]:
mae_wavelet_t

## PARÁMETROS SUPERLET: base_cycle, min_order, max_order = [3, 5, 1, 1, 1, 1], [1, 1, 5, 10, 20, 30], [30, 30, 40, 100, 100, 200]

In [None]:
if plot_scalogram:
    fig, ax = plt.subplots(ncols=len(params_wavelet), figsize=(18,5), dpi=300, sharey=True, sharex=True)
if plot_imnf:
    fig_2, ax_2 = plt.subplots(ncols=len(params_wavelet), figsize=(18, 5), dpi=300)
if plot_response:
    fig_3, (ax_3_1, ax_3_2) = plt.subplots(1, 2, figsize=(18, 5))

mae_superlet_f = []
std_superlet_f = []

mae_superlet_t = []
std_superlet_t = []

total_time = []
total_freq = []

total_scalogram_superlet = []

for (i, (base_cycle, min_order, max_order)) in enumerate(zip(*params_superlet)):

    wv, scalogram = superlet.adaptive_superlet_transform(signal_2, freqs, sampling_freq=fs,
                                                         base_cycle=base_cycle, min_order=min_order,
                                                         max_order=max_order, mode="mul")
    
    if plot_scalogram:
        im = ax[i].imshow(jnp.abs(scalogram) ** 2, aspect='auto', cmap="jet", interpolation="none", origin="lower",extent=[0, len(signal_1) / fs, freqs[0], freqs[-1]])
        ax[i].set_title(f"$c_1$: {base_cycle}, o: {min_order}-{max_order}")
        ax[i].set_xlabel("Time (s)")
        if i == 0:
            ax[i].set_ylabel("Frequency (Hz)")
        else:
            ax[i].set_ylabel("")
        ax[i].set_ylim(freqs[0], freqs[-1])

    scalogram_2 = np.abs(scalogram) ** 2

    total_scalogram_superlet.append(scalogram_2)
    
    instant_freq = np.sum(scalogram_2 * freqs[:, np.newaxis], axis=0) / np.sum(scalogram_2, axis=0)
        
    energia_temporal = np.sum(scalogram_2, axis=0)
        
    first_index, last_index = superlet.find_plateau_region(energia_temporal, burst_start, burst_duration, t, fs, smooth_sigma=smooth_sigma)
    
    t_onset = t[first_index]
    t_offset = t[last_index]
    
    instant_mean_freq_burst = instant_freq[start_idx:end_idx]
    
    if plot_imnf:
        ax_2[i].plot(t, instant_freq)

        ax_2[i].set_xlabel("Time (s)")
        ax_2[i].set_ylabel("Frequency (Hz)")
        #ax_2[i].set_xlim(t[0], t[-1])
        plt.title("Instantaneous frequency")
        ax_2[i].grid(True)

    if plot_response:
        ax_3_1.plot(t, instant_freq, label=f"$c_1$: {base_cycle}, o: {min_order}-{max_order}")
        ax_3_1.set_xlabel("Time (s)")
        ax_3_1.set_ylabel("Frequency (Hz)")
        ax_3_1.legend()
        ax_3_1.grid(True)
        ax_3_1.set_title("Instantaneous frequency")
        
        ax_3_2.plot(t, energia_temporal, label=f"$c_1$: {base_cycle}, o: {min_order}-{max_order}")
        ax_3_2.set_xlabel("Time (s)")
        ax_3_2.set_ylabel("Energy (V²)")
        ax_3_2.legend()
        ax_3_2.grid(True)
        ax_3_2.set_title("Energy over time")
        
    #MAE in frequency
    print(f'Estimated mean frequency: {np.mean(instant_mean_freq_burst)} Hz')
    mae_f, std_f = superlet.calculate_mae(instant_mean_freq_burst, MNF_1)
    mae_superlet_f.append(mae_f)
    std_superlet_f.append(std_f)
    
    #MAE in time
    print(f'Estimated onset = {t_onset} s, offset = {t_offset} s')
    # estimates_combined = t_offset - t_onset
    # print(f'Estimated difference: {estimates_combined} s')
    # true_values_combined = t_offset_1 - t_onset_1
    # print(f'Real difference: {true_values_combined} s')

    mae_t, std_t = superlet.calculate_mae((t_onset,t_offset), (real_t_onset,real_t_offset))
    mae_superlet_t.append(mae_t)
    std_superlet_t.append(std_t)
    
    # Visualization
    plt.figure(figsize=(12, 6))

    plt.subplot(2, 1, 1)
    plt.plot(t, instant_freq)
    plt.axvline(t_onset, color='green', linestyle='--', label='t_onset')
    plt.axvline(t_offset, color='red', linestyle='--', label='t_offset')
    plt.xlabel('Time (s)')
    plt.ylabel('Freq (Hz)')
    plt.legend()
    plt.grid()
    
    time = np.arange(scalogram_2.shape[1]) / fs

    total_time.append(time)
    total_freq.append(freqs)

    total_time.append(t)
    total_freq.append(freqs)

res_superlet = superlet.compute_avg_response_resolution(total_scalogram_superlet, total_time, total_freq, [f'$c_1$={c3}, o:{c1}-{c2}' for c1, c2, c3 in zip(*params_superlet)])

if plot_scalogram:
    cbar = fig.colorbar(im, ax=ax, orientation='horizontal', pad=0.15, shrink=0.1)
    cbar.set_label('Power (V²)', rotation=0, labelpad=15)

In [None]:
mae_superlet_f

In [None]:
mae_superlet_t

## STFT PARAMETERS (window = duration/cycles) = 38, 200, 413, 550, 600, 824 ms

In [None]:
offset_burst = 0
offset = 0

In [None]:
if plot_scalogram:
    fig, ax = plt.subplots(ncols=len(params_wavelet), figsize=(18,5), dpi=300, sharey=True, sharex=True)
if plot_imnf:
    fig_2, ax_2 = plt.subplots(ncols=len(params_wavelet), figsize=(18, 5), dpi=300)
if plot_response:
    fig_3, (ax_3_1, ax_3_2) = plt.subplots(1, 2, figsize=(18, 5))  

mae_stft_f = []
std_stft_f = []

mae_stft_t = []
std_stft_t = []

total_time = []
total_freq = []

res_stft = []
total_scalogram_stft = []

for i, w in enumerate(windows_ms):
    # Calcular la ventana Blackman
    window = windows.blackman(w)
    
    # Hacer que la f de la stft tenga el mismo tamaño que las demás
    nfft = 2 * (len(freqs))

    f, t_stft, Zxx = stft(signal_2, fs=fs, window='blackman', nperseg=w, noverlap=w-1, nfft=nfft, detrend=False, return_onesided=True, boundary='zeros', padded=True, axis=-1, scaling='spectrum')

    # Calcular la representación de potencia (magnitud al cuadrado)
    Zxx_power = np.abs(Zxx) ** 2
    
    if plot_scalogram:
        pcm = ax[i].pcolormesh(t_stft, f, Zxx_power, shading='gouraud', cmap='jet')
        
        ax[i].set_title(f'W={w * 1000 // fs} ms')
        if i == 0:
            ax[i].set_ylabel('Frecuencia [Hz]')
        else: 
            ax[i].set_ylabel('')
        ax[i].set_xlabel('Time [s]')

        ax[i].set_ylim(f[0], f[-1])
    
    scalogram_2 = Zxx_power
    
    total_scalogram_stft.append(scalogram_2)
    
    # Para evitar la división por cero
    denominador = np.sum(scalogram_2, axis=0)
    # denominador = np.where(denominador == 0, 1e-10, denominador)
    instant_freq = np.sum(scalogram_2 * f[:, np.newaxis], axis=0) / denominador
    
    energia_temporal = np.sum(scalogram_2, axis=0)
    
    first_index, last_index = superlet.find_plateau_region(energia_temporal, burst_start, burst_duration, t, fs, smooth_sigma=smooth_sigma)
    
    t_onset = t[first_index]
    t_offset = t[last_index]

    instant_mean_freq_burst = instant_freq[start_idx:end_idx]

    if plot_imnf:
        ax_2[i].plot(t_stft, instant_freq)

        ax_2[i].set_xlabel("Time (s)")
        ax_2[i].set_ylabel("Frequency (Hz)")
        plt.title("Instantaneous frequency")
        ax_2[i].grid(True)

    if plot_response:
        ax_3_1.plot(t_stft, instant_freq, label=f'W={w * 1000 // fs} ms')
        ax_3_1.set_xlabel("Time (s)")
        ax_3_1.set_ylabel("Frequency (Hz)")
        ax_3_1.legend()
        ax_3_1.grid(True)
        ax_3_1.set_title("Instantaneous frequency")
        
        ax_3_2.plot(t_stft, energia_temporal, label=f'W={w * 1000 // fs} ms')
        ax_3_2.set_xlabel("Time (s)")
        ax_3_2.set_ylabel("Energy (V²)")
        ax_3_2.legend()
        ax_3_2.grid(True)
        ax_3_2.set_title("Energy over time")

    #MAE in frequency
    print(f'Estimated mean frequency: {np.mean(instant_mean_freq_burst)} Hz')
    mae_f, std_f = superlet.calculate_mae(instant_mean_freq_burst, MNF_1)
    mae_stft_f.append(mae_f)
    std_stft_f.append(std_f)
    
    #MAE in time
    print(f'Estimated onset = {t_onset} s, offset = {t_offset} s')
    # estimates_combined = t_offset - t_onset
    # print(f'Estimated difference: {estimates_combined} s')
    # true_values_combined = t_offset_1 - t_onset_1
    # print(f'Real difference: {true_values_combined} s')

    mae_t, std_t = superlet.calculate_mae((t_onset,t_offset), (real_t_onset,real_t_offset))
    mae_stft_t.append(mae_t)
    std_stft_t.append(std_t)
    
    # Visualization
    plt.figure(figsize=(12, 6))

    plt.subplot(2, 1, 1)
    plt.plot(t_stft, instant_freq)
    plt.axvline(t_onset, color='green', linestyle='--', label='t_onset')
    plt.axvline(t_offset, color='red', linestyle='--', label='t_offset')
    plt.xlabel('Time (s)')
    plt.ylabel('Freq (Hz)')
    plt.legend()
    plt.grid()

    plt.tight_layout()
    
    total_time.append(np.linspace(t_stft[0], t_stft[-1], scalogram_2.shape[1]))
    total_freq.append(np.linspace(f[0], f[-1], scalogram_2.shape[0]))

res_stft = superlet.compute_avg_response_resolution(total_scalogram_stft, total_time, total_freq, windows_ms, stft=[freqs[0],freqs[-1]])

if plot_scalogram:
    cbar = fig.colorbar(pcm, ax=ax, orientation='horizontal', pad=0.15, shrink=0.1)
    cbar.set_label('Power (V²)', rotation=0, labelpad=15)

In [None]:
mae_stft_f

In [None]:
mae_stft_t

# Plot comparison of the same SNR

## MAE IN FREQUENCY

In [None]:
params_stft = windows_ms

# Número de condiciones
num_conditions = len(mae_wavelet_f)
x = np.arange(num_conditions)  # Posiciones para cada grupo de barras
bar_width = 0.25  # Ancho de las barras

# Crear el gráfico de barras
plt.figure(figsize=(14, 7))

# Graficar cada técnica con sus valores y barras de error
plt.bar(x - bar_width, mae_stft_f, width=bar_width, yerr=std_stft_f, capsize=5, label='STFT', color='skyblue', alpha=0.7)
plt.bar(x, mae_wavelet_f, width=bar_width, yerr=std_wavelet_f, capsize=5, label='Wavelet', color='lightgreen', alpha=0.7)
plt.bar(x + bar_width, mae_superlet_f, width=bar_width, yerr=std_superlet_f, capsize=5, label='Superlet', color='salmon', alpha=0.7)

# Configuración de etiquetas en el eje x con los parámetros específicos
labels = [
    f"W={p_stft} ms\n\nc={p_w}\n\nc$_1$={p_s[0]}, o: {p_s[1]}-{p_s[2]}"
    for p_w, p_s, p_stft in zip(params_wavelet, zip(*params_superlet), params_stft)
]

plt.xticks(x, labels, rotation=90, ha='center', fontsize=18)
#plt.xlabel('Condiciones y Parámetros')
plt.yticks(np.arange(0, 101, 10))

plt.ylabel('MAE (Hz)')
plt.title('MAE in FREQUENCY (SNR = ' + str(SNR) +')', fontsize=18)
plt.ylim(0,100)
plt.legend(fontsize=14)
plt.grid(True)

# Ajustar el layout para mejorar visualización de etiquetas
plt.tight_layout()

# Mostrar el gráfico
plt.show()

## MAE IN TIME

In [None]:
params_stft = windows_ms

# Número de condiciones
num_conditions = len(mae_wavelet_t)
x = np.arange(num_conditions)  # Posiciones para cada grupo de barras
bar_width = 0.25  # Ancho de las barras

# Crear el gráfico de barras
plt.figure(figsize=(14, 7))

# Graficar cada técnica con sus valores y barras de error
plt.bar(x - bar_width, mae_stft_t, width=bar_width, yerr=std_stft_t, capsize=5, label='STFT', color='skyblue', alpha=0.7)
plt.bar(x, mae_wavelet_t, width=bar_width, yerr=std_wavelet_t, capsize=5, label='Wavelet', color='lightgreen', alpha=0.7)
plt.bar(x + bar_width, mae_superlet_t, width=bar_width, yerr=std_superlet_t, capsize=5, label='Superlet', color='salmon', alpha=0.7)

# Configuración de etiquetas en el eje x con los parámetros específicos
labels = [
    f"W={p_stft} ms\n\nc={p_w}\n\nc$_1$={p_s[0]}, o: {p_s[1]}-{p_s[2]}"
    for p_w, p_s, p_stft in zip(params_wavelet, zip(*params_superlet), params_stft)
]

plt.xticks(x, labels, rotation=90, ha='center', fontsize=18)
#plt.xlabel('Condiciones y Parámetros')
plt.yticks(np.arange(0, 0.21, 0.1))

plt.ylabel('MAE (s)')
plt.title('MAE in TIME (SNR = ' + str(SNR) +')', fontsize=18)
plt.ylim(0,0.2)
plt.legend(fontsize=14)
plt.grid(True)

# Ajustar el layout para mejorar visualización de etiquetas
plt.tight_layout()

# Mostrar el gráfico
plt.show()

# COMPUTE RESOLUTION

In [None]:
rayleigh_limit = 1 / (4 * np.pi)

# Etiquetas para los grupos de parámetros con salto de línea adicional
labels = [
    f"W={p_stft} ms\nc={p_wavelet}\nc$_1$={p_superlet[0]}, o={p_superlet[1]}-{p_superlet[2]}"
    for p_wavelet, p_superlet, p_stft in zip(params_wavelet, zip(*params_superlet), params_stft)
]

x = np.arange(1,len(params_wavelet)+1)

plt.figure(figsize=(12, 8))

# Primer gráfico (Frecuencia)
plt.subplot(2, 1, 1)
plt.plot(x, [r[1] for r in res_wavelet], 'go-', label="Wavelet (Frequency)")
plt.plot(x, [r[1] for r in res_stft], 'bo-', label="STFT (Frequency)")
plt.plot(x, [r[1] for r in res_superlet], 'ro-', label="Superlet (Frequency)")
plt.axhline(rayleigh_limit, color='k', linestyle='--', label="Rayleigh Limit")
plt.xlabel("Window Size / Cycles / Order")
plt.ylabel("Frequency Resolution (Hz)")
plt.legend()

# Segundo gráfico (Tiempo)
plt.subplot(2, 1, 2)
plt.plot(x, [r[0] for r in res_wavelet], 'go-', label="Wavelet (Time)")
plt.plot(x, [r[0] for r in res_stft], 'bo-', label="STFT (Time)")
plt.plot(x, [r[0] for r in res_superlet], 'ro-', label="Superlet (Time)")
#plt.axhline(1 / MNF_1[0], color='k', linestyle='--', label="Temporal Limit")
plt.xlabel("Window Size / Cycles / Order")
plt.ylabel("Time Resolution (s)")
plt.legend()

# Ajuste de los xticks para ambos subgráficos
for ax in plt.gcf().get_axes():
    ax.set_xticks(x)  # Establecer los valores de los ticks
    ax.set_xticklabels(labels, rotation=90, ha='center', fontsize=10)  # Establecer las etiquetas con el formato deseado

# Ajustar el diseño
plt.tight_layout()
plt.show()

In [None]:
MAES_f

In [None]:
MAES_f['wavelet'].append(mae_wavelet_f)
MAES_f['stft'].append(mae_stft_f)
MAES_f['superlet'].append(mae_superlet_f)
MAES_f['std_wavelet'].append(std_wavelet_f)
MAES_f['std_stft'].append(std_stft_f)
MAES_f['std_superlet'].append(std_superlet_f)

In [None]:
MAES_t

In [None]:
MAES_t['wavelet'].append(mae_wavelet_t)
MAES_t['stft'].append(mae_stft_t)
MAES_t['superlet'].append(mae_superlet_t)
MAES_t['std_wavelet'].append(std_wavelet_t)
MAES_t['std_stft'].append(std_stft_t)
MAES_t['std_superlet'].append(std_superlet_t)

# WE REPEAT EVERYTHING WITH DIFFERENT SNR

In [None]:
if os.path.exists(filename):
    burst_signal = np.load(filename) + base_signal
else:
    #burst_signal = np.random.normal(0, 0.08, len(t))
    burst_signal = np.array([0.00001] * len(t))
    burst_signal[int(burst_start * fs):int(burst_start * fs + len(burst))] += burst
    #np.save('/Users/neuralrehabilitationgroup/PycharmProjects/Superlets-Marina/RESULTS/signal_burst_1.npy', burst_signal)

plt.plot(noise)
plt.figure()
plt.plot(burst_signal)

In [None]:
SNR = 100
# if 'SNR' not in globals():
#     SNR = "Inf"
signal_3 = burst_signal
signal_3, xn = superlet.add_noise(burst_signal, SNR, fs, plot = True)
#signal_2 = superlet.add_wgn_to_sig(burst_signal, noise, SNR)
snr_real = calcular_snr(signal_3, xn)

In [None]:
snr_real

In [None]:
plt.plot(signal_1)
plt.plot(signal_2)
plt.plot(signal_3)

In [None]:
fig, ax = plt.subplots(figsize=(15, 2), dpi=300)
ax.set_xlabel("Time (s)")
ax.plot(jnp.linspace(0, len(signal_2) / fs, len(signal_2)), signal_2)
print(f"Reference mean frequency: {MNF_1} Hz")
plt.plot(t, signal_2)

# WAVELET PARAMETERS = 3, 16, 33, 55, 60 115 cycles

In [None]:
plot_scalogram = False

In [None]:
if plot_scalogram:
    fig, ax = plt.subplots(ncols=len(params_wavelet), figsize=(18,5), dpi=300, sharey=True, sharex=True)
if plot_imnf:
    fig_2, ax_2 = plt.subplots(ncols=len(params_wavelet), figsize=(18, 5), dpi=300)
if plot_response:
    fig_3, (ax_3_1, ax_3_2) = plt.subplots(1, 2, figsize=(12, 5))  

mae_wavelet_f = []
std_wavelet_f = []

mae_wavelet_t = []
std_wavelet_t = []

total_time = []
total_freq = []

total_scalogram_wavelet = []
res_wavelet = []

for i, c in enumerate(params_wavelet):
    wavelet = f'cmor{c}-1.0'
    
    cwtmatr, f, physical_freqs = morlet.wavelet_transform_2(signal_2,wavelet,freqs,fs)
    cwtmatr = np.abs(cwtmatr[:,:])
    
    scalogram_2 = np.array(jnp.abs(cwtmatr)**2)
    
    if plot_scalogram:
        pcm = ax[i].pcolormesh(t,f, scalogram_2, shading='gouraud', cmap='jet')
        
        ax[i].set_title(f'c={c}')
        if i == 0:
            ax[i].set_ylabel('Frequency (Hz)')
        else:
            ax[i].set_ylabel('')  # Eliminar la etiqueta del eje Y en los demás subplots
    
        ax[i].set_xlabel('Time (s)')
        ax[i].set_ylim(f[0], f[-1])
        
    total_scalogram_wavelet.append(scalogram_2)
    instant_freq = np.sum(scalogram_2 * f[:, np.newaxis], axis=0) / np.sum(scalogram_2, axis=0)
    
    energia_temporal = np.sum(scalogram_2, axis=0)
    
    first_index, last_index = superlet.find_plateau_region(energia_temporal, burst_start, burst_duration, t, fs, smooth_sigma=smooth_sigma)
    
    t_onset = t[first_index]
    t_offset = t[last_index]
    
    instant_mean_freq_burst = instant_freq[start_idx:end_idx]
    
    if plot_imnf:
        ax_2[i].plot(t, instant_freq)
    
        ax_2[i].set_xlabel("Time (s)")
        ax_2[i].set_ylabel("Frequency (Hz)")
        #ax_2[i].set_xlim(t[0], t[-1])
        plt.title("Instantaneous frequency")
        ax_2[i].grid(True)

    if plot_response:
        ax_3_1.plot(t, instant_freq, label=f"c={c}")
        ax_3_1.set_xlabel("Time (s)")
        ax_3_1.set_ylabel("Frequency (Hz)")
        ax_3_1.legend()
        ax_3_1.grid(True)
        ax_3_1.set_title("Instantaneous frequency")
        
        ax_3_2.plot(t, energia_temporal, label=f"c={c}")
        ax_3_2.set_xlabel("Time (s)")
        ax_3_2.set_ylabel("Energy (V²)")
        ax_3_2.legend()
        ax_3_2.grid(True)
        ax_3_2.set_title("Energy over time")
        
    #MAE in frequency
    print(f'Estimated mean frequency: {np.mean(instant_mean_freq_burst)} Hz')
    mae_f, std_f = superlet.calculate_mae(instant_mean_freq_burst, MNF_1)
    mae_wavelet_f.append(mae_f)
    std_wavelet_f.append(std_f)
    
    #MAE in time
    print(f'Estimated onset = {t_onset} s, offset = {t_offset} s')
    mae_t, std_t = superlet.calculate_mae((t_onset,t_offset), (real_t_onset,real_t_offset))
    mae_wavelet_t.append(mae_t)
    std_wavelet_t.append(std_t)
    
    # Visualization
    plt.figure(figsize=(8, 5))

    plt.subplot(2, 1, 1)
    plt.plot(t, instant_freq)
    #plt.plot(t, media_movil)
    plt.axvline(t_onset, color='green', linestyle='--', label='t_onset')
    plt.axvline(t_offset, color='red', linestyle='--', label='t_offset')
    plt.xlabel('Time (s)')
    plt.ylabel('Freq (Hz)')
    plt.legend()
    plt.grid()

    plt.tight_layout()
    
    time = np.arange(scalogram_2.shape[1]) / fs

    total_time.append(time)
    total_freq.append(freqs)
    
res_wavelet = superlet.compute_avg_response_resolution(total_scalogram_wavelet, total_time, total_freq, params_wavelet)

if plot_scalogram:
    cbar = fig.colorbar(pcm, ax=ax, orientation='horizontal', pad=0.15, shrink=0.1)
    cbar.set_label('Power (V²)', rotation=0, labelpad=15)

In [None]:
mae_wavelet_f

In [None]:
mae_wavelet_t

## PARÁMETROS SUPERLET: base_cycle, min_order, max_order = [3, 5, 1, 1, 1, 1], [1, 1, 5, 10, 20, 30], [30, 30, 40, 100, 100, 200]

In [None]:
if plot_scalogram:
    fig, ax = plt.subplots(ncols=len(params_wavelet), figsize=(18,5), dpi=300, sharey=True, sharex=True)
if plot_imnf:
    fig_2, ax_2 = plt.subplots(ncols=len(params_wavelet), figsize=(18, 5), dpi=300)
if plot_response:
    fig_3, (ax_3_1, ax_3_2) = plt.subplots(1, 2, figsize=(12, 5))  

mae_superlet_f = []
std_superlet_f = []

mae_superlet_t = []
std_superlet_t = []

total_time = []
total_freq = []

total_scalogram_superlet = []

for (i, (base_cycle, min_order, max_order)) in enumerate(zip(*params_superlet)):

    wv, scalogram = superlet.adaptive_superlet_transform(signal_2, freqs, sampling_freq=fs,
                                                         base_cycle=base_cycle, min_order=min_order,
                                                         max_order=max_order, mode="mul")
    
    if plot_scalogram:
        im = ax[i].imshow(jnp.abs(scalogram) ** 2, aspect='auto', cmap="jet", interpolation="none", origin="lower",extent=[0, len(signal_1) / fs, freqs[0], freqs[-1]])
        ax[i].set_title(f"$c_1$: {base_cycle}, o: {min_order}-{max_order}")
        ax[i].set_xlabel("Time (s)")
        if i == 0:
            ax[i].set_ylabel("Frequency (Hz)")
        else:
            ax[i].set_ylabel("")
        ax[i].set_ylim(freqs[0], freqs[-1])

    scalogram_2 = np.abs(scalogram) ** 2

    total_scalogram_superlet.append(scalogram_2)
    
    instant_freq = np.sum(scalogram_2 * freqs[:, np.newaxis], axis=0) / np.sum(scalogram_2, axis=0)
    if plot_means:
        plt.figure()
        plt.plot(instant_freq)
        plt.title('Frecuencia instantánea')
        
    energia_temporal = np.sum(scalogram_2, axis=0)
        
    first_index, last_index = superlet.find_plateau_region(energia_temporal, burst_start, burst_duration, t, fs, smooth_sigma=smooth_sigma)
    
    t_onset = t[first_index]
    t_offset = t[last_index]
    
    instant_mean_freq_burst = instant_freq[start_idx:end_idx]
    
    if plot_imnf:
        ax_2[i].plot(t, instant_freq)

        ax_2[i].set_xlabel("Time (s)")
        ax_2[i].set_ylabel("Frequency (Hz)")
        ax_2[i].set_title("Instantaneous frequency")
        ax_2[i].grid(True)

    if plot_response:
        ax_3_1.plot(t, instant_freq, label=f"$c_1$: {base_cycle}, o: {min_order}-{max_order}")
        ax_3_1.set_xlabel("Time (s)")
        ax_3_1.set_ylabel("Frequency (Hz)")
        ax_3_1.legend()
        ax_3_1.grid(True)
        ax_3_1.set_title("Instantaneous frequency")
        
        ax_3_2.plot(t, energia_temporal, label=f"$c_1$: {base_cycle}, o: {min_order}-{max_order}")
        ax_3_2.set_xlabel("Time (s)")
        ax_3_2.set_ylabel("Energy (V²)")
        ax_3_2.legend()
        ax_3_2.grid(True)
        ax_3_2.set_title("Energy over time")
    
    #MAE in frequency
    print(f'Estimated mean frequency: {np.mean(instant_mean_freq_burst)} Hz')
    mae_f, std_f = superlet.calculate_mae(instant_mean_freq_burst, MNF_1)
    mae_superlet_f.append(mae_f)
    std_superlet_f.append(std_f)
    
    #MAE in time
    print(f'Estimated onset = {t_onset} s, offset = {t_offset} s')
    # estimates_combined = t_offset - t_onset
    # print(f'Estimated difference: {estimates_combined} s')
    # true_values_combined = t_offset_1 - t_onset_1
    # print(f'Real difference: {true_values_combined} s')

    mae_t, std_t = superlet.calculate_mae((t_onset,t_offset), (real_t_onset,real_t_offset))
    mae_superlet_t.append(mae_t)
    std_superlet_t.append(std_t)
    
    # Visualization
    plt.figure(figsize=(12, 6))

    plt.subplot(2, 1, 1)
    plt.plot(t, instant_freq)
    plt.axvline(t_onset, color='green', linestyle='--', label='t_onset')
    plt.axvline(t_offset, color='red', linestyle='--', label='t_offset')
    plt.xlabel('Time (s)')
    plt.ylabel('Freq (Hz)')
    plt.legend()
    plt.grid()
    
    time = np.arange(scalogram_2.shape[1]) / fs

    total_time.append(time)
    total_freq.append(freqs)

    total_time.append(t)
    total_freq.append(freqs)

res_superlet = superlet.compute_avg_response_resolution(total_scalogram_superlet, total_time, total_freq, [f'$c_1$={c3}, o:{c1}-{c2}' for c1, c2, c3 in zip(*params_superlet)])

if plot_scalogram:
    cbar = fig.colorbar(im, ax=ax, orientation='horizontal', pad=0.15, shrink=0.1)
    cbar.set_label('Power (V²)', rotation=0, labelpad=15)


In [None]:
mae_superlet_f

In [None]:
mae_superlet_t

## STFT PARAMETERS (window = duration/cycles) = 38, 200, 413, 550, 600, 824 ms

In [None]:
if plot_scalogram:
    fig, ax = plt.subplots(ncols=len(params_wavelet), figsize=(18,5), dpi=300, sharey=True, sharex=True)
if plot_imnf:
    fig_2, ax_2 = plt.subplots(ncols=len(params_wavelet), figsize=(18, 5), dpi=300)
if plot_response:
    fig_3, (ax_3_1, ax_3_2) = plt.subplots(1, 2, figsize=(12, 5))  

mae_stft_f = []
std_stft_f = []

mae_stft_t = []
std_stft_t = []

total_time = []
total_freq = []

res_stft = []
total_scalogram_stft = []

for i, w in enumerate(windows_ms):
    # Calcular la ventana Blackman
    window = windows.blackman(w)
    
    # Hacer que la f de la stft tenga el mismo tamaño que las demás
    nfft = 2 * (len(freqs))

    f, t_stft, Zxx = stft(signal_2, fs=fs, window='blackman', nperseg=w, noverlap=w-1, nfft=nfft, detrend=False, return_onesided=True, boundary='zeros', padded=True, axis=-1, scaling='spectrum')

    # Calcular la representación de potencia (magnitud al cuadrado)
    Zxx_power = np.abs(Zxx) ** 2
    
    if plot_scalogram:
        pcm = ax[i].pcolormesh(t_stft, f, Zxx_power, shading='gouraud', cmap='jet')
        
        ax[i].set_title(f'W={w * 1000 // fs} ms')
        if i == 0:
            ax[i].set_ylabel('Frecuencia [Hz]')
        else: 
            ax[i].set_ylabel('')
        ax[i].set_xlabel('Time [s]')

        ax[i].set_ylim(f[0], f[-1])
    
    
    scalogram_2 = Zxx_power
    
    #if np.any(np.isnan(jnp.abs(Zxx)**2)) or np.any(np.isinf(jnp.abs(Zxx)**2)):
    #scalogram_2 = jnp.nan_to_num(jnp.abs(Zxx)**2, nan=0.0, posinf=0.0, neginf=0.0)
    
    total_scalogram_stft.append(scalogram_2)
    
    # Para evitar la división por cero
    denominador = np.sum(scalogram_2, axis=0)
    # denominador = np.where(denominador == 0, 1e-10, denominador)
    instant_freq = np.sum(scalogram_2 * f[:, np.newaxis], axis=0) / denominador
    
    energia_temporal = np.sum(scalogram_2, axis=0)
    
    first_index, last_index = superlet.find_plateau_region(energia_temporal, burst_start, burst_duration, t, fs, smooth_sigma=smooth_sigma)
    
    t_onset = t[first_index]
    t_offset = t[last_index]

    instant_mean_freq_burst = instant_freq[start_idx:end_idx]

    if plot_imnf:
        ax_2[i].plot(t_stft, instant_freq)

        ax_2[i].set_xlabel("Time (s)")
        ax_2[i].set_ylabel("Frequency (Hz)")
        plt.title("Instantaneous frequency")
        ax_2[i].grid(True)

    if plot_response:
        ax_3_1.plot(t_stft, instant_freq, label=f'W={w * 1000 // fs} ms')
        ax_3_1.set_xlabel("Time (s)")
        ax_3_1.set_ylabel("Frequency (Hz)")
        ax_3_1.legend()
        ax_3_1.grid(True)
        ax_3_1.set_title("Instantaneous frequency")
        
        ax_3_2.plot(t_stft, energia_temporal, label=f'W={w * 1000 // fs} ms')
        ax_3_2.set_xlabel("Time (s)")
        ax_3_2.set_ylabel("Energy (V²)")
        ax_3_2.legend()
        ax_3_2.grid(True)
        ax_3_2.set_title("Energy over time")

    #MAE in frequency
    print(f'Estimated mean frequency: {np.mean(instant_mean_freq_burst)} Hz')
    mae_f, std_f = superlet.calculate_mae(instant_mean_freq_burst, MNF_1)
    mae_stft_f.append(mae_f)
    std_stft_f.append(std_f)
    
    #MAE in time
    print(f'Estimated onset = {t_onset} s, offset = {t_offset} s')
    # estimates_combined = t_offset - t_onset
    # print(f'Estimated difference: {estimates_combined} s')
    # true_values_combined = t_offset_1 - t_onset_1
    # print(f'Real difference: {true_values_combined} s')

    mae_t, std_t = superlet.calculate_mae((t_onset,t_offset), (real_t_onset,real_t_offset))
    mae_stft_t.append(mae_t)
    std_stft_t.append(std_t)
    
    # Visualization
    plt.figure(figsize=(12, 6))

    plt.subplot(2, 1, 1)
    plt.plot(t_stft, instant_freq)
    plt.axvline(t_onset, color='green', linestyle='--', label='t_onset')
    plt.axvline(t_offset, color='red', linestyle='--', label='t_offset')
    plt.xlabel('Time (s)')
    plt.ylabel('Freq (Hz)')
    plt.legend()
    plt.grid()

    plt.tight_layout()
    
    total_time.append(np.linspace(t_stft[0], t_stft[-1], scalogram_2.shape[1]))
    total_freq.append(np.linspace(f[0], f[-1], scalogram_2.shape[0]))

res_stft = superlet.compute_avg_response_resolution(total_scalogram_stft, total_time, total_freq, windows_ms, stft=[freqs[0],freqs[-1]])

if plot_scalogram:
    cbar = fig.colorbar(pcm, ax=ax, orientation='horizontal', pad=0.15, shrink=0.1)
    cbar.set_label('Power (V²)', rotation=0, labelpad=15)

In [None]:
mae_stft_f

In [None]:
mae_stft_t

# Plot comparison of the same SNR

## MAE IN FREQUENCY

In [None]:
params_stft = windows_ms

# Número de condiciones
num_conditions = len(mae_wavelet_f)
x = np.arange(num_conditions)  # Posiciones para cada grupo de barras
bar_width = 0.25  # Ancho de las barras

# Crear el gráfico de barras
plt.figure(figsize=(14, 7))

# Graficar cada técnica con sus valores y barras de error
plt.bar(x - bar_width, mae_stft_f, width=bar_width, yerr=std_stft_f, capsize=5, label='STFT', color='skyblue', alpha=0.7)
plt.bar(x, mae_wavelet_f, width=bar_width, yerr=std_wavelet_f, capsize=5, label='Wavelet', color='lightgreen', alpha=0.7)
plt.bar(x + bar_width, mae_superlet_f, width=bar_width, yerr=std_superlet_f, capsize=5, label='Superlet', color='salmon', alpha=0.7)

# Configuración de etiquetas en el eje x con los parámetros específicos
labels = [
    f"W={p_stft} ms\n\nc={p_w}\n\nc$_1$={p_s[0]}, o: {p_s[1]}-{p_s[2]}"
    for p_w, p_s, p_stft in zip(params_wavelet, zip(*params_superlet), params_stft)
]

plt.xticks(x, labels, rotation=90, ha='center', fontsize=18)
#plt.xlabel('Condiciones y Parámetros')
plt.yticks(np.arange(0, 101, 10))

plt.ylabel('MAE (Hz)')
plt.title('MAE in FREQUENCY (SNR = ' + str(SNR) +')', fontsize=18)
plt.ylim(0,100)
plt.legend(fontsize=14)
plt.grid(True)

# Ajustar el layout para mejorar visualización de etiquetas
plt.tight_layout()

# Mostrar el gráfico
plt.show()

## MAE IN TIME

In [None]:
params_stft = windows_ms

# Número de condiciones
num_conditions = len(mae_wavelet_t)
x = np.arange(num_conditions)  # Posiciones para cada grupo de barras
bar_width = 0.25  # Ancho de las barras

# Crear el gráfico de barras
plt.figure(figsize=(14, 7))

# Graficar cada técnica con sus valores y barras de error
plt.bar(x - bar_width, mae_stft_t, width=bar_width, yerr=std_stft_t, capsize=5, label='STFT', color='skyblue', alpha=0.7)
plt.bar(x, mae_wavelet_t, width=bar_width, yerr=std_wavelet_t, capsize=5, label='Wavelet', color='lightgreen', alpha=0.7)
plt.bar(x + bar_width, mae_superlet_t, width=bar_width, yerr=std_superlet_t, capsize=5, label='Superlet', color='salmon', alpha=0.7)

# Configuración de etiquetas en el eje x con los parámetros específicos
labels = [
    f"W={p_stft} ms\n\nc={p_w}\n\nc$_1$={p_s[0]}, o: {p_s[1]}-{p_s[2]}"
    for p_w, p_s, p_stft in zip(params_wavelet, zip(*params_superlet), params_stft)
]

plt.xticks(x, labels, rotation=90, ha='center', fontsize=18)
#plt.xlabel('Condiciones y Parámetros')
plt.yticks(np.arange(0, 0.21, 0.1))

plt.ylabel('MAE (s)')
plt.title('MAE in TIME (SNR = ' + str(SNR) +')', fontsize=18)
plt.ylim(0,0.2)
plt.legend(fontsize=14)
plt.grid(True)

# Ajustar el layout para mejorar visualización de etiquetas
plt.tight_layout()

# Mostrar el gráfico
plt.show()

# COMPUTE RESOLUTION

In [None]:
rayleigh_limit = 1 / (4 * np.pi)

# Etiquetas para los grupos de parámetros con salto de línea adicional
labels = [
    f"W={p_stft} ms\nc={p_wavelet}\nc$_1$={p_superlet[0]}, o={p_superlet[1]}-{p_superlet[2]}"
    for p_wavelet, p_superlet, p_stft in zip(params_wavelet, zip(*params_superlet), params_stft)
]

x = np.arange(1,len(params_wavelet)+1)

plt.figure(figsize=(12, 8))

# Primer gráfico (Frecuencia)
plt.subplot(2, 1, 1)
plt.plot(x, [r[1] for r in res_wavelet], 'go-', label="Wavelet (Frequency)")
plt.plot(x, [r[1] for r in res_stft], 'bo-', label="STFT (Frequency)")
plt.plot(x, [r[1] for r in res_superlet], 'ro-', label="Superlet (Frequency)")
plt.axhline(rayleigh_limit, color='k', linestyle='--', label="Rayleigh Limit")
plt.xlabel("Window Size / Cycles / Order")
plt.ylabel("Frequency Resolution (Hz)")
plt.legend()

# Segundo gráfico (Tiempo)
plt.subplot(2, 1, 2)
plt.plot(x, [r[0] for r in res_wavelet], 'go-', label="Wavelet (Time)")
plt.plot(x, [r[0] for r in res_stft], 'bo-', label="STFT (Time)")
plt.plot(x, [r[0] for r in res_superlet], 'ro-', label="Superlet (Time)")
#plt.axhline(1 / MNF_1[0], color='k', linestyle='--', label="Temporal Limit")
plt.xlabel("Window Size / Cycles / Order")
plt.ylabel("Time Resolution (s)")
plt.legend()

# Ajuste de los xticks para ambos subgráficos
for ax in plt.gcf().get_axes():
    ax.set_xticks(x)  # Establecer los valores de los ticks
    ax.set_xticklabels(labels, rotation=90, ha='center', fontsize=10)  # Establecer las etiquetas con el formato deseado

# Ajustar el diseño
plt.tight_layout()
plt.show()

In [None]:
MAES_f

In [None]:
MAES_f['wavelet'].append(mae_wavelet_f)
MAES_f['stft'].append(mae_stft_f)
MAES_f['superlet'].append(mae_superlet_f)
MAES_f['std_wavelet'].append(std_wavelet_f)
MAES_f['std_stft'].append(std_stft_f)
MAES_f['std_superlet'].append(std_superlet_f)

In [None]:
MAES_t

In [None]:
MAES_t['wavelet'].append(mae_wavelet_t)
MAES_t['stft'].append(mae_stft_t)
MAES_t['superlet'].append(mae_superlet_t)
MAES_t['std_wavelet'].append(std_wavelet_t)
MAES_t['std_stft'].append(std_stft_t)
MAES_t['std_superlet'].append(std_superlet_t)

### COMPARISON BETWEEN ALL SNR

# COMPARISON OF MAE IN FREQUENCY

In [None]:
params_wavelet = params_wavelet
params_superlet = params_superlet
params_stft = windows_ms  # Tamaños de ventana en ms

# Datos MAE para cada técnica y SNR
mae_stft = MAES_f['stft']
std_stft = MAES_f['std_stft']

mae_wavelet = MAES_f['wavelet']
std_wavelet = MAES_f['std_wavelet']

mae_superlet = MAES_f['superlet']
std_superlet = MAES_f['std_superlet']

# Reordenar los niveles de SNR al orden deseado: 5 dB, 20 dB, Sin SNR
snr_labels = ["5 dB", "15 dB", "100 dB"]
# mae_stft_2 = [mae_stft[1], mae_stft[0]]
mae_stft_2 = mae_stft
# std_stft_2 = [std_stft[1], std_stft[0]]
std_stft_2 = std_stft

# mae_wavelet_2 = [mae_wavelet[1], mae_wavelet[0]]
mae_wavelet_2 = mae_wavelet
# std_wavelet_2 = [std_wavelet[1], std_wavelet[0]]
std_wavelet_2 = std_wavelet

# mae_superlet_2 = [mae_superlet[1], mae_superlet[0]]
mae_superlet_2 = mae_superlet
# std_superlet_2 = [std_superlet[1], std_superlet[0]]
std_superlet_2 = std_superlet

# Configuración del gráfico
num_params = len(params_stft)  # Número de parámetros
num_snr = len(snr_labels)  # Número de niveles de ruido
x = np.arange(num_params)  # Posiciones base para cada grupo de parámetros
bar_width = 0.1  # Ancho de cada barra
group_width = bar_width * num_snr  # Espacio reservado para cada grupo de barras

# Tonalidades por nivel de ruido (ordenadas)
tonalidades = {
    "5 dB": 0.4,
    "15 dB": 0.6,
    "100 dB": 0.8,
}

# Colores base por metodología (asegurando correspondencia)
colores_base = {
    "STFT": plt.colormaps["Blues"],
    "Wavelet": plt.colormaps["Greens"],
    "Superlet": plt.colormaps["Oranges"],
}

# Crear el gráfico
plt.figure(figsize=(18, 8))

# Crear listas para la leyenda
handles = []
labels_legend = []

# Dibujar barras para cada técnica y SNR
for i, snr in enumerate(snr_labels):  # Iterar sobre niveles de SNR
    # Dibujar barras para STFT
    bar = plt.bar(
        x - 0.01 - group_width + i * bar_width,
        [np.array(mae_stft_2[i])[k] for k in range(num_params)],
        width=bar_width,
        color=colores_base["STFT"](tonalidades[snr]),
        label=f'STFT ({snr})'
    )
    handles.append(bar[0])

    plt.errorbar(
        x - 0.01 - group_width + i * bar_width,
        [np.array(mae_stft_2[i])[k] for k in range(num_params)],
        yerr=[np.array(std_stft_2[i])[k] for k in range(num_params)],
        fmt='none',
        ecolor='black',
        capsize=3
    )
    
    # Dibujar barras para Wavelet
    bar = plt.bar(
        x + i * bar_width,
        [np.array(mae_wavelet_2[i])[k] for k in range(num_params)],
        width=bar_width,
        color=colores_base["Wavelet"](tonalidades[snr]),
        label=f'Wavelet ({snr})'
    )
    handles.append(bar[0])

    plt.errorbar(
        x + i * bar_width,
        [np.array(mae_wavelet_2[i])[k] for k in range(num_params)],
        yerr=[np.array(std_wavelet_2[i])[k] for k in range(num_params)],
        fmt='none',
        ecolor='black',
        capsize=3
    )

    # Dibujar barras para Superlet
    bar = plt.bar(
        x + 0.01 + group_width + i * bar_width,
        [np.array(mae_superlet_2[i])[k] for k in range(num_params)],
        width=bar_width,
        color=colores_base["Superlet"](tonalidades[snr]),
        label=f'Superlet ({snr})'
    )
    handles.append(bar[0])

    plt.errorbar(
        x + 0.01 + group_width + i * bar_width,
        [np.array(mae_superlet_2[i])[k] for k in range(num_params)],
        yerr=[np.array(std_superlet_2[i])[k] for k in range(num_params)],
        fmt='none',
        ecolor='black',
        capsize=3
    )

# Ordenar la leyenda agrupada por metodología
labels_legend = [f'STFT ({snr})' for snr in snr_labels] + \
                [f'Wavelet ({snr})' for snr in snr_labels] + \
                [f'Superlet ({snr})' for snr in snr_labels]
handles = handles[:len(labels_legend)]

# Etiquetas para los grupos de parámetros con salto de línea adicional
labels = [
    f"W={p_stft} ms\n\nc={p_wavelet}\n\n\nc$_1$={p_superlet[0]}, o={p_superlet[1]}-{p_superlet[2]}"
    for p_wavelet, p_superlet, p_stft in zip(params_wavelet, zip(*params_superlet), params_stft)
]

from matplotlib.patches import Patch

leyenda_snr = [
    Patch(color=colores_base["STFT"](tonalidades["5 dB"]), label="STFT (SNR = 5 dB)"),
    Patch(color=colores_base["STFT"](tonalidades["15 dB"]), label="STFT (SNR = 15 dB)"),
    Patch(color=colores_base["STFT"](tonalidades["100 dB"]), label="STFT (SNR = 100 dB)"),
    Patch(color=colores_base["Wavelet"](tonalidades["5 dB"]), label="Wavelet (SNR = 5 dB)"),
    Patch(color=colores_base["Wavelet"](tonalidades["15 dB"]), label="Wavelet (SNR = 15 dB)"),
    Patch(color=colores_base["Wavelet"](tonalidades["100 dB"]), label="Wavelet (SNR = 100 dB)"),
    Patch(color=colores_base["Superlet"](tonalidades["5 dB"]), label="Superlet (SNR = 5 dB)"),
    Patch(color=colores_base["Superlet"](tonalidades["15 dB"]), label="Superlet (SNR = 15 dB)"),
    Patch(color=colores_base["Superlet"](tonalidades["100 dB"]), label="Superlet (SNR = 100 dB)"),
]

# Configurar eje X y otros detalles
plt.xticks(x, labels, rotation=90, ha='center', fontsize=18)
plt.yticks(np.arange(0, 71, 10))
plt.ylabel('MAE (Hz)', fontsize=14)
plt.ylim(0, 70)
#plt.legend(handles, labels_legend, fontsize=10, ncol=1)  # Leyenda agrupada
plt.title('Comparison of MAE in frequency according to SNR')

plt.gca().add_artist(plt.legend(handles=leyenda_snr, loc='upper right', fontsize=12, title_fontsize=14))

plt.tight_layout()
plt.grid(True)

# Mostrar el gráfico
plt.show()

# COMPARISON OF MAE IN TIME

In [None]:
params_wavelet = params_wavelet
params_superlet = params_superlet
params_stft = windows_ms  # Tamaños de ventana en ms

# Datos MAE para cada técnica y SNR
mae_stft = MAES_t['stft']
std_stft = MAES_t['std_stft']

mae_wavelet = MAES_t['wavelet']
std_wavelet = MAES_t['std_wavelet']

mae_superlet = MAES_t['superlet']
std_superlet = MAES_t['std_superlet']

# Reordenar los niveles de SNR al orden deseado: 5 dB, 20 dB, Sin SNR
# snr_labels = ["5 dB", "20 dB"]
snr_labels = ["5 dB", "15 dB", "100 dB"]
#mae_stft_2 = [mae_stft[1], mae_stft[0]]
mae_stft_2 = mae_stft
std_stft_2 = std_stft

#mae_wavelet_2 = [mae_wavelet[1], mae_wavelet[0]]
mae_wavelet_2 = mae_wavelet
std_wavelet_2 = std_wavelet

#mae_superlet_2 = [mae_superlet[1], mae_superlet[0]]
mae_superlet_2 = mae_superlet
std_superlet_2 = std_superlet

# Configuración del gráfico
num_params = len(params_stft)  # Número de parámetros
num_snr = len(snr_labels)  # Número de niveles de ruido
x = np.arange(num_params)  # Posiciones base para cada grupo de parámetros
bar_width = 0.1  # Ancho de cada barra
group_width = bar_width * num_snr  # Espacio reservado para cada grupo de barras

# Tonalidades por nivel de ruido (ordenadas)
tonalidades = {
    "5 dB": 0.4,
    "15 dB": 0.6,
    "100 dB": 0.8
}

# Colores base por metodología (asegurando correspondencia)
colores_base = {
    "STFT": plt.colormaps["Blues"],
    "Wavelet": plt.colormaps["Greens"],
    "Superlet": plt.colormaps["Oranges"],
}

# Crear el gráfico
plt.figure(figsize=(18, 8))

# Crear listas para la leyenda
handles = []
labels_legend = []

# Dibujar barras para cada técnica y SNR
for i, snr in enumerate(snr_labels):  # Iterar sobre niveles de SNR
    # Dibujar barras para STFT
    bar = plt.bar(
        x - 0.01 - group_width + i * bar_width,
        [np.array(mae_stft_2[i])[k] for k in range(num_params)],
        width=bar_width,
        color=colores_base["STFT"](tonalidades[snr]),
        label=f'STFT ({snr})'
    )
    handles.append(bar[0])

    plt.errorbar(
        x - 0.01 - group_width + i * bar_width,
        [np.array(mae_stft_2[i])[k] for k in range(num_params)],
        yerr=[np.array(std_stft_2[i])[k] for k in range(num_params)],
        fmt='none',
        ecolor='black',
        capsize=3
    )
    
    # Dibujar barras para Wavelet
    bar = plt.bar(
        x + i * bar_width,
        [np.array(mae_wavelet_2[i])[k] for k in range(num_params)],
        width=bar_width,
        color=colores_base["Wavelet"](tonalidades[snr]),
        label=f'Wavelet ({snr})'
    )
    handles.append(bar[0])

    plt.errorbar(
        x + i * bar_width,
        [np.array(mae_wavelet_2[i])[k] for k in range(num_params)],
        yerr=[np.array(std_wavelet_2[i])[k] for k in range(num_params)],
        fmt='none',
        ecolor='black',
        capsize=3
    )

    # Dibujar barras para Superlet
    bar = plt.bar(
        x + 0.01 + group_width + i * bar_width,
        [np.array(mae_superlet_2[i])[k] for k in range(num_params)],
        width=bar_width,
        color=colores_base["Superlet"](tonalidades[snr]),
        label=f'Superlet ({snr})'
    )
    handles.append(bar[0])

    plt.errorbar(
        x + 0.01 + group_width + i * bar_width,
        [np.array(mae_superlet_2[i])[k] for k in range(num_params)],
        yerr=[np.array(std_superlet_2[i])[k] for k in range(num_params)],
        fmt='none',
        ecolor='black',
        capsize=3
    )

# Ordenar la leyenda agrupada por metodología
labels_legend = [f'STFT ({snr})' for snr in snr_labels] + \
                [f'Wavelet ({snr})' for snr in snr_labels] + \
                [f'Superlet ({snr})' for snr in snr_labels]
handles = handles[:len(labels_legend)]

# Etiquetas para los grupos de parámetros con salto de línea adicional
labels = [
    f"W={p_stft} ms\n\nc={p_wavelet}\n\n\nc$_1$={p_superlet[0]}, o={p_superlet[1]}-{p_superlet[2]}"
    for p_wavelet, p_superlet, p_stft in zip(params_wavelet, zip(*params_superlet), params_stft)
]

from matplotlib.patches import Patch

leyenda_snr = [
    Patch(color=colores_base["STFT"](tonalidades["5 dB"]), label="STFT (SNR = 5 dB)"),
    Patch(color=colores_base["STFT"](tonalidades["15 dB"]), label="STFT (SNR = 15 dB)"),
    Patch(color=colores_base["STFT"](tonalidades["100 dB"]), label="STFT (SNR = 100 dB)"),
    Patch(color=colores_base["Wavelet"](tonalidades["5 dB"]), label="Wavelet (SNR = 5 dB)"),
    Patch(color=colores_base["Wavelet"](tonalidades["15 dB"]), label="Wavelet (SNR = 15 dB)"),
    Patch(color=colores_base["Wavelet"](tonalidades["100 dB"]), label="Wavelet (SNR = 100 dB)"),
    Patch(color=colores_base["Superlet"](tonalidades["5 dB"]), label="Superlet (SNR = 5 dB)"),
    Patch(color=colores_base["Superlet"](tonalidades["15 dB"]), label="Superlet (SNR = 15 dB)"),
    Patch(color=colores_base["Superlet"](tonalidades["100 dB"]), label="Superlet (SNR = 100 dB)"),
]

# Configurar eje X y otros detalles
plt.xticks(x, labels, rotation=90, ha='center', fontsize=18)
plt.yticks(np.arange(0, 0.11, 0.1))
plt.ylabel('MAE (s)', fontsize=14)
plt.ylim(0, 0.1)
# Agregar leyenda de SNR arriba del gráfico
#plt.legend(handles=leyenda_snr, title="Tonalidad del color por SNR", loc='upper left', fontsize=12, title_fontsize=14)

# Agregar leyenda de métodos abajo
plt.gca().add_artist(plt.legend(handles=leyenda_snr, loc='upper left', fontsize=12, title_fontsize=14))
#plt.gca().add_artist(plt.legend(handles=leyenda_metodos, title="Colores por Método", loc='upper right', fontsize=12, title_fontsize=14))

plt.title('Comparison of MAE in time according to SNR')

# Mostrar el gráfico
plt.tight_layout()
plt.grid(True)

# Mostrar el gráfico
plt.show()