In [None]:

%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.colors import BoundaryNorm
from matplotlib.ticker import MaxNLocator
import librosa.display
import numpy as np

def get_shape_of_wavelet_transform(duration,
                                   sampling_frequency,
                                   wavelet_frequency_resolution):
    """A helper function to get the shape of the wavelet transform.
    
    Parameters
    ----------
    duration: float
        The duration of the data segment.
    sampling_frequency: float
        The sampling frequency of the data segment.
    wavelet_frequency_resolution: float
        The frequency resolution of the wavelet transform.
    
    Returns
    -------
    Nt, Nf: int, int
        The number of time and frequency bins in the wavelet transform.
    """
    Nf = int(sampling_frequency / 2 / wavelet_frequency_resolution)
    Nt = int(duration*sampling_frequency / Nf)
    return Nt, Nf



def plot_spectrogram(spectrogram,
                     duration,
                     sampling_frequency,
                     wavelet_frequency_resolution,
                     title=None, savefig=None, dpi=100):
    Nt, Nf = get_shape_of_wavelet_transform(duration=duration,
                                            sampling_frequency=sampling_frequency,
                                            wavelet_frequency_resolution=wavelet_frequency_resolution)
    print(Nt, Nf)
    sampling_times = np.arange(Nt) * duration / Nt
    sampling_frequencies = np.arange(Nf) * wavelet_frequency_resolution
    cmap = plt.get_cmap("viridis")
    levels = MaxNLocator(nbins=15).tick_values(np.min(spectrogram),np.max(spectrogram))
    norm = BoundaryNorm(levels,ncolors=cmap.N,clip=True)
    fig, ax = plt.subplots()
    img = librosa.display.specshow(spectrogram.T,y_axis="log",x_axis="s",cmap=cmap,norm=norm,x_coords=sampling_times,y_coords=sampling_frequencies,snap=True, ax=ax)
    ax.set_title(title)
    ax.set_ylabel("Frequency (Hz)")
    cbar = fig.colorbar(img,format="%.0e", ax=ax)
    cbar.set_label("Normalized Energy")
    cbar.ax.yaxis.set_label_position("left")
    plt.xlabel("Time (s)")
    if savefig is not None:
        plt.savefig(savefig,dpi=dpi,bbox_inches="tight")
    plt.clf()
    plt.close()

In [2]:
spectrogram = np.random.randn(512, 64).astype(bool)
duration = 16
sampling_frequency = 2048
wavelet_frequency_resolution = 16
plot_spectrogram(spectrogram, duration, sampling_frequency, wavelet_frequency_resolution, savefig='test.png')

512 64
