In [None]:
from collections import defaultdict
import random

from librosa.display import specshow
import matplotlib.pyplot as plt
import torch
from datasets import load_dataset
from torchaudio.transforms import AmplitudeToDB, MelSpectrogram, Resample
import gradio as gr

%load_ext gradio

no_class_string = "Use filename instead"

In [None]:
dataset = load_dataset("danavery/urbansound8k")
filename_to_index = defaultdict(int)
class_to_class_id = defaultdict(int)
class_id_files = defaultdict(list)
for index, item in enumerate(dataset["train"]):
    filename = item["slice_file_name"]
    class_name = item["class"]
    class_id = int(item["classID"])

    filename_to_index[filename] = index
    class_to_class_id[class_name] = class_id
    class_id_files[class_id].append(filename)

In [None]:
def fetch_random():
    example_index = random.randint(0, len(dataset['train']) - 1)
    example = dataset['train'][example_index]
    return example

In [None]:
def get_random_index_by_class(audio_class):
    class_id = class_to_class_id[audio_class]
    filenames = class_id_files.get(class_id, [])
    selected_filename = random.choice(filenames)
    index = filename_to_index.get(selected_filename)
    return index

In [None]:
def fetch_example(file_name=None, audio_class=no_class_string):
    if audio_class == no_class_string and file_name:
        example = dataset["train"][filename_to_index[file_name]]
    elif audio_class == no_class_string:
        example = fetch_random()
    else:
        example = dataset["train"][get_random_index_by_class(audio_class)]

    waveform = torch.tensor(example["audio"]["array"]).float()
    waveform = torch.unsqueeze(waveform, 0)
    sr = example["audio"]["sampling_rate"]
    slice_file_name = example["slice_file_name"]
    audio_class = example["class"]
    return waveform, sr, slice_file_name, audio_class

In [None]:
def make_mel_spectrogram(
    audio: torch.Tensor, sample_rate, hop_length=256, n_fft=512, n_mels=64
) -> torch.Tensor:
    spec_transformer = MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
    )
    mel_spec = spec_transformer(audio).squeeze(0)

    amplitude_to_db_transformer = AmplitudeToDB()
    mel_spec_db = amplitude_to_db_transformer(mel_spec)
    print(mel_spec_db.shape)
    return mel_spec_db

In [None]:
def resample(audio, file_sr):
    if audio.shape[0] > 1:
        audio = torch.mean(audio, dim=0, keepdim=True)
    sample_rate = 22050
    resampler = Resample(file_sr, sample_rate)
    audio = resampler(audio)

    num_samples = audio.shape[-1]
    total_duration = num_samples / sample_rate

    return audio, num_samples, total_duration


def normalize_spectrogram(spec):
    spectrogram = (spec - torch.min(spec)) / (torch.max(spec) - torch.min(spec))
    return spectrogram


def generate_spectrogram(audio):
    spectrogram = make_mel_spectrogram(audio, 22050)
    return spectrogram


def preprocess(waveform, sr):
    audio, _, _ = resample(waveform, sr)
    spec = generate_spectrogram(audio)
    norm_spec = normalize_spectrogram(spec)
    return audio, norm_spec

In [None]:
def generate_spec_figure(file_name="", audio_class=None):
    if file_name:
        waveform, sr, file_name, audio_class = fetch_example(
            file_name=file_name, audio_class=audio_class
        )
    elif audio_class != no_class_string:
        waveform, sr, file_name, audio_class = fetch_example(
            file_name=file_name, audio_class=audio_class
        )
    else:
        waveform, sr, file_name, audio_class = fetch_example()
    audio, norm_spec = preprocess(waveform, sr)
    fig, ax = plt.subplots(figsize=(5, 2))
    img = specshow(
        norm_spec.numpy(), sr=22050, hop_length=256, x_axis="time", y_axis="mel", ax=ax
    )
    return fig, audio[0].numpy(), file_name, audio_class



In [None]:
def generate_gradio_elements(file_name, class_picker):
    fig, audio, file_name, audio_class = generate_spec_figure(file_name, class_picker)
    fig = gr.Plot(value=fig)
    audio = gr.Audio(value=(22050, audio))
    file_name = gr.Textbox(value=file_name)
    class_picker = gr.Dropdown(value=audio_class)
    return fig, audio, file_name, class_picker

In [None]:
spec = generate_spec_figure("137815-4-0-8.wav")

In [None]:
%%blocks
classes = list(class_to_class_id.keys())
classes.append(no_class_string)

with gr.Blocks() as demo:
    with gr.Row():
        file_name = gr.Textbox(label="slice_file_name in dataset")
        class_picker = gr.Dropdown(
            choices=classes, label="Choose a category", value=classes[-1]
        )
    with gr.Row():
        spec = gr.Plot(container=True)
        my_audio = gr.Audio()
    gen_button = gr.Button("Get Spec")
    gen_button.click(
        fn=generate_gradio_elements,
        inputs=[file_name, class_picker],
        outputs=[spec, my_audio, file_name, class_picker],
    )
    gr.Examples(
        examples=[["100263-2-0-117.wav"], ["100852-0-0-0.wav"]],
        inputs=[file_name, class_picker],
        outputs=[spec, my_audio, file_name, class_picker],
        run_on_click=True,
        fn=generate_gradio_elements,
    )