# Agile Modeling for Bioacoustics.

This notebook provides a single-machine workflow for using pre-trained models to embed raw audio files, search, and create classifiers for target signals. This notebook is ideal for a single machine with a GPU for accelarated embedding.

## Configuration and Imports.

In [0]:
 #@title Imports. { vertical-output: true }

import collections
from etils import epath
from ml_collections import config_dict
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tqdm
from chirp.inference import colab_utils
colab_utils.initialize(use_tf_gpu=True, disable_warnings=True)

from chirp import audio_utils
from chirp.inference import embed_lib
from chirp.inference import tf_examples
from chirp.inference import models
from chirp.models import metrics
from chirp.inference.search import bootstrap
from chirp.inference.search import search
from chirp.inference.search import display
from chirp.inference.classify import classify
from chirp.inference.classify import data_lib


In [0]:
#@title Basic Configuration. { vertical-output: true }

# Define the model: Usually perch or birdnet.
model_choice = 'perch'  #@param
# Set the base directory for the project.
working_dir = '/tmp/agile'  #@param

# Set the embedding and labeled data directories.
embeddings_path = epath.Path(working_dir) / 'embeddings'
labeled_data_path = epath.Path(working_dir) / 'labeled'
embeddings_glob = embeddings_path / 'embeddings-*'

# OPTIONAL: Set up separation model.
separation_model_key = 'separator_model_tf'  #@param
separation_model_path = ''  #@param


In [0]:
#@title Load Existing Project State and Models. { vertical-output: true }

# If you have already computed embeddings, run this cell to load models
# and find existing data.

if (embeddings_path / 'config.json').exists():
  # Get relevant info from the embedding configuration.
  bootstrap_config = bootstrap.BootstrapConfig.load_from_embedding_config(
      embeddings_path=embeddings_path,
      annotated_path=labeled_data_path)
  project_state = bootstrap.BootstrapState(bootstrap_config)

  if (bootstrap_config.model_key == 'separate_embed_model'
      and not separation_model_path.strip()):
    separation_model_key = 'separator_model_tf'
    separation_model_path = bootstrap_config.model_config.separator_model_tf_config.model_path

# Load separation model.
if separation_model_path:
  separation_config = config_dict.ConfigDict({
      'model_path': separation_model_path,
      'frame_size': 32000,
      'sample_rate': 32000,
  })
  separator = models.model_class_map()[
      separation_model_key].from_config(separation_config)
  print('Loaded separator model at {}'.format(separation_model_path))
else:
  print('No separation model loaded.')
  separator = None

## Embed Audio

In [0]:
#@title Embedding Configuration. { vertical-output: true }

config = config_dict.ConfigDict()
config.embed_fn_config = config_dict.ConfigDict()
config.embed_fn_config.model_config = config_dict.ConfigDict()

# IMPORTANT: Select the targe audio files.
# source_file_patterns should contain a list of globs of audio files, like:
# ['/home/me/*.wav', '/home/me/other/*.flac']
config.source_file_patterns = ['']  #@param
config.output_dir = embeddings_path.as_posix()

# For Perch, set the perch_tfhub_model_version, and the model will load
# automagically from TFHub. Alternatively, set the model path for a local
# copy of the model.
# Note that only one of perch_model_path and perch_tfhub_version should be set.
perch_tfhub_version = 4  #@param
perch_model_path = ''  #@param

# For BirdNET, point to the specific tflite file.
birdnet_model_path = ''  #@param
if model_choice == 'perch':
  config.embed_fn_config.model_key = 'taxonomy_model_tf'
  config.embed_fn_config.model_config.window_size_s = 5.0
  config.embed_fn_config.model_config.hop_size_s = 5.0
  config.embed_fn_config.model_config.sample_rate = 32000
  config.embed_fn_config.model_config.tfhub_version = perch_tfhub_version
  config.embed_fn_config.model_config.model_path = perch_model_path
elif model_choice == 'birdnet':
  config.embed_fn_config.model_key = 'birdnet'
  config.embed_fn_config.model_config.window_size_s = 3.0
  config.embed_fn_config.model_config.hop_size_s = 3.0
  config.embed_fn_config.model_config.sample_rate = 48000
  config.embed_fn_config.model_config.model_path = birdnet_model_path
  # Note: The v2_1 class list is appropriate for Birdnet 2.1, 2.2, and 2.3.
  config.embed_fn_config.model_config.class_list_name = 'birdnet_v2_1'
  config.embed_fn_config.model_config.num_tflite_threads = 4

# Only write embeddings to reduce size.
config.embed_fn_config.write_embeddings = True
config.embed_fn_config.write_logits = False
config.embed_fn_config.write_separated_audio = False
config.embed_fn_config.write_raw_audio = False

# Number of parent directories to include in the filename.
config.embed_fn_config.file_id_depth = 1

In [0]:
#@title Set up. { vertical-output: true }

# Set up the embedding function, including loading models.
embed_fn = embed_lib.EmbedFn(**config.embed_fn_config)
print('\n\nLoading model(s)...')
embed_fn.setup()

# Create output directory and write the configuration.
output_dir = epath.Path(config.output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
embed_lib.maybe_write_config(config, output_dir)

# Create SourceInfos.
source_infos = embed_lib.create_source_infos(
    config.source_file_patterns,
    num_shards_per_file=config.get('num_shards_per_file', -1),
    shard_len_s=config.get('shard_len_s', -1))
print(f'Found {len(source_infos)} source infos.')

print('\n\nTest-run of model...')
window_size_s = config.embed_fn_config.model_config.window_size_s
sr = config.embed_fn_config.model_config.sample_rate
z = np.zeros([int(sr * window_size_s)])
embed_fn.embedding_model.embed(z)
print('Setup complete!')

In [0]:
#@title Run embedding. { vertical-output: true }

# Uses multiple threads to load audio before embedding.
# This tends to be faster, but can fail if any audio files are corrupt.

embed_fn.min_audio_s = 1.0
record_file = (output_dir / 'embeddings.tfrecord').as_posix()
succ, fail = 0, 0

existing_embedding_ids = embed_lib.get_existing_source_ids(
    output_dir, 'embeddings-*')

new_source_infos = embed_lib.get_new_source_infos(
    source_infos, existing_embedding_ids, config.embed_fn_config.file_id_depth)

print(f'Found {len(new_source_infos)} existing embedding ids.'
      f'Processing {len(new_source_infos)} new source infos. ')

audio_iterator = audio_utils.multi_load_audio_window(
    filepaths=[s.filepath for s in new_source_infos],
    offsets=[s.shard_num * s.shard_len_s for s in new_source_infos],
    sample_rate=config.embed_fn_config.model_config.sample_rate,
    window_size_s=config.get('shard_len_s', -1.0),
)
with tf_examples.EmbeddingsTFRecordMultiWriter(
    output_dir=output_dir, num_files=config.get('tf_record_shards', 1)) as file_writer:
  for source_info, audio in tqdm.tqdm(
      zip(new_source_infos, audio_iterator), total=len(new_source_infos)):
    file_id = source_info.file_id(config.embed_fn_config.file_id_depth)
    offset_s = source_info.shard_num * source_info.shard_len_s
    example = embed_fn.audio_to_example(file_id, offset_s, audio)
    if example is None:
      fail += 1
      continue
    file_writer.write(example.SerializeToString())
    succ += 1
  file_writer.flush()
print(f'\n\nSuccessfully processed {succ} source_infos, failed {fail} times.')

fns = [fn for fn in output_dir.glob('embeddings-*')]
ds = tf.data.TFRecordDataset(fns)
parser = tf_examples.get_example_parser()
ds = ds.map(parser)
for ex in ds.as_numpy_iterator():
  print(ex['filename'])
  print(ex['embedding'].shape, flush=True)
  break

# Load/refresh bootstrap_config for subsequent steps.
print('\nRefreshing bootstrap_config.', flush=True)
bootstrap_config = bootstrap.BootstrapConfig.load_from_embedding_config(
    embeddings_path=embeddings_path,
    annotated_path=labeled_data_path)

project_state = bootstrap.BootstrapState(bootstrap_config)


## Search Embeddings

### Query Creation

In [0]:
#@title Load query audio. { vertical-output: true }

# Point to an audio file, Xeno-Canto id (like 'xc12345') or audio file URL.
audio_path = 'xc871667'  #@param
# Muck around with manual selection of the query start time...
start_s = 0  #@param

window_s = bootstrap_config.model_config['window_size_s']
sample_rate = bootstrap_config.model_config['sample_rate']
audio = audio_utils.load_audio(audio_path, sample_rate)

# Display the full file.
display.plot_audio_melspec(audio, sample_rate)

# Display the selected window.
print('-' * 80)
print('Selected audio window:')
st = int(start_s * sample_rate)
end = int(st + window_s * sample_rate)
if end > audio.shape[0]:
  end = audio.shape[0]
  st = max([0, int(end - window_s * sample_rate)])
audio_window = audio[st:end]
display.plot_audio_melspec(audio_window, sample_rate)

query_audio = audio_window
sep_outputs = None

In [0]:
#@title Separate the target audio window { vertical-output: true }

if separator is not None:
  sep_outputs = separator.embed(audio_window)

  for c in range(sep_outputs.separated_audio.shape[0]):
    print(f'Channel {c}')
    display.plot_audio_melspec(sep_outputs.separated_audio[c, :], sample_rate)
    print('-' * 80)
else:
  sep_outputs = None
  print('No separation model loaded.')

In [0]:
#@title Select the query channel. { vertical-output: true }

query_label = 'some_audio'  #@param
query_channel = 0  #@param

if query_channel < 0 or sep_outputs is None:
  query_audio = audio_window
else:
  query_audio = sep_outputs.separated_audio[query_channel].copy()

display.plot_audio_melspec(query_audio, sample_rate)

outputs = project_state.embedding_model.embed(query_audio)
query = outputs.pooled_embeddings('first', 'first')


### Execute Search

In [0]:
#@title Run Top-K Search. { vertical-output: true }

# Number of search results to capture.
top_k = 25  #@param

# Target distance for search results.
# This lets us try to hone in on a 'classifier boundary' instead of just
# looking at the closest matches.
# Set to 'None' for raw 'best results' search.
target_score = None  #@param

metric = 'mip'  #@param['euclidean', 'mip', 'cosine']

random_sample = False  #@param

ds = project_state.create_embeddings_dataset()
results, all_scores = search.search_embeddings_parallel(
    ds, query,
    hop_size_s=bootstrap_config.embedding_hop_size_s,
    top_k=top_k, target_score=target_score, score_fn=metric,
    random_sample=random_sample)

# Plot histogram of distances
ys, _, _ = plt.hist(all_scores, bins=128, density=True)
hit_scores = [r.score for r in results.search_results]
plt.scatter(hit_scores, np.zeros_like(hit_scores), marker='|',
            color='r', alpha=0.5)

plt.xlabel(metric)
plt.ylabel('density')
if target_score is not None:
  plt.plot([target_score, target_score], [0.0, np.max(ys)], 'r:')
  # Compute the proportion of scores < target_score
  hit_percentage = (all_scores < target_score).mean()
  print(f'score < target_score percentage : {hit_percentage:5.3f}')
min_score = np.min(all_scores)
plt.plot([min_score, min_score], [0.0, np.max(ys)], 'g:')

plt.show()


In [0]:
#@title Display results. { vertical-output: true }

display.display_search_results(
    results, sample_rate, project_state.source_map,
    checkbox_labels=[query_label, 'unknown'],
    max_workers=5)

In [0]:
#@title Write annotated examples. { vertical-output: true }

results.write_labeled_data(bootstrap_config.annotated_path,
                           project_state.embedding_model.sample_rate)

## Active Learning for a Target Class

In [0]:
# @title Load+Embed the Labeled Dataset. { vertical-output: true }

# Time-pooling strategy for examples longer than the model's window size.
time_pooling = 'mean'  # @param

merged = data_lib.MergedDataset.from_folder_of_folders(
    base_dir=labeled_data_path,
    embedding_model=project_state.embedding_model,
    time_pooling=time_pooling,
    load_audio=False,
    target_sample_rate=-2,
    audio_file_pattern='*',
    embedding_config_hash=bootstrap_config.embedding_config_hash(),
)

# Label distribution
lbl_counts = np.sum(merged.data['label_hot'], axis=0)
print('num classes :', (lbl_counts > 0).sum())
print('mean ex / class :', lbl_counts.sum() / (lbl_counts > 0).sum())
print('min ex / class :', (lbl_counts + (lbl_counts == 0) * 1e6).min())

In [0]:
#@title Train small model over embeddings. { vertical-output: true }

# Number of random training examples to choose form each class.
# Set exactly one of train_ratio and train_examples_per_class
train_ratio = None  #@param
train_examples_per_class = 2  #@param

# Number of random re-trainings. Allows judging model stability.
num_seeds = 1  #@param

# Classifier training hyperparams.
# These should be good defaults.
batch_size = 32
num_epochs = 128
num_hiddens = -1
learning_rate = 1e-3

metrics = collections.defaultdict(list)
for seed in tqdm.tqdm(range(num_seeds)):
  if num_hiddens > 0:
    model = classify.get_two_layer_model(
        num_hiddens, merged.embedding_dim, merged.num_classes)
  else:
    model = classify.get_linear_model(
        merged.embedding_dim, merged.num_classes)
  run_metrics = classify.train_embedding_model(
      model, merged, train_ratio, train_examples_per_class,
      num_epochs, seed, batch_size, learning_rate)
  metrics['acc'].append(run_metrics.top1_accuracy)
  metrics['auc_roc'].append(run_metrics.auc_roc)
  metrics['cmap'].append(run_metrics.cmap_value)
  metrics['maps'].append(run_metrics.class_maps)
  metrics['test_logits'].append(run_metrics.test_logits)

mean_acc = np.mean(metrics['acc'])
mean_auc = np.mean(metrics['auc_roc'])
mean_cmap = np.mean(metrics['cmap'])
# Merge the test_logits into a single array.
test_logits = {
    k: np.concatenate([logits[k] for logits in metrics['test_logits']])
    for k in metrics['test_logits'][0].keys()
}

print(f'acc:{mean_acc:5.2f}, auc_roc:{mean_auc:5.2f}, cmap:{mean_cmap:5.2f}')
for lbl, auc in zip(merged.labels, run_metrics.class_maps):
  if np.isnan(auc):
    continue
  print(f'\n{lbl:8s}, auc_roc:{auc:5.2f}')
  colab_utils.prstats(f'test_logits({lbl})',
                      test_logits[merged.labels.index(lbl)])

In [0]:
#@title Run model on target unlabeled data. { vertical-output: true }

# Choose the target class to work with.
target_class = 'some_audio'  #@param
# Choose a target logit; will display results close to the target.
# Set to None to get the highest-logit examples.
target_logit = None  #@param
# Number of results to display.
num_results = 25  #@param

# Create the embeddings dataset.
embeddings_ds = tf_examples.create_embeddings_dataset(
    embeddings_path, file_glob='embeddings-*')
target_class_idx = merged.labels.index(target_class)
results, all_logits = search.classifer_search_embeddings_parallel(
    embeddings_classifier=model,
    target_index=target_class_idx,
    embeddings_dataset=embeddings_ds,
    hop_size_s=bootstrap_config.embedding_hop_size_s,
    target_score=target_logit,
    top_k=num_results
)

# Plot the histogram of logits.
_, ys, _ = plt.hist(all_logits, bins=128, density=True)
plt.xlabel(f'{target_class} logit')
plt.ylabel('density')
# plt.yscale('log')
plt.plot([target_logit, target_logit], [0.0, np.max(ys)], 'r:')
plt.show()


In [0]:
#@title Display results for the target label. { vertical-output: true }

display_labels = merged.labels

extra_labels = []  #@param
for label in extra_labels:
  if label not in merged.labels:
    display_labels += (label,)
if 'unknown' not in merged.labels:
  display_labels += ('unknown',)

display.display_search_results(
    results, project_state.embedding_model.sample_rate,
    project_state.source_map,
    checkbox_labels=display_labels,
    max_workers=5)

In [0]:
#@title Add selected results to the labeled data. { vertical-output: true }

results.write_labeled_data(
    bootstrap_config.annotated_path,
    project_state.embedding_model.sample_rate)

## Inference

In [0]:
#@title Write classifier inference CSV. { vertical-output: true }

threshold = 1.0  #@param
output_filepath = '/tmp/inference.csv'  #@param

# Create the embeddings dataset.
embeddings_ds = tf_examples.create_embeddings_dataset(
    embeddings_path, file_glob='embeddings-*')

def classify_batch(batch):
  """Classify a batch of embeddings."""
  emb = batch[tf_examples.EMBEDDING]
  emb_shape = tf.shape(emb)
  flat_emb = tf.reshape(emb, [-1, emb_shape[-1]])
  logits = model(flat_emb)
  logits = tf.reshape(
      logits, [emb_shape[0], emb_shape[1], tf.shape(logits)[-1]])
  # Take the maximum logit over channels.
  logits = tf.reduce_max(logits, axis=-2)
  batch['logits'] = logits
  return batch

inference_ds = tf_examples.create_embeddings_dataset(
    embeddings_path, file_glob='embeddings-*')
inference_ds = inference_ds.map(
    classify_batch, num_parallel_calls=tf.data.AUTOTUNE
)

with open(output_filepath, 'w') as f:
  # Write column headers.
  headers = ['filename', 'timestamp_s', 'label', 'logit']
  f.write(', '.join(headers) + '\n')
  for ex in tqdm.tqdm(inference_ds.as_numpy_iterator()):
    for t in range(ex['logits'].shape[0]):
      for i, label in enumerate(merged.labels):
        if ex['logits'][t, i] > threshold:
          offset = ex['timestamp_s'] + t * bootstrap_config.embedding_hop_size_s
          logit = '{:.2f}'.format(ex['logits'][t, i])
          row = [ex['filename'].decode('utf-8'),
                 '{:.2f}'.format(offset),
                 label, logit]
          f.write(', '.join(row) + '\n')
