In [None]:
from diff_gfdn.config.config import DiffGFDNConfig
from diff_gfdn.dataloader import ThreeRoomDataset, load_dataset, custom_collate, RoomDataset
from diff_gfdn.utils import db
from pathlib import Path
from numpy.typing import NDArray, ArrayLike
import torch
import numpy as np
import torch.nn.functional as F
from scipy.fft import rfftfreq, rfft
from scipy.signal import fftconvolve
import matplotlib.pyplot as plt
from typing import Optional

We want to impose an EDR based loss in checking how well our DiffGFDN fits the data. The EDR loss is defined as:

$$\text{EDR}_{\text{loss}}(k,m) = 10 \log_{10} \left(\frac{\sum_{\tau=m}^M |H(k, \tau) |^2} {\sum_{\tau=m}^M |\hat{H}(k, \tau) |^2}\right)$$

For this, we need the STFT bins of the RIRs, $H(k, m)$. However, we only have access to their DFT bins, $H(k)$. Can we get the EDR without calculating an explicit STFT?

\begin{align}
H(k,m) &= \sum_{n=-\infty}^{\infty} h(n) w(n-mR) e^{-j\omega_k n} \\
&= \text{DTFT}( h \ .  \text{shift}_{mR} (w))_k
\\
&= \left(\text{DTFT}(h) \circledast \text{DTFT}(w) \right)_k  \exp(-j\omega_k mR) \\
&= (H \circledast W)_k  \exp(-j\omega_k mR)
\end{align} 

In [None]:
def plot_spectrogram(S: torch.tensor, freqs: ArrayLike, time_frames: ArrayLike, title:Optional[str]=None):
    plt.figure()
    plt.imshow(db(np.abs(S)).cpu().detach().numpy(), aspect='auto', origin='lower',
    extent=[time_frames.min(), time_frames.max(), freqs.min(), freqs.max()])
    plt.xlabel('Time (s)')
    plt.ylabel('Frequency (Hz)')
    cbar = plt.colorbar()
    cbar.set_label('dB')
    if title is not None:
        plt.title(title)
    plt.show()
    

def get_stft_torch(rir: torch.tensor, sample_rate:float, 
                    win_size:int, hop_size:int, window:Optional[torch.tensor]=None, 
                    nfft: int=2**10, time_axis:int=-1):
    time_samps = rir.shape[time_axis]
    if time_samps % hop_size != 0:
        # zero pad the input signal
        num_extra_zeros = hop_size * int(np.ceil(time_samps/hop_size)) - time_samps
        # pad zeros to the right of the last dimension (time)
        rir = F.pad(input=rir, pad=(0, num_extra_zeros), mode='constant', value=0)

    if window is None:
        window = torch.hann_window(win_size)
        assert hop_size == win_size // 2
    
    # complex tensor of shape B (batch size), N(num frequency samples), T(num_frames):
    S = torch.stft(rir, nfft, hop_length=hop_size, win_length=win_size, window=window, center=False, 
                   normalized=False, onesided=True, return_complex=True)

    freqs = rfftfreq(nfft, d=1.0/sample_rate)
    time_frames = np.arange(0, rir.shape[time_axis] - hop_size, hop_size) / sample_rate
    assert len(freqs) == S.shape[0]
    assert len(time_frames) == S.shape[-1]
    return S, freqs, time_frames

def get_edr_from_stft(S: torch.tensor):
    num_freqs, num_time_frames = S.shape
    edr = torch.zeros((num_freqs, num_time_frames), dtype=torch.float32)
    for m in range(num_time_frames):
        edr[:, m] = torch.sum(torch.abs(S[:, m:])**2, axis=-1)
    edr = db(edr, is_squared = True)
    return edr
    

### Get STFT

In [None]:
config_dict = DiffGFDNConfig()
room_data = ThreeRoomDataset(Path(config_dict.room_dataset_path).resolve())
single_rir = torch.from_numpy(room_data.rirs[0, ...])
fs = room_data.sample_rate

In [None]:
win_size = 2**9
hop_size = win_size // 2
S, freqs, time_frames = get_stft_torch(single_rir, fs, win_size=win_size, hop_size=hop_size, nfft=win_size)
plot_spectrogram(S, freqs, time_frames, title='RIR STFT')

### Plot EDR from STFT

In [None]:
edr = get_edr_from_stft(S)
plot_spectrogram(edr, freqs, time_frames, title='EDR')

### Try calculating STFT from the DTFT directly 
This does not give a good match since we are using the DFT of the signal, and not its DTFT.

In [None]:
def get_custom_stft_from_dft(rir_response:torch.tensor, sample_rate: float, win_size: int, hop_size: int, 
                             window: Optional[torch.tensor]=None, nfft: int = 2**10, freq_axis:int=-1):
    """
    Compute the Short-Time Fourier Transform (STFT) directly from the DTFT using the convolution rule.

    Parameters:
        rir_respomse (torch.tensor): The DTFT of the signal (frequency-domain representation).
        sample_rate (float): sampling rate of the signal
        hop_size (int): The number of samples to hop between windows.
        fft_size (int): The FFT size for STFT.
        window (torch.tensor): The window function (in time-domain).


    Returns:
        STFT (2D numpy array): The STFT of the signal.
    """
    num_freq_bins = (rir_response.shape[freq_axis] - 1) * 2
    # number of windows in STFT
    omega_vals = torch.fft.rfftfreq(num_freq_bins)

    if window is None:
        window = torch.hann_window(win_size)
        assert hop_size == win_size // 2

    # assuming number of time frames is equal to the number of frequency bins
    num_windows = (num_freq_bins - len(window)) // hop_size + 1  # Number of time windows
    stft_matrix = torch.zeros((num_windows, nfft), dtype=torch.complex64)

    # Compute the DTFT of the window function
    window_response = rfft(window, num_freq_bins)
    
    # For each time frame (hop), multiply X(omega) with the window's DTFT and shift
    for i in range(num_windows):
        # The "shift" in time corresponds to a modulation in the frequency domain
        time_shift = torch.exp(-1j * omega_vals * (i * hop_size))

        # Apply the convolution rule: modulate DTFT and multiply with window DTFT
        rir_response_shifted = rir_response * time_shift
        rir_response_windowed = rir_response_shifted * window_response

        # Store the result in the STFT matrix (using FFT-size binning)
        stft_matrix[i, :] = rir_response_windowed[:nfft]

    freqs = rfftfreq(nfft, d=1.0/sample_rate)
    time_frames = np.arange(0, num_windows * hop_size, hop_size) / sample_rate

    return stft_matrix, freqs, time_frames

In [None]:
win_size = 2**9
hop_size = win_size // 2
single_rir_response = torch.from_numpy(room_data.rir_mag_response[0, ...])
S_custom, freqs, time_frames = get_custom_stft_from_dft(single_rir_response, fs, win_size=win_size, hop_size=hop_size, nfft=win_size)
plot_spectrogram(S_custom, freqs, time_frames, title='Custom RIR STFT')