In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
from IPython.display import display, HTML
from ipywidgets import interact, FloatSlider, IntSlider, Button, Output 
display(HTML("""<s|tyle>
.rendered_html.text_cell_render {max-width:600px;}
</style>""")) 

# audio.mels

## `to_mel` Function and `cut_up`


In [None]:
#|default_exp audio.mels
#|export
import warnings
warnings.filterwarnings("ignore")
import torch
from cgnai.utils import cgnai_home, sliding_window_ind
from torchvision.transforms import Compose
from torchaudio.transforms import Resample, MelSpectrogram

In [None]:
#|export
def ToMel(sr, wav_cut_spec:"(width, displacement)", n_mels:"features"):
    """
    Higher Order Function. Returns func that 
    maps a wav signal to its melspectrogram.
    """
    w, d = wav_cut_spec
    
    components = []
        
    components.append(MelSpectrogram(
        sample_rate = sr, 
        n_fft       = w, 
        hop_length  = d, 
        pad         = 0, 
        n_mels      = n_mels, 
        normalized  = False))

    components.append(torchaudio.transforms.AmplitudeToDB(
        stype  = 'power', 
        top_db =  80))

    to_mel = Compose(components)
    
    return to_mel

In [None]:
#|export
def cut_up(x, cut_spec:"(width, displacement)"):
    """
    Cuts up an array along its last(!!!) dimension
    according to the cut spec - a tuple of 
    width and displacement.
    
    Note: The name sucks because it 
    really just is sliding windows, 
    and not "cuts" ...äaaanyway.
    """
    return x.unfold(-1, *cut_spec)

In [None]:
#|export
class MelCuts():
    def __init__(self, sr, wav_cut_spec:"(width, displacement)", mel_cut_spec:"(width, displacement)", n_mels):
        self.sr = sr
        self.wav_cut_spec = wav_cut_spec
        self.mel_cut_spec = mel_cut_spec
        self.n_mels = n_mels
        
        self.to_mel = ToMel(sr, wav_cut_spec=wav_cut_spec, n_mels=n_mels)
    
    def cut_up(self, x):
        return cut_up(x, mel_cut_spec)
    
    def __call__(self, wav):
        mel = self.to_mel(wav)
        cuts = self.cut_up(mel)
        return mel, cuts

In [None]:
#|export
def to_samples(i, wav_cut_spec:"(width, displacement)", mel_cut_spec:"(width, displacement)"):
    """Mel-cut index to sample index."""
    w , d  = wav_cut_spec
    w_, d_ = mel_cut_spec
    return i*d_*d, (i*d_ + w_)*d + w
            
def to_ms(s, sr):
    """From samples at a certain rate (Hz) to ms"""
    return s/sr*1000

## Testing and visualization

In [None]:
import matplotlib.pyplot as plt
import torchaudio
import torch
from cgnai.fileio import *
from IPython.display import Audio, display
import numpy as np

In [None]:
path = cgnai_home()/"local/data/cv-corpus-8.0-2022-01-19/en"
D    = torchaudio.datasets.COMMONVOICE(path, 'train.tsv')

In [None]:
wav_, sr_, _ = D[10]
# ---------------------
print(f"wav': {wav_.size()}")
print(f"ORIGINAL sampling rate: {sr_:_.0f}");
display(Audio(wav_, rate=sr_));

plt.figure(figsize=(12,2));
plt.ylim(-1,1)
plt.xlabel("Samples")
plt.ylabel("Amplitude")
plt.plot(wav_[0]);

In [None]:
sr        = 16_000    

fft_width = 512 
fft_step  = fft_width//2

n_mels    = 64
mel_width = 40
mel_step  = 20

wav_cut_spec = (fft_width, fft_step)
mel_cut_spec = (mel_width, mel_step)

In [None]:
resample = Resample(sr_, sr)

cut_it = MelCuts(sr=sr_, 
                 wav_cut_spec=wav_cut_spec, 
                 mel_cut_spec=mel_cut_spec, 
                 n_mels=n_mels)

wav = resample(wav_)
wav_cuts = cut_up(wav, wav_cut_spec)

cuts = cut_up(mel, mel_cut_spec)

mel, cuts = cut_it(wav)


# ----------------
print(mel.size(), cuts.size(), wav_cuts.size())

In [None]:
w , d  = wav_cut_spec
w_, d_ = mel_cut_spec


i   = 5
fr0 = i*d_
fr1 = i*d_ + w_
s0  = fr0*d
s1  = fr1*d + w


# --------------------
plt.figure(figsize=(14,2))
plt.ylim(-1,1)
plt.plot(wav[0])
plt.gca().axvspan(s0, s1, alpha=0.1, color='k')
plt.xticks([s0,s1])


sub = torch.arange(fr0-5, fr1+10)
wcuts = wav_cuts[sub]
fig, axs = plt.subplots(1, len(wcuts),  figsize=(24,2), sharex=False, sharey=True)
plt.ylim(-1,1)
for t,wc in enumerate(wcuts):
    axs[t].set_title(f"{sub[t]}")
    axs[t].plot(wc[0])
    axs[t].set_xticks([])
    if fr0 <= sub[t] <= fr1:
        axs[t].set_facecolor((0,0,0,0.1))
    

plt.figure(figsize=(15,4))

im = mel[0].clone()
im[:,fr0:fr1+1] *= 1
plt.imshow(im)
plt.xticks([fr0,fr1])


fig, axs = plt.subplots(1, len(cuts), figsize=(14,3), sharex=False, sharey=True)
for t,mc in enumerate(cuts):
    axs[t].axis("off")
    axs[t].set_title(f"{t}")
    axs[t].imshow(mc[0], cmap="plasma", vmin=np.amin(mel.numpy()), vmax=np.amax(mel.numpy()))