<a href="https://colab.research.google.com/github/felipednegredo/tcc-emocoes-musicais-codigo/blob/main/TCC_STFT_e_Fingerprint.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Biblotecas

In [None]:
from pathlib import Path
from typing import Optional, Dict, Any, List, Tuple
import io, zlib, os
import numpy as np
import pandas as pd
from IPython.display import Audio, display

# Import pyarrow and pyarrow.dataset
import pyarrow as pa
import pyarrow.dataset as ds


# Limitar threads BLAS quando paralelizar
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")

try:
    import soundfile as sf
    _HAS_SF = True
except Exception:
    _HAS_SF = False

# ---------- utilidades ----------
_DTYPE_ALIASES = {
    None: np.int16,
    "int16": np.int16, "i2": np.int16, "pcm_s16le": np.int16,
    "int32": np.int32, "i4": np.int32, "pcm_s32le": np.int32,
    "float32": np.float32, "f4": np.float32, "pcm_f32le": np.float32,
    "float64": np.float64, "f8": np.float64, "pcm_f64le": np.float64,
}

try:
    import librosa as lb
except Exception as e:
    raise RuntimeError("Instale librosa: pip install librosa") from e

## Utils

In [None]:
def ndarray_to_bytes(arr: np.ndarray, compress: bool = True) -> bytes:
    arr = np.ascontiguousarray(arr)  # mantém dtype (pode ser float32 ou complex64)
    buf = io.BytesIO()
    # np.save guarda shape e dtype junto
    np.save(buf, arr, allow_pickle=False)
    raw = buf.getvalue()
    return zlib.compress(raw, level=6) if compress else raw

In [None]:
def bytes_to_ndarray(b: bytes, compressed: bool = True) -> np.ndarray:
    raw = zlib.decompress(b) if compressed else b
    return np.load(io.BytesIO(raw), allow_pickle=False)

In [None]:
def _read_parquet_safe(p: Path) -> pd.DataFrame:
    try:
        return pd.read_parquet(p)
    except Exception as e:
        print(f"⚠️  Falha ao ler '{p.name}': {e}")
        return pd.DataFrame()

## STFT

### Funções auxiliares

In [None]:
def timeSeries(frequence,startTime, finalTime):
    '''Gera um vetor de tempo (em segundos).
    frequence: taxa de amostragem (Hz)
    startTime: instante inicial (segundos), normalmente 0.0
    duration: duração em segundos
    Retorna: np.ndarray com shape (N,), onde N = floor(sr * duration)
    '''
    T = 1./frequence
    timeSeries = np.arange(startTime,finalTime,T)
    return timeSeries

In [None]:
def genSignal(frequence, amplitude, timeSeries):
    """
    Gera uma senoide amp * sin(2π f t) para um vetor de tempo 'timeSeries' (em segundos).
    Parâmetros:
        frequence (float): frequência em Hz
        amplitude (float): amplitude do sinal
        timeSeries (np.ndarray): série temporal em segundos (shape: (N,))
    Retorna:
        np.ndarray (N,): senoide
    """
    return amplitude * np.sin(2.0 * np.pi * frequence * timeSeries)

In [None]:
def plotSignal(
    signal,
    time,
    titulo: str = "Sinal no tempo",
    figsize: tuple = (900, 350),
    xlim: tuple = None,
    ylim: tuple = None,
    color: str = "royalblue",
    savepath: str = None,
    focus_cycles: int = 5,       # nº de ciclos para zoom inicial
    frequency: float = None      # frequência em Hz -> define período
):
    """
    Plota um sinal (1D) em função do tempo usando Plotly (interativo),
    focando em alguns períodos iniciais se frequency for fornecida.
    Mostra também o período no título.
    """

    signal = np.asarray(signal).squeeze()
    time = np.asarray(time).squeeze()

    if signal.shape != time.shape:
        raise ValueError("signal e time devem ter o mesmo formato")

    # Auto escala Y
    if ylim is None:
        smin, smax = float(np.min(signal)), float(np.max(signal))
        span = smax - smin
        ylim = (smin - 0.1 * span, smax + 0.1 * span) if span > 0 else (-1, 1)

    # Ajusta título se frequência fornecida
    if frequency and frequency > 0:
        period = 1.0 / frequency
        titulo = f"{titulo} • f = {frequency:.2f} Hz • T = {period:.4f} s"
        xlim_auto = (0, focus_cycles * period)
    else:
        period = None
        xlim_auto = xlim

    # Criação do gráfico
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=time, y=signal,
        mode="lines",
        name="Sinal",
        line=dict(color=color)
    ))

    # Layout básico
    fig.update_layout(
        title=titulo,
        xaxis_title="Tempo (s)",
        yaxis_title="Amplitude",
        width=figsize[0],
        height=figsize[1],
        template="plotly_white",
        margin=dict(l=60, r=20, t=50, b=50),
        hovermode="x unified",
        dragmode="zoom"
    )

    # Ajusta zoom inicial (períodos)
    fig.update_xaxes(
        range=xlim_auto,
        showgrid=True,
        zeroline=False,
        rangeslider=dict(visible=True),
        rangeselector=dict(
            buttons=list([
                dict(count=0.1, label="100ms", step="second", stepmode="backward"),
                dict(count=1,   label="1s",    step="second", stepmode="backward"),
                dict(count=5,   label="5s",    step="second", stepmode="backward"),
                dict(step="all",label="Tudo")
            ])
        )
    )

    # Eixo Y
    fig.update_yaxes(range=ylim, showgrid=True, zeroline=False)

    # Exportar ou mostrar
    if savepath:
        if savepath.endswith(".html"):
            fig.write_html(savepath)
        elif savepath.endswith(".png"):
            fig.write_image(savepath)
        else:
            raise ValueError("Formato de arquivo não suportado. Use .html ou .png")
    else:
        fig.show()


In [None]:
def sineWave(frequence, amplitude, timeSeries):
    """
    Gera e plota uma senoide usando genSignal() e plotSignal().

    Parâmetros:
        frequence (float ou int): frequência em Hertz
        amplitude (float): amplitude do sinal
        timeSeries (np.ndarray): série temporal em segundos
    Retorna:
        np.ndarray: sinal gerado
    """
    # Gera a senoide
    signal = genSignal(frequence, amplitude, timeSeries)

    # Plota
    plotSignal(signal, timeSeries, titulo=f"Senoide {frequence} Hz")

    return signal

### Filtro

In [None]:
def bessel_filter(x, sr, type='low', cutoff_freqs=None, order=4, axis=-1):
    """
    Filtro de Bessel com fase zero (filtfilt).
    - type: 'low' | 'high' | 'band' | 'bandstop'
    - cutoff_freqs:
        low/high  -> float (Hz)
        band/stop -> (low_hz, high_hz)
    """
    if sr is None or sr <= 0:
        raise ValueError("sr inválido.")
    if order < 1:
        raise ValueError("order deve ser >= 1")

    x = np.asarray(x, dtype=float)
    nyq, eps = 0.5 * float(sr), 1e-6

    def norm_clamp(hz):
        """Normaliza corte em Hz para (0,1) relativo ao Nyquist, com clamp leve."""
        return min(max(float(hz) / nyq, eps), 1 - eps)

    if type in ('low', 'high'):
        if not isinstance(cutoff_freqs, (int, float, np.floating)):
            raise ValueError("Para 'low'/'high', cutoff_freqs deve ser número (Hz).")
        Wn = norm_clamp(cutoff_freqs)
        btype = 'low' if type == 'low' else 'high'
        b, a = bessel(order, Wn, btype=btype, analog=False, norm='phase')

    elif type in ('band', 'bandpass', 'band-stop', 'bandstop'):
        if not (isinstance(cutoff_freqs, (tuple, list, np.ndarray)) and len(cutoff_freqs) == 2):
            raise ValueError("Para 'band'/'bandstop', cutoff_freqs deve ser (low_hz, high_hz).")
        low_hz, high_hz = float(cutoff_freqs[0]), float(cutoff_freqs[1])
        if not (low_hz > 0 and high_hz > 0 and low_hz < high_hz):
            raise ValueError("Banda inválida: garanta 0 < low_hz < high_hz.")

        # Ajustes simples para caber no Nyquist
        high_hz = min(high_hz, 0.95 * nyq)
        low_hz  = min(low_hz, max(10.0, 0.5 * high_hz))

        low, high = norm_clamp(low_hz), norm_clamp(high_hz)
        if not (0 < low < high < 1):
            raise ValueError(f"Banda normalizada inválida: low={low:.6f}, high={high:.6f}.")

        btype = 'bandpass' if type in ('band', 'bandpass') else 'bandstop'
        b, a = bessel(order, [low, high], btype=btype, analog=False, norm='phase')
    else:
        raise ValueError("type inválido. Use 'low', 'high', 'band', 'bandstop'.")

    return filtfilt(b, a, x, axis=axis)


## Calcular o STFT

In [None]:
def calculateSTFT(signal, sr, n_fft=1024, hop_length=256, window='hann',
                  db=True, ylim=None, xlim=None, zlim_db=(-80, 0),
                  return_data=False, reconstruct=False, return_components=False,
                  width=900, height=360, title='Espectrograma (STFT)',
                  frequency: float = None,      # Hz (para calcular período e foco temporal)
                  focus_cycles: int = 5,        # quantos ciclos mostrar no zoom inicial
                  show_period_markers: bool = True,  # desenhar linhas verticais a cada período
                  show: bool = True):
    """
    STFT com escalas CONSISTENTES e foco por período:
      - ylim: float ou (fmin, fmax) fixa o eixo Y em Hz
      - xlim: (tmin, tmax) em segundos; se None, usa (0, duracao_do_sinal) ou 'focus_cycles * 1/frequency'
      - zlim_db: (zmin, zmax) fixa a escala de cores em dB
      - frequency + focus_cycles: abre o gráfico já focado em alguns ciclos do sinal
      - show_period_markers: desenha linhas verticais nos instantes k*T

    Retornos:
          - return_components=True, reconstruct=False -> (f, t, S_plot, Z)
          - return_components=True, reconstruct=True  -> (f, t, S_plot, Z, x_rec)
          - return_components=False, return_data=True:
              * reconstruct=False -> (f, t, S_plot)
              * reconstruct=True  -> (f, t, S_plot, x_rec)
          - caso contrário -> None (só plota)

        Observações:
          - S_plot está em dB se db=True, senão é magnitude.
          - Z é o espectro complexo (fase verdadeira).

    """

    # --- Preparos ---
    x = np.asarray(signal, float).squeeze()
    dur = len(x) / float(sr)

    noverlap = n_fft - hop_length
    win = get_window(window, n_fft, fftbins=True) if isinstance(window, str) else np.asarray(window)

    # --- STFT ---
    f, t, Z = stft(x, fs=sr, window=win,
                   nperseg=n_fft, noverlap=noverlap,
                   nfft=n_fft, return_onesided=True,
                   boundary='zeros')
    S = np.abs(Z)

    # --- Magnitude → dB (ou amplitude) ---
    if db:
        S_plot = 20.0 * np.log10(S + 1e-10)  # dB absolutos
        zlab = 'Magnitude (dB)'
        zmin, zmax = (None, None) if (zlim_db is None) else (float(zlim_db[0]), float(zlim_db[1]))
    else:
        S_plot = S
        zlab = 'Amplitude'
        zmin, zmax = (None, None)

    # --- Limite de frequência (Y) ---
    if ylim is not None:
        if np.isscalar(ylim):
            fmin, fmax = 0.0, float(ylim)
        else:
            fmin, fmax = float(ylim[0]), float(ylim[1])
        fmax = min(fmax, sr/2.0)   # Nyquist
        fmin = max(fmin, 0.0)
    else:
        fmin, fmax = 0.0, sr/2.0

    y_mask = (f >= fmin) & (f <= fmax)
    f = f[y_mask]
    S_plot = S_plot[y_mask, :]

    # --- Limite de tempo (X) / foco por período ---
    period = None
    if (frequency is not None) and (frequency > 0):
        period = 1.0 / float(frequency)
        xlim_auto = (0.0, min(dur, focus_cycles * period))
    else:
        xlim_auto = xlim if xlim is not None else (0.0, dur)

    tmin, tmax = float(xlim_auto[0]), float(xlim_auto[1])

    # --- Plotly ---
    fig = go.Figure(
        data=go.Heatmap(
            x=t, y=f, z=S_plot,
            colorscale='Viridis',
            colorbar=dict(title=zlab),
            zsmooth=False,
            zmin=zmin, zmax=zmax,
            hovertemplate="t=%{x:.3f}s<br>f=%{y:.0f} Hz<br>Valor=%{z:.2f}<extra></extra>"
        )
    )

    # Título com f e T se fornecido
    if period is not None:
        title = f"{title} • f = {frequency:.2f} Hz • T = {period:.4f} s"

    fig.update_layout(
        title=title,
        xaxis_title="Tempo (s)",
        yaxis_title="Frequência (Hz)",
        width=width,
        height=height,
        template="plotly_white",
        margin=dict(l=60, r=20, t=50, b=40),
        hovermode="x unified",
        dragmode="zoom"
    )

    # Eixo X com range + controles de zoom
    fig.update_xaxes(
        range=[tmin, tmax],
        showgrid=True,
        zeroline=False,
        rangeslider=dict(visible=True),
        rangeselector=dict(
            buttons=list([
                dict(count=0.1, label="100ms", step="second", stepmode="backward"),
                dict(count=1,   label="1s",    step="second", stepmode="backward"),
                dict(count=5,   label="5s",    step="second", stepmode="backward"),
                dict(step="all",label="Tudo")
            ])
        )
    )

    # Eixo Y fixo (linear; se quiser log: fig.update_yaxes(type="log"))
    fig.update_yaxes(range=[fmin, fmax], showgrid=True, zeroline=False)

    # --- Marcadores de período (linhas verticais) ---
    if show_period_markers and (period is not None) and (period > 0):
        k = 1
        shapes = []
        while k * period <= min(dur, tmax):
            xk = k * period
            shapes.append(dict(
                type="line",
                xref="x", yref="paper",
                x0=xk, x1=xk, y0=0, y1=1,
                line=dict(color="rgba(0,0,0,0.25)", width=1, dash="dot")
            ))
            k += 1
        if shapes:
            fig.update_layout(shapes=shapes)

    if show:
      fig.show()

     # --- Reconstrução opcional ---
    x_rec = None
    if reconstruct:
        _, x_rec = istft(Z, fs=sr, window=win,
                         nperseg=n_fft, noverlap=noverlap,
                         nfft=n_fft, input_onesided=True, boundary=None)
        if len(x_rec) > len(x):
            x_rec = x_rec[:len(x)]
        elif len(x_rec) < len(x):
            x_rec = np.pad(x_rec, (0, len(x)-len(x_rec)))

 # --- Saídas (compatibilidade + novos componentes) ---
    if return_components:
        if reconstruct:
            return f, t, S_plot, Z, x_rec
        else:
            return f, t, S_plot, Z

    if return_data:
        return (f, t, S_plot, x_rec) if reconstruct else (f, t, S_plot)

## Execução do STFT

In [None]:
parquet_dir = Path("/content/drive/MyDrive/Datasets/DEAM/")