In [None]:
from collections import defaultdict
import random

from librosa.display import specshow
import matplotlib.pyplot as plt
import torch
from datasets import load_dataset
from torchaudio.transforms import AmplitudeToDB, MelSpectrogram, Resample
import gradio as gr
from transformers import ASTFeatureExtractor
import numpy as np
from IPython.display import Audio

%load_ext gradio

no_class_string = "Use filename instead"
feature_extractor = ASTFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
astextractor_frame_shift = 10
astextractor_hop_length = feature_extractor.sampling_rate*(astextractor_frame_shift/1000)

In [None]:
dataset = load_dataset("danavery/urbansound8k")
filename_to_index = defaultdict(int)
class_to_class_id = defaultdict(int)
class_id_files = defaultdict(list)
for index, item in enumerate(dataset["train"]):
    filename = item["slice_file_name"]
    class_name = item["class"]
    class_id = int(item["classID"])

    filename_to_index[filename] = index
    class_to_class_id[class_name] = class_id
    class_id_files[class_id].append(filename)

In [None]:
def fetch_random():
    example_index = random.randint(0, len(dataset["train"]) - 1)
    example = dataset["train"][example_index]
    return example

In [None]:
def get_random_index_by_class(audio_class):
    class_id = class_to_class_id[audio_class]
    filenames = class_id_files.get(class_id, [])
    selected_filename = random.choice(filenames)
    index = filename_to_index.get(selected_filename)
    return index

In [None]:
def fetch_example(file_name=None, audio_class=no_class_string):
    if audio_class == no_class_string and file_name:
        example = dataset["train"][filename_to_index[file_name]]
    elif audio_class == no_class_string:
        example = fetch_random()
    else:
        example = dataset["train"][get_random_index_by_class(audio_class)]

    waveform = torch.tensor(example["audio"]["array"]).float()
    waveform = torch.unsqueeze(waveform, 0)
    sr = example["audio"]["sampling_rate"]
    slice_file_name = example["slice_file_name"]
    audio_class = example["class"]
    return waveform, sr, slice_file_name, audio_class

In [None]:
def make_mel_spectrogram(
    audio: torch.Tensor, sample_rate, hop_length=256, n_fft=512, n_mels=64
) -> torch.Tensor:
    spec_transformer = MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
    )
    mel_spec = spec_transformer(audio).squeeze(0)

    amplitude_to_db_transformer = AmplitudeToDB()
    mel_spec_db = amplitude_to_db_transformer(mel_spec)
    print(mel_spec_db.shape)
    return mel_spec_db


def resample(audio, file_sr, input_sr):
    if audio.shape[0] > 1:
        audio = torch.mean(audio, dim=0, keepdim=True)
    resampler = Resample(file_sr, input_sr)
    audio = resampler(audio)

    num_samples = audio.shape[-1]
    total_duration = num_samples / input_sr

    return audio, num_samples, total_duration

In [None]:
def normalize_spectrogram(spec):
    spectrogram = (spec - torch.min(spec)) / (torch.max(spec) - torch.min(spec))
    return spectrogram


def generate_spectrogram(audio, input_sr):
    spectrogram = make_mel_spectrogram(audio, input_sr)
    return spectrogram


def preprocess(waveform, file_sr, input_sr):
    audio, _, _ = resample(waveform, file_sr, input_sr)
    spec = generate_spectrogram(audio, input_sr)
    spec = normalize_spectrogram(spec)
    return audio, spec

In [None]:
def preprocess_with_ast_feature_extractor(waveform, file_sr, output_sr):
    raw_audio, _, _ = resample(waveform, file_sr, output_sr)

    inputs = feature_extractor(
        raw_audio.numpy(),
        sampling_rate=output_sr,
        padding="max_length",
        return_tensors="pt"
    )
    spec = inputs["input_values"]
    spec = torch.squeeze(spec, 0)
    spec = torch.transpose(spec, 0, 1)

    actual_frames = np.ceil(len(raw_audio[0]) / 160).astype(int)
    spec = spec[:, :actual_frames]

    return raw_audio, spec

In [None]:
def load_file(file_name, audio_class):
    if file_name:
        waveform, file_sr, file_name, audio_class = fetch_example(
            file_name=file_name, audio_class=audio_class
        )
    elif audio_class != no_class_string:
        waveform, file_sr, file_name, audio_class = fetch_example(
            file_name=file_name, audio_class=audio_class
        )
    else:
        waveform, file_sr, file_name, audio_class = fetch_example()
    return file_name, audio_class, waveform, file_sr

In [None]:
def plot_spectrogram(input_sr, spec, hop_length):
    fig, ax = plt.subplots(figsize=(5, 2))
    img = specshow(
        spec.numpy(),
        sr=input_sr,
        hop_length=hop_length,
        x_axis="time",
        y_axis="mel",
        ax=ax,
    )
    return fig

In [None]:
def process(file_name="", audio_class=no_class_string, model="AST"):
    file_name, audio_class, waveform, file_sr = load_file(file_name, audio_class)

    if model == "AST":
        input_sr = feature_extractor.sampling_rate
        audio, spec = preprocess_with_ast_feature_extractor(waveform, file_sr, input_sr)
        hop_length = astextractor_hop_length
    else:
        input_sr = 22050
        audio, spec = preprocess(waveform, file_sr, input_sr)
        hop_length = 256

    fig = plot_spectrogram(input_sr, spec, hop_length)
    return fig, audio[0].numpy(), file_name, audio_class, input_sr


In [None]:
fig, audio, file_name, audio_class, input_sr = process(file_name="138031-2-0-45.wav", model="AST")
plt.show()
Audio(audio, rate=input_sr)

In [None]:
def generate_gradio_elements(file_name, class_picker, model):
    fig, audio, file_name, audio_class, _ = process(file_name, class_picker, model)
    fig = gr.Plot(value=fig)
    audio = gr.Audio(value=(22050, audio))
    file_name = gr.Textbox(value=file_name)
    class_picker = gr.Dropdown(value=audio_class)
    return fig, audio, file_name, class_picker

In [None]:
spec = process("137969-2-0-37.wav")

In [None]:
%%blocks
classes = list(class_to_class_id.keys())
classes.append(no_class_string)

with gr.Blocks() as demo:
    with gr.Row():
        model = gr.Dropdown(choices=["AST", "local"], value="local", label="Choose a model")
        file_name = gr.Textbox(label="slice_file_name in dataset")
        class_picker = gr.Dropdown(
            choices=classes, label="Choose a category", value=classes[-1]
        )
    with gr.Row():
        spec = gr.Plot(container=True)
        my_audio = gr.Audio()
    gen_button = gr.Button("Get Spec")
    gen_button.click(
        fn=generate_gradio_elements,
        inputs=[file_name, class_picker, model],
        outputs=[spec, my_audio, file_name, class_picker],
    )
    gr.Examples(
        examples=[["100263-2-0-117.wav"], ["100852-0-0-0.wav"]],
        inputs=[file_name, class_picker, model],
        outputs=[spec, my_audio, file_name, class_picker],
        run_on_click=True,
        fn=generate_gradio_elements,
    )