# Small-Model Training

Train a small linear model over fixed embeddings, using a curated set of training data.



# Imports and Configuration.

In [None]:
#@title Imports. { vertical-output: true }

# Global imports
import collections
import os
import numpy as np
import tensorflow as tf
from etils import epath
import matplotlib.pyplot as plt
import tqdm

use_tf_gpu = True #@param
if not use_tf_gpu:
  tf.config.experimental.set_visible_devices([], "GPU")

# Chirp imports
from chirp import audio_utils
from chirp import path_utils
from chirp.preprocessing import pipeline
from chirp.models import frontend
from chirp.models import metrics
from chirp.inference import models
from chirp.inference import tf_examples
from chirp.projects.multicluster import classify
from chirp.projects.multicluster import data_lib


In [None]:
#@title Configure data locations and load model. { vertical-output: true }

# Path to TFRecords of unlabeled embeddings.
unlabeled_embeddings_path = '' #@param
embeddings_glob = epath.Path(unlabeled_embeddings_path) / '*'

# Hop-size used when creating the embeddings dataset.
embedding_hop_size_s = 5.0 #@param
# Number of folder name levels in embedding file id's.
file_id_depth = 1 #@param

# Globs for source audio files represented in the unlabeled embeddings.
# e.g., /data/project_audio/*/*.wav
audio_globs = [] #@param

# Path to the labeled wav data.
# Should be in 'folder-of-folders' format - a folder with sub-folders for
# each class of interest.
# Audio in sub-folders should be wav files.
# Audio should ideally be 5s audio clips, but the system is quite forgiving.
labeled_data_path = '' #@param

model_choice = 'perch' #@param['perch', 'birdnet']
# Path to the folder contianing the perch model, which you can get at:
# https://tfhub.dev/google/bird-vocalization-classifier
perch_path = '' #@param
# Path to a local copy of a BirdNet TFLite file.
birdnet_path = '' #@param

# Create the config and load the model given the provided information.
if model_choice == 'perch':
  model_key='taxonomy_model_tf'
  model_config = {
      'model_path': perch_path, 
      'window_size_s': 5.0, 
      'hop_size_s': embedding_hop_size_s, 
      'sample_rate': 32000
  }
elif model_choice == 'birdnet':
  model_key='birdnet'
  model_config = {
      'window_size_s': 3.0, 
      'hop_size_s': embedding_hop_size_s,
      'sample_rate': 48000,
  } 
else:
  raise ValueError(f'unknown model choice {model_choice=}')

config = bootstrap.BootstrapConfig(
    # Path to pre-generated embeddings TFRecord files.
    embeddings_glob=embeddings_glob,
    embedding_hop_size_s=embedding_hop_size_s,
    file_id_depth=file_id_depth,
    # Globs for audio files represented in the embeddings.
    audio_globs=audio_globs,

    # Path for storing annotated examples.
    annotated_path=labeled_data_path,

    # Embedding model info.
    # Needs to match the model used for the embeddings DB, of course...
    model_key=model_key,
    model_config=model_config)
project_state = bootstrap.BootstrapState(config)
embedding_model = project_state.embedding_model

# Supervised Learning.

In [None]:
#@title Load+Embed the Labeled Dataset. { vertical-output: true }

# Time-pooling strategy for examples longer than the model's window size.
time_pooling = 'mean' #@param

merged = data_lib.MergedDataset(config.annotated_path, 
                                embedding_model, 
                                time_pooling=time_pooling)

# Label distribution
lbl_counts = np.sum(merged.data['label_hot'], axis=0)
print('num classes :', (lbl_counts > 0).sum())
print('mean ex / class :', lbl_counts.sum() / (lbl_counts > 0).sum())
print('min ex / class :', (lbl_counts + (lbl_counts == 0) * 1e6).min())


In [None]:
#@title Train linear model over embeddings. { vertical-output: true }

# Number of random training examples to choose form each class.
example_per_class = 128 #@param

# Number of random re-trainings. Allows judging model stability.
num_seeds = 1 #@param

# Classifier training hyperparams.
# These should be good defaults.
batch_size = 32
num_epochs = 128
num_hiddens = -1
learning_rate = 1e-3

metrics = collections.defaultdict(list)
for seed in range(num_seeds):
  if num_hiddens > 0:
    model = classify.get_two_layer_model(
        num_hiddens, merged.embedding_dim, merged.num_classes)
  else:
    model = classify.get_linear_model(
        merged.embedding_dim, merged.num_classes)
  run_metrics = classify.train_embedding_model(
      model, merged, example_per_class, num_epochs, seed, batch_size, learning_rate)
  metrics['acc'].append(run_metrics.top1_accuracy)
  metrics['auc_roc'].append(run_metrics.auc_roc)
  metrics['cmap'].append(run_metrics.cmap_value)
  metrics['maps'].append(run_metrics.class_maps)
  metrics['recall'].append(run_metrics.recall)
mean_acc = np.mean(metrics['acc'])
mean_auc_roc = np.mean(metrics['auc_roc'])
mean_cmap = np.mean(metrics['cmap'])
mean_recall = np.mean(metrics['recall'])
print(f'{example_per_class:d},  acc:{mean_acc:5.2f},  '
      f'auc_roc:{mean_auc_roc:5.2f},  cmap:{mean_cmap:5.2f},  '
      f'recall:{mean_recall:5.2f}')
for lbl, auc in zip(merged.labels, run_metrics.class_maps):
  if np.isnan(auc):
    continue
  print(f'{lbl:8s}, auc_roc:{auc:5.2f}')

# Evaluation on Unlabeled Data

In [None]:
#@title Run model on target unlabeled data. { vertical-output: true }

# Choose the target class to work with.
target_class = '' #@param
# Choose a target logit; will display results close to the target.
target_logit = 2.0 #@param
# Number of results to display.
num_results = 25 #@param

# Create the embeddings dataset.
embeddings_ds = tf_examples.create_embeddings_dataset(unlabeled_embeddings_path)
target_class_idx = merged.labels.index(target_class)
results, all_logits = search.classifer_search_embeddings_parallel(
    embeddings_ds, model, target_class_idx, hop_size_s=5.0,
    target_logit=target_logit, top_k=num_results
)

# Plot the histogram of logits.
_, ys, _ = plt.hist(all_logits, bins=128, density=True)
plt.xlabel(f'{target_class} logit')
plt.ylabel(f'density')
plt.plot([target_logit, target_logit], [0.0, np.max(ys)], 'r:')
plt.show()


In [None]:
#@title Display results for the target label. { vertical-output: true }

display_labels = merged.labels

extra_labels = [] #@param
for label in extra_labels:
  if label not in merged.labels:
    display_labels += (label,)
if 'unknown' not in merged.labels:
  display_labels += ('unknown',)

display.display_search_results(
    results, embedding_model.sample_rate, 
    project_state.source_map, 
    checkbox_labels=display_labels,
    max_workers=5)

In [None]:
#@title Add selected results to the labeled data. { vertical-output: true }

results.write_labeled_data(
    config.annotated_path, embedding_model.sample_rate)