In [None]:
import os
import numpy as np
import librosa.display
import pytorch_mel_fsgcc_cls_feature_class as cls
import cls_feature_class as cl
import parameters
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plot
from IPython.display import Audio
import torch

plot.rcParams.update({'font.size': 22})

In [None]:
def extend_spectrogram(stft):
    
    num_frames, num_freqs, num_channels = stft.shape

    dc_component = stft[:, :1, :]  # Componente continua (DC)
    positive_freqs = stft[:, 1:num_freqs-1, :]  # Solo le frequenze positive (senza Nyquist)
    nyquist_freq = stft[:, num_freqs-1:num_freqs, :]  # Componente Nyquist (da trattare correttamente)

    # Specchio della parte positiva
    mirrored_freqs = np.flip(np.conj(positive_freqs), axis=1)

    # Correzione della Nyquist (deve rimanere reale se il segnale è reale)
    nyquist_freq = np.real(nyquist_freq) + 0j  # Assicuriamo che sia puramente reale

    # Concatenazione nell'ordine corretto
    extended_stft = np.concatenate([dc_component, positive_freqs, nyquist_freq, mirrored_freqs], axis=1)

    return extended_stft

In [None]:
params = parameters.get_params()
aud_dir = os.path.join( params['dataset_dir'], 'mic_dev', 'mic')

feat_cls = cls.FeatureClass(params)

#audio, fs = feat_cls._load_audio(os.path.join(aud_dir, 'fold5_room1_mix007.wav'))
spect = feat_cls._get_spectrogram_for_file(os.path.join(aud_dir, 'fold5_room1_mix007.wav'))

spect_mics = np.zeros((spect.shape[0], spect.shape[1], 4), dtype=np.complex128)
spect_mics[:, :, 0] = spect[:, :, 4]
spect_mics[:, :, 1] = spect[:, :, 8]
spect_mics[:, :, 2] = spect[:, :, 12]
spect_mics[:, :, 3] = spect[:, :, 16]

mel_spect = feat_cls._get_mel_spectrogram_gcc(spect_mics)

extended_spect = extend_spectrogram(spect_mics)
extended_spect = np.transpose(extended_spect, (1, 0, 2))

In [None]:
print("STFT Shape:", spect.shape, "dtype:", spect.dtype)
print("STFT extended Shape:", extended_spect.shape, "dtype:", extended_spect.dtype)
gcc = feat_cls._get_gcc(spect_mics)
gcc = np.transpose(gcc, (1,0))
print("GCC Shape:", gcc.shape, "dtype:", gcc.dtype)

In [None]:
print("STFT Shape:", spect_mics.shape, "dtype:", spect_mics.dtype)
print("Mel Spectrogram Shape_gcc:", mel_spect.shape, "dtype:", mel_spect.dtype)

In [None]:
%matplotlib inline

spect_mics = np.abs(np.squeeze(spect_mics))
epsilon = 1e-10
spectrogram_db = 20 * np.log10(spect_mics + epsilon)
spectrogram_db = np.transpose(spectrogram_db, (1, 0, 2))
mel_spect = np.transpose(mel_spect, (1, 0))

n_fft = 2048
# axes
freqs = np.linspace(0, params['fs'] / 2, n_fft // 2 + 1)  # Frequencies in Hz (only positive part)
times = np.arange(0, (128/24000) * spectrogram_db.shape[1], (128/24000))

plot.figure(figsize=(40, 20))

ax1 = plot.subplot(2, 1, 1)
aux1 = spectrogram_db[:, :, 0]
im1 = ax1.imshow(aux1, cmap='viridis', aspect='auto', origin='lower', extent=[times[0], times[-1], freqs[0], freqs[-1]])
ax1.set_title('Spectrogram in dB')
ax1.set_xlabel('Time (s)')
ax1.set_ylabel('Frequency (Hz)')
plot.colorbar(im1, ax=ax1, format='%+2.0f dB')


ax2 = plot.subplot(2, 1, 2)
aux2 = mel_spect[:64, :]
im2 = ax2.imshow(aux2, cmap='viridis', aspect='auto', origin='lower')
ax2.set_title('Log Mel Spectrogram')
ax2.set_xlabel('Time Frames')
ax2.set_ylabel('Mel Bands')
plot.colorbar(im2, ax=ax2, format='%+2.0f dB')

plot.tight_layout()
plot.show()

In [None]:
%matplotlib inline
print(feat_cls._mel_wts.shape)
plot.plot(feat_cls._mel_wts)

In [None]:
%matplotlib inline

plot.figure(figsize=(40, 10))

aux1 = gcc[:64, :]
im1 = plot.imshow(aux1, cmap='viridis', aspect='auto', origin='lower')
plot.title('GCC Baseline')
plot.xlabel('Time Frames')
plot.ylabel('Mel Bands')

plot.tight_layout()
plot.show()

In [None]:
mel_bins_edges_hz = librosa.mel_frequencies(n_mels=feat_cls._nb_mel_bins + 2, fmin=0, fmax=feat_cls._fs / 2)

# Converti le frequenze in indici della FFT
k_lims = np.round(mel_bins_edges_hz / feat_cls._fs * feat_cls._nfft).astype(int)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Nframes = extended_spect.shape[1]
# lag corresponding to each index of the GCC
lags = torch.arange(-(feat_cls._nfft / 2), feat_cls._nfft / 2, dtype=torch.float64, device=device)
# maximum lag expected according to microphone separation (1.2 m)
max_lag = torch.round(torch.tensor(2 * 6 / 343) * feat_cls._fs).int()
# pairs = list(itertools.combinations(range(self._nb_channels), 2))
pairs = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
pairs_ex = [(0, 1)]
# Precalcolo del filtro Mel per tutte le bande
#k_lims = Mel_bins(feat_cls._nb_mel_bins, 0, feat_cls._fs / 2, feat_cls._fs, feat_cls._nfft)  # Shape: [nbands+2]
win = 'boxcar'

print(max_lag)
maxlag_ind1 = feat_cls._nfft // 2 - max_lag
maxlag_ind2 = feat_cls._nfft // 2 + max_lag

print(maxlag_ind1, maxlag_ind2)

lagmask = torch.zeros(feat_cls._nfft, dtype=torch.float64, device=device)
lagmask[maxlag_ind1:maxlag_ind2 + 1] = 1
print(lagmask)


Xframes = torch.tensor(extended_spect, device=device)

# Preallocazione dei risultati
'''Meltde = torch.zeros((feat_cls._nb_mel_bins, Nframes, len(pairs_ex)), dtype=torch.float64, device='cpu')
Melmde = torch.zeros((feat_cls._nb_mel_bins, Nframes, len(pairs_ex)), dtype=torch.float64, device='cpu')
Melstd = torch.zeros((feat_cls._nb_mel_bins, Nframes, len(pairs_ex)), dtype=torch.float64, device='cpu')
Melavg = torch.zeros((feat_cls._nb_mel_bins, Nframes, len(pairs_ex)), dtype=torch.float64, device='cpu')'''
MelGCClag = torch.zeros((feat_cls._nb_mel_bins, Nframes, len(pairs_ex), maxlag_ind2 - maxlag_ind1 + 1), dtype=torch.complex128, device='cpu')
# print(torch.cuda.memory_summary())
batch_size = 100

# Vettorializzazione: processiamo tutte le coppie contemporaneamente
for idx, (p1, p2) in enumerate(pairs_ex):
    # Estrai i segnali STFT delle coppie di microfoni
    X1 = Xframes[:, :, p1]  # Shape: [freq_bins, frames]
    X2 = Xframes[:, :, p2]  # Shape: [freq_bins, frames]

    # Calcolo GCC in batch per tutti i frame
    GCC = torch.exp(1j * torch.angle(X2 * torch.conj(X1)))  # Shape: [freq_bins, frames]

    # Elaborazione in batch
    for start in range(0, Nframes, batch_size):
        end = min(start + batch_size, Nframes)

        # Calcolo temporaneo su GPU
        batch_GCCm = torch.zeros((feat_cls._nb_mel_bins, end - start, feat_cls._nfft), dtype=torch.complex128,
                                    device=device)

        for k in range(feat_cls._nb_mel_bins):
            BW = k_lims[k + 2] - k_lims[k] + 1
            BW = BW + BW % 2

            if win == 'hann':
                wind = torch.hann_window(BW, device=device)  # Finestra sul dispositivo corretto
            else:
                wind = torch.ones(BW, dtype=torch.float64, device=device)  # Aggiungi il device qui

            windmask = torch.zeros(feat_cls._nfft, dtype=torch.complex128, device=device)  # Allocato su device
            windmask[:BW // 2] = wind[BW // 2:]
            windmask[-BW // 2:] = wind[:BW // 2]

            # Traslazione e filtraggio del GCC
            GCCd = torch.roll(GCC[:, start:end], shifts=k_lims[k + 1].item(), dims=0)
            GCCd = GCCd * windmask[:, None]
            aux = (1 / BW) * torch.fft.fftshift(torch.fft.ifft(GCCd, dim=0), dim=0) * lagmask[:, None]

            batch_GCCm[k, :, :] = aux.T  # Trasposta per ottenere [n_frames, Nfft] per ogni banda
            abs_aux = torch.abs(aux)

            # Calcolo degli indici e dei risultati
            max_ind = torch.argmax(abs_aux, dim=0)  # Su Nfft, shape: (n_frames,)
            '''Melmde[k, start:end, idx] = abs_aux[max_ind, torch.arange(end - start)].to('cpu')
            Meltde[k, start:end, idx] = lags[max_ind].to('cpu')

            # Calcolo della PDF ausiliaria
            auxpdf = abs_aux / abs_aux.sum(dim=0, keepdim=True)
            Melavg[k, start:end, idx] = (lags[:, None] * auxpdf).sum(dim=0).to('cpu')
            Melstd[k, start:end, idx] = torch.sqrt(
                ((lags[:, None] - Melavg[k, start:end, idx].to(device)) ** 2 * auxpdf).sum(dim=0)).to('cpu')'''

        # Trasferisci GCCm su CPU per l'intervallo di batch corrente
        MelGCClag[:, start:end, idx, :] = batch_GCCm[:, :, maxlag_ind1:maxlag_ind2 + 1].to('cpu')



'''Meltde = Meltde / max_lag
Melmde = Melmde / (0.5 * (1 / feat_cls._nfft))
Melstd = Melstd / max_lag
Melavg = Melavg / (0.5 * max_lag)

print("MELTDE Shape:", Meltde.shape)
print("MELMDE Shape:", Melmde.shape)
print("MELSTD Shape:", Melstd.shape)
print("MELAVG Shape:", Melavg.shape)'''
print("MELGCCLAG Shape:", MelGCClag.shape)


In [None]:
import matplotlib.animation as animation

In [None]:


dati = MelGCClag[:, :, 0, :]  # forma: (nb_mel_bins, Nframes, num_lags)

# Se vuoi usare il modulo dei valori complessi
dati_abs = torch.abs(dati).cpu().numpy()  # conversione a NumPy se necessario
nb_mel_bins, Nframes, num_lags = dati_abs.shape

fig, ax = plot.subplots()
img = ax.imshow(dati_abs[:, :, 0], aspect='auto', origin='lower')
ax.set_xlabel('Frame')
ax.set_ylabel('Mel Bands')
ax.set_title('Lag: 0')

def update(lag):
    img.set_data(dati_abs[:, :, lag])
    ax.set_title(f'Lag: {lag}')
    return [img]

ani = animation.FuncAnimation(fig, update, frames=num_lags, interval=500, blit=True)
plot.show()

In [None]:
%matplotlib inline

plot.figure(figsize=(40, 10))

aux1 = torch.abs(MelGCClag[:, :, 0, 500]).cpu().numpy()
im1 = plot.imshow(aux1, cmap='viridis', aspect='auto', origin='lower')
plot.title('Mel-FSGCC Lag: 500')
plot.xlabel('Time Frames')
plot.ylabel('Mel Bands')

plot.tight_layout()
plot.show()

In [None]:
from IPython.display import HTML
HTML(ani.to_jshtml())

In [None]:
%matplotlib inline
plot.figure(figsize=(40, 10))

aux1 = torch.abs(MelGCClag[:, 9000, 0, :]).cpu().numpy()
im1 = plot.imshow(aux1, cmap='viridis', aspect='auto', origin='lower')
plot.title('Mel-FSGCC Frame: 4000')
plot.xlabel('Time Delay Lags')
plot.ylabel('Mel Bands')

plot.tight_layout()
plot.show()

In [None]:
%matplotlib inline
plot.figure(figsize=(40, 10))

aux1 = Meltde[:, :, 0].cpu().numpy()
im1 = plot.imshow(aux1, cmap='viridis', aspect='auto', origin='lower')
plot.title('Meltde')
plot.xlabel('Time Frames')
plot.ylabel('Mel Bands')

plot.tight_layout()
plot.show()

In [None]:
%matplotlib inline
plot.figure(figsize=(40, 10))

aux1 = Melmde[:, :, 0].cpu().numpy()
im1 = plot.imshow(aux1, cmap='viridis', aspect='auto', origin='lower')
plot.title('Melmde')
plot.xlabel('Time Frames')
plot.ylabel('Mel Bands')


plot.tight_layout()
plot.show()

In [None]:
%matplotlib inline
plot.figure(figsize=(40, 10))

aux1 = Melstd[:, :, 0].cpu().numpy()
im1 = plot.imshow(aux1, cmap='viridis', aspect='auto', origin='lower')
plot.title('Melstd')
plot.xlabel('Time Frames')
plot.ylabel('Mel Bands')


plot.tight_layout()
plot.show()

In [None]:
%matplotlib inline
plot.figure(figsize=(40, 10))

aux1 = Melavg[:, :, 0].cpu().numpy()
im1 = plot.imshow(aux1, cmap='viridis', aspect='auto', origin='lower')
plot.title('Melavg')
plot.xlabel('Time Frames')
plot.ylabel('Mel Bands')


plot.tight_layout()
plot.show()

In [None]:
def collect_classwise_data(_in_dict):
    _out_dict = {}
    for _key in _in_dict.keys():
        for _seld in _in_dict[_key]:
            if _seld[0] not in _out_dict:
                _out_dict[_seld[0]] = []
            _out_dict[_seld[0]].append([_key, _seld[0], _seld[2], _seld[3]])
    return _out_dict


def plot_func(plot_data, hop_len_s, ind, plot_x_ax=False, plot_y_ax=False):
    cmap = ['b', 'r', 'g', 'y', 'k', 'c', 'm', 'orange', 'grey', 'lime', 'peru', 'maroon', 'lightpink', 'purple']
    for class_ind in plot_data.keys():
        time_ax = np.array(plot_data[class_ind])[:, 0] *hop_len_s
        y_ax = np.array(plot_data[class_ind])[:, ind]
        plot.plot(time_ax, y_ax, marker='.', color=cmap[class_ind], linestyle='None', markersize=4)
    plot.grid()
    plot.xlim([0, 60])
    if not plot_x_ax:
        plot.gca().axes.set_xticklabels([])

    if not plot_y_ax:
        plot.gca().axes.set_yticklabels([])

In [None]:
params = parameters.get_params()


# output format file to visualize
pred = 'results_audio/6_1_dev_split0_multiaccdoa_mic_gcc_20250114111734_test/fold4_room24_mix007.csv'

# path of reference audio directory for visualizing the spectrogram and description directory for
# visualizing the reference
# Note: The code finds out the audio filename from the predicted filename automatically
ref_dir = os.path.join( params['dataset_dir'], 'metadata_dev', 'dev-test-sony')
aud_dir = os.path.join( params['dataset_dir'], 'mic_dev', 'dev-test-sony')

feat_cls = cls.FeatureClass(params)

# fs di 24kHz, finestra di Hanning di lunghezza 960, hop size di 480 (overlap del 50%),  nfft di 1024, numero mel bands 64

# load the audio and extract spectrogram
ref_filename = os.path.basename(pred).replace('.csv', '.wav')
audio, fs = feat_cls._load_audio(os.path.join(aud_dir, ref_filename))
stft = feat_cls._spectrogram(audio[:, :1])
mel_spect = feat_cls._get_mel_spectrogram(stft)


stft = np.abs(np.squeeze(stft))
epsilon = 1e-10
spectrogram_db = 20 * np.log10(stft + epsilon)


n_fft = 1024
# axes
freqs = np.linspace(0, params['fs'] / 2, n_fft // 2 + 1)  # Frequencies in Hz (only positive part)
times = np.arange(0, params['hop_len_s'] * stft.shape[1], params['hop_len_s'])

In [None]:
Audio(os.path.join(aud_dir, ref_filename))

In [None]:
%matplotlib inline
print("STFT Shape:", stft.shape, "dtype:", stft.dtype)
print("STFT min/max:", np.min(np.abs(stft)), np.max(np.abs(stft)))
print("Mel Spectrogram Shape:", mel_spect.shape, "dtype:", mel_spect.dtype)
print("Mel Spectrogram min/max:", np.min(mel_spect), np.max(mel_spect))

plot.figure(figsize=(40, 20))

ax1 = plot.subplot(2, 1, 1)
aux1 = spectrogram_db[:n_fft // 2 + 1, :]
im1 = ax1.imshow(aux1, cmap='viridis', aspect='auto', origin='lower', extent=[times[0], times[-1], freqs[0], freqs[-1]])
ax1.set_title('Spectrogram in dB')
ax1.set_xlabel('Time (s)')
ax1.set_ylabel('Frequency (Hz)')
plot.colorbar(im1, ax=ax1, format='%+2.0f dB')


ax2 = plot.subplot(2, 1, 2)
aux2 = mel_spect
im2 = ax2.imshow(aux2, cmap='viridis', aspect='auto', origin='lower')
ax2.set_title('Log Mel Spectrogram')
ax2.set_xlabel('Time Frames')
ax2.set_ylabel('Mel Bands')
plot.colorbar(im2, ax=ax2, format='%+2.0f dB')


plot.tight_layout()
plot.show()

In [None]:
# load the predicted output format
pred_dict = feat_cls.load_output_format_file(pred)

# load the reference output format
ref_filename = os.path.basename(pred)
ref_dict_polar = feat_cls.load_output_format_file(os.path.join(ref_dir, ref_filename))

pred_data = collect_classwise_data(pred_dict)
ref_data = collect_classwise_data(ref_dict_polar)

nb_classes = 13

In [None]:
plot.figure(figsize=(20, 15))

gs = gridspec.GridSpec(3, 4)

ax0 = plot.subplot(gs[0, :2])
plot_func(ref_data, params['label_hop_len_s'], ind=1, plot_y_ax=True)
plot.ylim([-1, nb_classes + 1])
plot.title('SED reference')

ax1 = plot.subplot(gs[0, 2:])
plot_func(pred_data, params['label_hop_len_s'], ind=1)
plot.ylim([-1, nb_classes + 1])
plot.title('SED predicted')

ax2 = plot.subplot(gs[1, :2])
plot_func(ref_data, params['label_hop_len_s'], ind=2, plot_y_ax=True)
plot.ylim([-180, 180])
plot.title('Azimuth reference')

ax3 = plot.subplot(gs[1, 2:])
plot_func(pred_data, params['label_hop_len_s'], ind=2)
plot.ylim([-180, 180])
plot.title('Azimuth predicted')

ax4 = plot.subplot(gs[2, :2])
plot_func(ref_data, params['label_hop_len_s'], ind=3, plot_y_ax=True)
plot.ylim([-90, 90])
plot.title('Elevation reference')

ax5 = plot.subplot(gs[2, 2:])
plot_func(pred_data, params['label_hop_len_s'], ind=3)
plot.ylim([-90, 90])
plot.title('Elevation predicted')

plot.show()