# Transfer learning using YAMNet

> Need to downsample to 16kHz. Evaluate impact

## Data exploration

In [None]:
from pathlib import Path

DATA_DIR = Path.cwd().parent / "data/data_small"

In [None]:
audio_paths = list(DATA_DIR.glob("train/*.wav"))

In [None]:
example_audio_path = audio_paths[0]
print(example_audio_path)

In [None]:
import librosa

y, sampling_rate = librosa.load(example_audio_path, sr=None)

In [None]:
import numpy as np
import matplotlib.pyplot as plt


def plot_mel_spectrogram(S: np.ndarray, sampling_rate: float) -> None:
    fig, ax = plt.subplots()

    # Actual rendering
    S_db = librosa.power_to_db(S, ref=np.max)
    img = librosa.display.specshow(
        S_db, x_axis="time", y_axis="mel", sr=sampling_rate, ax=ax
    )

    # Image formatting
    fig.colorbar(img, ax=ax, format="%+2.f dB")
    ax.set(title="Mel-frequency spectrogram")

    fig.show()

In [None]:
S = librosa.feature.melspectrogram(y=y, sr=sampling_rate)
plot_mel_spectrogram(S, sampling_rate)

## Inference using YAMnet

### Preprocess audio

In [None]:
import tensorflow as tf
import tensorflow_io as tfio


@tf.function
def load_audio_file(filepath: Path):
    """Load audio file as a tensor and resample it to 16kHz single channel audio."""
    file_content = tf.io.read_file(str(filepath))
    audio, sample_rate = tf.audio.decode_wav(file_content, desired_channels=1)
    audio = tf.squeeze(audio, axis=-1)
    sample_rate = tf.cast(sample_rate, dtype=tf.int64)
    return tfio.audio.resample(audio, rate_in=sample_rate, rate_out=16000)

### Run inference

In [None]:
import pandas as pd
import tensorflow_hub as hub


class YamnetModel:
    def __init__(self) -> None:
        self.model_handle = "https://tfhub.dev/google/yamnet/1"

    def load(self) -> None:
        # Download the model from Tensorflow Hub
        self.model = hub.load(self.model_handle)
        # Load classes mapping
        class_map_path = self.model.class_map_path().numpy().decode("utf-8")
        self.classes = pd.read_csv(class_map_path)["display_name"].tolist()

    def predict(self, audio_data: tf.Tensor) -> str:
        """For now, audio must have been already preprocessed."""
        scores, _, _ = self.model(audio_data)
        class_scores = tf.reduce_mean(scores, axis=0)
        top_class = tf.math.argmax(class_scores)
        return self.classes[top_class]

In [None]:
model = YamnetModel()
model.load()

In [None]:
wav_data = load_audio_file(example_audio_path)
model.predict(wav_data)