### Import packages

In [None]:
import pandas as pd
from nightingale.model.classifier_head import ClassifierHead
from nightingale.data.wav_loader import load_wav_16k_mono
import os
import glob
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import tensorflow_hub as hub
import tensorflow as tf
import numpy as np

### Load and Explore birdclef-2024 data (pre conversion)

In [None]:
# Read train meta data
train_metadata_path = "../data/birdclef-2024/train_metadata.csv"
train_df = pd.read_csv(train_metadata_path)
train_df.head()

In [None]:
train_df.describe()

### Prepare dataframe pointing to bird call audio data in wav format

In [None]:
# Read train meta data
base_data_path = "../data/birdclef-2024"
bird_metadata_path = os.path.join(base_data_path, "train_metadata.csv")
bird_df = pd.read_csv(bird_metadata_path)

# Change the filename endings from .ogg to .wav in the filename column of bird_df
bird_df['filename'] = bird_df['filename'].str.replace('.ogg', '.wav', regex=False)

# Show rows where the filename matches the pattern "cohcuc1/*.wav"
wav_files = glob.glob(base_data_path + "/train_audio_16/**/*.wav", recursive=True)
wav_files = [f.replace(base_data_path + "/train_audio_16/", "") for f in wav_files]

filtered_bird_df = bird_df[bird_df['filename'].isin(wav_files)]

bird_classes = list(set(filtered_bird_df['common_name']))

map_class_to_id = {name: idx for idx, name in enumerate(bird_classes)}

class_id = filtered_bird_df['common_name'].apply(lambda name: map_class_to_id[name])
filtered_bird_df = filtered_bird_df.assign(target=class_id)

full_path = filtered_bird_df['filename'].apply(lambda row: os.path.join(base_data_path + "/train_audio_16/", row))
filtered_bird_df = filtered_bird_df.assign(filename=full_path)

# filtered_bird_df.head(10)

### Split data: Training, Validation and Test

In [None]:
# Step 1: Split the data into training (60%), validation (20%) and test (20%) sets
train_df_idx, temp_df_idx = train_test_split(filtered_bird_df.index, test_size=0.4, random_state=42, stratify=filtered_bird_df['target'])
val_df_idx, test_df_idx = train_test_split(temp_df_idx, test_size=0.5, random_state=42, stratify=filtered_bird_df.loc[temp_df_idx, 'target'])

# Step 2: Create 'fold' column in original filtered_bird_df
filtered_bird_df['fold'] = ''  # initialize empty
filtered_bird_df.loc[train_df_idx, 'fold'] = 1
filtered_bird_df.loc[val_df_idx, 'fold'] = 2
filtered_bird_df.loc[test_df_idx, 'fold'] = 3

# filtered_bird_df.head(10)
# plt.hist(filtered_bird_df[filtered_bird_df['fold'] == 1]['target'], bins=len(bird_classes), alpha=0.7, label='Train')
# plt.hist(filtered_bird_df[filtered_bird_df['fold'] == 2]['target'], bins=len(bird_classes), alpha=0.7, label='Val')
# plt.hist(filtered_bird_df[filtered_bird_df['fold'] == 3]['target'], bins=len(bird_classes), alpha=0.7, label='Test')
# plt.xlabel('Bird Classes')
# plt.ylabel('Count')
# plt.title('Distribution of Bird Classes in Train, Val, and Test Sets')
# plt.legend()
# plt.show()

### Modelling
* Load YAMNet
* Create audio/bird call embeddings using the training data with YAMNet
* Create a custom classifier for bird call classification
* Train classifier with created YAMNet embeddings (as inputs) and bird classes (as outputs)
* Concatenate YAMNet and classifier and measure performance

#### Load YAMNet

In [None]:
# Load the model.
model = hub.load('https://tfhub.dev/google/yamnet/1')

#### Use bird call audio to extract embeddings

In [None]:
filenames_train = filtered_bird_df[filtered_bird_df['fold'] == 1]['filename']
targets_train = filtered_bird_df[filtered_bird_df['fold'] == 1]['target']

filenames_val = filtered_bird_df[filtered_bird_df['fold'] == 2]['filename']
targets_val = filtered_bird_df[filtered_bird_df['fold'] == 2]['target']

filenames_test = filtered_bird_df[filtered_bird_df['fold'] == 3]['filename']
targets_test = filtered_bird_df[filtered_bird_df['fold'] == 3]['target']


train_ds = tf.data.Dataset.from_tensor_slices((filenames_train, targets_train))
val_ds = tf.data.Dataset.from_tensor_slices((filenames_val, targets_val))
test_ds = tf.data.Dataset.from_tensor_slices((filenames_test, targets_test))

def load_wav_for_map(filename, label):
  return load_wav_16k_mono(filename), label

train_ds = train_ds.map(load_wav_for_map)
val_ds = val_ds.map(load_wav_for_map)
test_ds = test_ds.map(load_wav_for_map)

In [None]:
def extract_embedding(wav_data, label):
  ''' run YAMNet to extract embedding from the wav data '''
  scores, embeddings, spectrogram = model(wav_data)
  num_embeddings = tf.shape(embeddings)[0]
  return (embeddings,
            tf.repeat(label, num_embeddings))

train_ds = train_ds.map(extract_embedding).unbatch()
val_ds = val_ds.map(extract_embedding).unbatch()
test_ds = test_ds.map(extract_embedding).unbatch()
train_ds.element_spec

train_ds = train_ds.cache().shuffle(1000).batch(32).prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.cache().batch(32).prefetch(tf.data.AUTOTUNE)
test_ds = test_ds.cache().batch(32).prefetch(tf.data.AUTOTUNE)

#### Model bird call classifier

In [None]:
num_bird_classes = len(bird_classes)
bird_class_model = ClassifierHead(num_classes=num_bird_classes)

bird_class_model.summary()

In [None]:
bird_class_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                 optimizer="adam",
                 metrics=['accuracy'])

callback = tf.keras.callbacks.EarlyStopping(monitor='loss',
                                            patience=3,
                                            restore_best_weights=True)

#### Configure MLFLow Experiment

In [None]:
import mlflow
from mlflow import MlflowClient

client = MlflowClient(tracking_uri="http://127.0.0.1:8080")

experiment_description = (
    "Nightingale is a bird call classification project."
)

experiment_tags = {
    "project_name": "nightingale",
    "mlflow.note.content": experiment_description,
}

# only run following command once to create the experiment after the server has been started for the first time
# client.create_experiment(name="Nightingale_Bird_Call_Classification", tags=experiment_tags)

# Use the fluent API to set the tracking uri and the active experiment
mlflow.set_tracking_uri("http://127.0.0.1:8080")

# Sets the current active experiment to the "Nightingale_Bird_Call_Classification" experiment and returns the Experiment metadata
nightingale_experiment = mlflow.set_experiment("Nightingale_Bird_Call_Classification")

# Define a run name for this iteration of training.
# If this is not set, a unique name will be auto-generated for your run.
run_name = "nightingale_classifier_test"

# Define an artifact path that the model will be saved to.
artifact_path = "classifier_nightingale"

#### Train classifier

In [None]:
history = bird_class_model.fit(train_ds,
                       epochs=20,
                       validation_data=val_ds,
                       callbacks=callback)

#### Evaluate classifier

In [None]:
loss, accuracy = bird_class_model.evaluate(test_ds)

print("Loss: ", loss)
print("Accuracy: ", accuracy)

In [None]:
# Assemble the metrics we're going to write into a collection
metrics = {"Loss": loss, "Accuracy": accuracy}
params = {
    "num_bird_classes": num_bird_classes,
    "optimizer": "adam",
    "loss_function": "SparseCategoricalCrossentropy",
    "loss_from_logits": True,
    "epochs": len(history.epoch),
    "batch_size": 32,
    "early_stopping_monitor": "loss",
    "early_stopping_patience": 3,
}

# Initiate the MLflow run context
with mlflow.start_run(run_name=run_name) as run:
    # Log the parameters used for the model fit
    mlflow.log_params(params)

    # Log the error metrics that were calculated during validation
    mlflow.log_metrics(metrics)

    # Take one batch from the dataset
    x_batch, y_batch = next(iter(train_ds))

    # Convert to numpy (MLflow expects numpy or tensor-like input, not a tf.data.Dataset)
    input_example = x_batch.numpy()

    print("Shape of input_example:", input_example.shape)
    # Log an instance of the trained model for later use
    mlflow.tensorflow.log_model(model=bird_class_model, input_example=input_example, name=artifact_path)
    # mlflow.sklearn.log_model(sk_model=rf, input_example=X_val, name=artifact_path)
    

#### Run inference on a bird call audio sample (YAMNet + classifier head)

In [None]:
wav = load_wav_16k_mono(filtered_bird_df[filtered_bird_df['fold'] == 3]['filename'].values[1])
scores, embeddings, spectrogram = model(wav)
result = bird_class_model(embeddings).numpy()

inferred_class = bird_classes[result.mean(axis=0).argmax()]
print(f'The main sound is: {inferred_class}')