# Transfer learning using YAMNet

Heavily based on [this](https://www.tensorflow.org/tutorials/audio/transfer_learning_audio) tutorial.

> YAMnet requires to downsample audio to 16kHz. Does it have an impact on the quality of our sounds?

In [None]:
# Disable Tensorflow's debugging logs
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

In [None]:
from __future__ import annotations

In [None]:
import json
from pathlib import Path

import pandas as pd


class DataLoader:
    def __init__(self, data_dir: Path):
        self.data_dir = data_dir
        self.labels_path = self.data_dir / "labels.json"
        self.loaded = False

    def load(self):
        self.labels = json.loads(self.labels_path.read_bytes())
        self.loaded = True

    def get_classes(self) -> list[str]:
        return list(self.labels.keys())

    def get_metadata(self):
        audio_paths = list(self.data_dir.glob("**/*.wav"))
        df = pd.DataFrame({"filename": audio_paths})
        df["target"] = df.filename.map(lambda f: self.labels[f.name.split("_")[0]])
        df["fold"] = df.filename.map(lambda f: f.parent.name)
        df["filename"] = df.filename.map(lambda f: str(f))
        return df

In [None]:
import pandas as pd
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_io as tfio


class YamnetModel:
    def __init__(self) -> None:
        self.model_handle = "https://tfhub.dev/google/yamnet/1"

    def load(self) -> None:
        # Download the model from Tensorflow Hub
        self.model = hub.load(self.model_handle)
        # Load classes mapping
        class_map_path = self.model.class_map_path().numpy().decode("utf-8")
        self.classes = pd.read_csv(class_map_path)["display_name"].tolist()

    def predict(self, audio_data: tf.Tensor) -> str:
        """For now, audio must have been already preprocessed."""
        scores, _, _ = self.model(audio_data)
        class_scores = tf.reduce_mean(scores, axis=0)
        top_class = tf.math.argmax(class_scores)
        return self.classes[top_class]

In [None]:
loader = DataLoader(Path.cwd().parent / "data/data_small")
loader.load()

In [None]:
metadata = loader.get_metadata()
filenames = metadata.filename
targets = metadata.target
folds = metadata.fold

In [None]:
model = YamnetModel()
model.load()

In [None]:
@tf.function
def load_wav_16k_mono(filename: str):
    """Load audio file as a tensor and resample it to 16kHz single channel audio."""
    file_content = tf.io.read_file(filename)
    audio, sample_rate = tf.audio.decode_wav(file_content, desired_channels=1)
    audio = tf.squeeze(audio, axis=-1)
    sample_rate = tf.cast(sample_rate, dtype=tf.int64)
    return tfio.audio.resample(audio, rate_in=sample_rate, rate_out=16000)


def load_wav_for_map(filename, label, fold):
    return load_wav_16k_mono(filename), label, fold


def extract_embedding(wav_data, label, fold):
    """run YAMNet to extract embedding from the wav data"""
    scores, embeddings, spectrogram = model.model(wav_data)
    num_embeddings = tf.shape(embeddings)[0]
    return (embeddings, tf.repeat(label, num_embeddings), tf.repeat(fold, num_embeddings))

In [None]:
main_ds = tf.data.Dataset.from_tensor_slices((filenames, targets, folds))
main_ds = main_ds.map(load_wav_for_map)
main_ds = main_ds.map(extract_embedding).unbatch()

In [None]:
cached_ds = main_ds.cache()
train_ds = cached_ds.filter(lambda embedding, label, fold: fold == "train")
val_ds = cached_ds.filter(lambda embedding, label, fold: fold == "val")

remove_fold_column = lambda embedding, label, fold: (embedding, label)
train_ds = train_ds.map(remove_fold_column)
val_ds = val_ds.map(remove_fold_column)

train_ds = train_ds.cache().shuffle(1000).batch(32).prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.cache().batch(32).prefetch(tf.data.AUTOTUNE)

In [None]:
my_classes = loader.get_classes()

In [None]:
my_model = tf.keras.Sequential(
    [
        tf.keras.layers.Input(shape=(1024), dtype=tf.float32, name="input_embedding"),
        tf.keras.layers.Dense(512, activation="relu"),
        tf.keras.layers.Dense(len(my_classes)),
    ],
    name="my_model",
)

my_model.summary()

In [None]:
my_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer="adam",
    metrics=["accuracy"],
)

callback = tf.keras.callbacks.EarlyStopping(
    monitor="loss", patience=3, restore_best_weights=True
)

In [None]:
history = my_model.fit(
    train_ds,
    epochs=20,
    validation_data=val_ds,
    callbacks=callback
)

In [None]:
filenames[10]

In [None]:
testing_wav_data = load_wav_16k_mono(filenames[10])
scores, embeddings, spectrogram = model.model(testing_wav_data)
result = my_model(embeddings).numpy()

inferred_class = my_classes[result.mean(axis=0).argmax()]
print(f'The main sound is: {inferred_class}')

---

In [None]:
import librosa
import numpy as np
import matplotlib.pyplot as plt


def plot_mel_spectrogram(S: np.ndarray, sampling_rate: float) -> None:
    fig, ax = plt.subplots()

    # Actual rendering
    S_db = librosa.power_to_db(S, ref=np.max)
    img = librosa.display.specshow(
        S_db, x_axis="time", y_axis="mel", sr=sampling_rate, ax=ax
    )

    # Image formatting
    fig.colorbar(img, ax=ax, format="%+2.f dB")
    ax.set(title="Mel-frequency spectrogram")

    fig.show()

In [None]:
S = librosa.feature.melspectrogram(y=y, sr=sampling_rate)
plot_mel_spectrogram(S, sampling_rate)