# Transfer learning using YAMNet

Heavily based on [this](https://www.tensorflow.org/tutorials/audio/transfer_learning_audio) tutorial.

> YAMnet requires to downsample audio to 16kHz. Does it have an impact on the quality of our sounds?

In [None]:
# Disable Tensorflow's debugging logs
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

In [None]:
import pandas as pd
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_io as tfio

In [None]:
from biobuzz.yamnet.preprocessing import load_wav_16k_mono

In [None]:
from pathlib import Path
from biobuzz.metadata import MetadataLoader

loader = MetadataLoader(Path.cwd().parent / "data")
loader.load()

In [None]:
from biobuzz.yamnet.model import YamnetModel

yamnet_model = YamnetModel()
yamnet_model.load()

In [None]:
import tensorflow as tf
from biobuzz.yamnet.preprocessing import split_dataset

train_ds, val_ds = split_dataset(*loader.get_metadata(), yamnet_model)

In [None]:
my_classes = loader.get_classes()

In [None]:
my_model = tf.keras.Sequential(
    [
        tf.keras.layers.Input(shape=(1024), dtype=tf.float32, name="input_embedding"),
        tf.keras.layers.Dense(512, activation="relu"),
        tf.keras.layers.Dense(len(my_classes)),
    ],
    name="my_model",
)

my_model.summary()

In [None]:
my_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer="adam",
    metrics=["accuracy"],
)

callback = tf.keras.callbacks.EarlyStopping(
    monitor="loss", patience=3, restore_best_weights=True
)

In [None]:
history = my_model.fit(
    train_ds,
    epochs=20,
    validation_data=val_ds,
    callbacks=callback
)

In [None]:
filenames[10]

In [None]:
testing_wav_data = load_wav_16k_mono(filenames[10])
scores, embeddings, spectrogram = model.model(testing_wav_data)
result = my_model(embeddings).numpy()

inferred_class = my_classes[result.mean(axis=0).argmax()]
print(f'The main sound is: {inferred_class}')

---

In [None]:
test_filenames = loader.get_test_filenames()

---

In [None]:
import librosa
import numpy as np
import matplotlib.pyplot as plt


def plot_mel_spectrogram(S: np.ndarray, sampling_rate: float) -> None:
    fig, ax = plt.subplots()

    # Actual rendering
    S_db = librosa.power_to_db(S, ref=np.max)
    img = librosa.display.specshow(
        S_db, x_axis="time", y_axis="mel", sr=sampling_rate, ax=ax
    )

    # Image formatting
    fig.colorbar(img, ax=ax, format="%+2.f dB")
    ax.set(title="Mel-frequency spectrogram")

    fig.show()

In [None]:
S = librosa.feature.melspectrogram(y=y, sr=sampling_rate)
plot_mel_spectrogram(S, sampling_rate)