# Agile Modeling for Bioacoustics.

This notebook provides a workflow for creating custom classifiers for target signals, by first **searching** for training data, and then engaging in an **active learning** loop.

We assume that embeddings have been pre-computed using `embed.ipynb`.

# ATTENTION: 

There is a new version of this workflow avialable [here](https://github.com/google-research/perch-hoplite/blob/main/perch_hoplite/agile/1_embed_audio_v2.ipynb), in the new [Perch-Hoplite](https://github.com/google-research/perch-hoplite/blob/main/perch_hoplite) respository.

## Configuration and Imports.

In [0]:
#@title Installation. { vertical-output: true }
#@markdown Run this notebook in Google Colab by following [this link](https://colab.research.google.com/github/google-research/perch/blob/main/agile_modeling.ipynb).
#@markdown
#@markdown Run this cell to install the project dependencies.
%pip install git+https://github.com/google-research/perch.git


In [0]:
 #@title Imports. { vertical-output: true }

import collections
from etils import epath
from ml_collections import config_dict
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tqdm
from chirp.inference import colab_utils
colab_utils.initialize(use_tf_gpu=True, disable_warnings=True)

from chirp import audio_utils
from chirp.inference import baw_utils
from chirp.inference import tf_examples
from chirp.models import metrics
from perch_hoplite.taxonomy import namespace
from chirp.inference.search import bootstrap
from chirp.inference.search import search
from chirp.inference.search import display
from chirp.inference.classify import classify
from chirp.inference.classify import data_lib
from perch_hoplite.zoo import model_configs
from perch_hoplite.zoo import zoo_interface


In [0]:
#@title Basic Configuration. { vertical-output: true }

#@markdown Choose what data to work with.
#@markdown * For local data (most cases), choose 'filesystem'.
#@markdown * For Australian Acoustic Observatory, select 'a2o'.
#@markdown This will cause many options (like model_choice) to be overridden.
#@markdown Note that you will need an Authentication Token from:
#@markdown https://data.acousticobservatory.org/my_account
data_source = 'filesystem' #@param['filesystem', 'a2o']
baw_auth_token = '' #@param {type:'string'}

#@markdown Set the base directory for the project.
working_dir = '/tmp/agile'  #@param {type:'string'}

#@markdown Set the embedding and labeled data directories.
labeled_data_path = epath.Path(working_dir) / 'labeled'
custom_classifier_path = epath.Path(working_dir) / 'custom_classifier'

#@markdown The embeddings_path should be detected automatically, but can be
#@markdown overridden.
embeddings_path = ''

#@markdown OPTIONAL: Set up separation model.
separation_model_key = 'separator_model_tf'  #@param {type:'string'}
separation_model_path = ''  #@param {type:'string'}


In [0]:
#@title Load Project State and Models. { vertical-output: true }

if data_source == 'a2o':
  embedding_config = baw_utils.get_a2o_embeddings_config()
  bootstrap_config = bootstrap.BootstrapConfig.load_from_embedding_config(
      embedding_config=embedding_config,
      annotated_path=labeled_data_path,
      embeddings_glob = '*/embeddings-*')
  embeddings_path = embedding_config.output_dir
elif (embeddings_path
      or (epath.Path(working_dir) / 'embeddings/config.json').exists()):
  if not embeddings_path:
    # Use the default embeddings path, as it seems we found a config there.
    embeddings_path = epath.Path(working_dir) / 'embeddings'
  # Get relevant info from the embedding configuration.
  bootstrap_config = bootstrap.BootstrapConfig.load_from_embedding_path(
      embeddings_path=embeddings_path,
      annotated_path=labeled_data_path)
  if (bootstrap_config.model_key == 'separate_embed_model'
      and not separation_model_path.strip()):
    separation_model_key = 'separator_model_tf'
    separation_model_path = bootstrap_config.model_config.separator_model_tf_config.model_path
  baw_auth_token = ''
else:
  raise ValueError('No embedding configuration found.')

project_state = bootstrap.BootstrapState(
    bootstrap_config, baw_auth_token=baw_auth_token)

# Load separation model.
if separation_model_path:
  separation_config = config_dict.ConfigDict({
      'model_path': separation_model_path,
      'frame_size': 32000,
      'sample_rate': 32000,
  })
  separator = model_configs.MODEL_CLASS_MAP[
      separation_model_key].from_config(separation_config)
  print('Loaded separator model at {}'.format(separation_model_path))
else:
  print('No separation model loaded.')
  separator = None

## Search Embeddings

### Query Creation

In [0]:
#@title Load query audio. { vertical-output: true }

#@markdown You may specify:
#@markdown * an audio filepath (like `/home/me/audio/example.wav`),
#@markdown * a Xeno-Canto id (like `xc12345`), or
#@markdown * an audio file URL (like
#@markdown https://upload.wikimedia.org/wikipedia/commons/7/7c/Turdus_merula_2.ogg).
audio_path = 'xc692557'  #@param
#@markdown Choose the start time for the audio window within the file.
#@markdown We will focus on the model's `window_size_s` seconds of audio,
#@markdown starting from `start_s` seconds into the file.
start_s = 0  #@param

window_s = bootstrap_config.model_config['window_size_s']
sample_rate = bootstrap_config.model_config['sample_rate']
audio = audio_utils.load_audio(audio_path, sample_rate)

# Display the full file.
display.plot_audio_melspec(audio, sample_rate)

# Display the selected window.
print('-' * 80)
print('Selected audio window:')
st = int(start_s * sample_rate)
end = int(st + window_s * sample_rate)
if end > audio.shape[0]:
  end = audio.shape[0]
  st = max([0, int(end - window_s * sample_rate)])
audio_window = audio[st:end]
display.plot_audio_melspec(audio_window, sample_rate)

query_audio = audio_window
sep_outputs = None

In [0]:
#@title Separate the target audio window { vertical-output: true }

if separator is not None:
  sep_outputs = separator.embed(audio_window)

  for c in range(sep_outputs.separated_audio.shape[0]):
    print(f'Channel {c}')
    display.plot_audio_melspec(sep_outputs.separated_audio[c, :], sample_rate)
    print('-' * 80)
else:
  sep_outputs = None
  print('No separation model loaded.')

In [0]:
#@title Select the query channel. { vertical-output: true }

#@markdown Choose a name for the class.
query_label = 'my_class'  #@param
#@markdown If you have applied separation, choose a channel.
#@markdown Ignored if no separation model is being used.
query_channel = 0  #@param

if query_channel < 0 or sep_outputs is None:
  query_audio = audio_window
else:
  query_audio = sep_outputs.separated_audio[query_channel].copy()

display.plot_audio_melspec(query_audio, sample_rate)

outputs = project_state.embedding_model.embed(query_audio)
query = outputs.pooled_embeddings('first', 'first')


### Execute Search

In [0]:
#@title Run Top-K Search. { vertical-output: true }

#@markdown Number of search results to capture.
top_k = 50  #@param

#@markdown Target distance for search results.
#@markdown This lets us try to hone in on a 'classifier boundary' instead of
#@markdown just looking at the closest matches.
#@markdown Set to 'None' for raw 'best results' search.
target_score = None  #@param

#@markdown Maximimum Inner-Product (mip) generally gives best results.
metric = 'mip'  #@param['euclidean', 'mip', 'cosine']

#@markdown If True, produce a fully-random sample of data, ignoring similarity.
random_sample = False  #@param

ds = project_state.create_embeddings_dataset(shuffle_files=True)
results, all_scores = search.search_embeddings_parallel(
    ds, query,
    hop_size_s=bootstrap_config.embedding_hop_size_s,
    top_k=top_k, target_score=target_score, score_fn=metric,
    random_sample=random_sample)

# Plot histogram of distances
ys, _, _ = plt.hist(all_scores, bins=128, density=True)
hit_scores = [r.score for r in results.search_results]
plt.scatter(hit_scores, np.zeros_like(hit_scores), marker='|',
            color='r', alpha=0.5)

plt.xlabel(metric)
plt.ylabel('density')
if target_score is not None:
  plt.plot([target_score, target_score], [0.0, np.max(ys)], 'r:')
  # Compute the proportion of scores < target_score
  hit_percentage = (all_scores < target_score).mean()
  print(f'score < target_score percentage : {hit_percentage:5.3f}')
min_score = np.min(all_scores)
plt.plot([min_score, min_score], [0.0, np.max(ys)], 'g:')

plt.show()


In [0]:
#@title Display results. { vertical-output: true }

samples_per_page = 25
page_state = display.PageState(
    np.ceil(len(results.search_results) / samples_per_page))

display.display_paged_results(
    results, page_state, samples_per_page,
    project_state=project_state,
    embedding_sample_rate=project_state.embedding_model.sample_rate,
    exclusive_labels=False,
    checkbox_labels=[query_label, 'unknown'],
    max_workers=5,
)

In [0]:
#@title Write annotated examples. { vertical-output: true }

results.write_labeled_data(bootstrap_config.annotated_path,
                           project_state.embedding_model.sample_rate)

## Active Learning for a Target Class

In [0]:
# @title Load+Embed the Labeled Dataset. { vertical-output: true }

#@markdown Time-pooling strategy for audio longer than the model's window size.
time_pooling = 'mean'  # @param

merged = data_lib.MergedDataset.from_folder_of_folders(
    base_dir=labeled_data_path,
    embedding_model=project_state.embedding_model,
    time_pooling=time_pooling,
    load_audio=False,
    target_sample_rate=-2,
    audio_file_pattern='*',
    embedding_config_hash=bootstrap_config.embedding_config_hash(),
)

# Label distribution
lbl_counts = np.sum(merged.data['label_hot'], axis=0)
print('num classes :', (lbl_counts > 0).sum())
print('mean ex / class :', lbl_counts.sum() / (lbl_counts > 0).sum())
print('min ex / class :', (lbl_counts + (lbl_counts == 0) * 1e6).min())

In [0]:
#@title Train small model over embeddings. { vertical-output: true }

#@markdown Number of random training examples to choose form each class.
#@markdown Set exactly one of `train_ratio` and `train_examples_per_class`.
train_ratio = 0.9  #@param
train_examples_per_class = None  #@param

#@markdown Number of random re-trainings. Allows judging model stability.
num_seeds = 3  #@param

# Classifier training hyperparams.
# These should be good defaults.
batch_size = 32
num_epochs = 128
num_hiddens = -1
learning_rate = 1e-3

metrics = collections.defaultdict(list)
for seed in tqdm.tqdm(range(num_seeds)):
  if num_hiddens > 0:
    model = classify.get_two_layer_model(
        num_hiddens, merged.embedding_dim, merged.num_classes)
  else:
    model = classify.get_linear_model(
        merged.embedding_dim, merged.num_classes)
  run_metrics = classify.train_embedding_model(
      model, merged, train_ratio, train_examples_per_class,
      num_epochs, seed, batch_size, learning_rate)
  metrics['acc'].append(run_metrics.top1_accuracy)
  metrics['auc_roc'].append(run_metrics.auc_roc)
  metrics['cmap'].append(run_metrics.cmap_value)
  metrics['maps'].append(run_metrics.class_maps)
  metrics['test_logits'].append(run_metrics.test_logits)

mean_acc = np.mean(metrics['acc'])
mean_auc = np.mean(metrics['auc_roc'])
mean_cmap = np.mean(metrics['cmap'])
# Merge the test_logits into a single array.
test_logits = {
    k: np.concatenate([logits[k] for logits in metrics['test_logits']])
    for k in metrics['test_logits'][0].keys()
}

print(f'acc:{mean_acc:5.2f}, auc_roc:{mean_auc:5.2f}, cmap:{mean_cmap:5.2f}')
for lbl, auc in zip(merged.labels, run_metrics.class_maps):
  if np.isnan(auc):
    continue
  print(f'\n{lbl:8s}, auc_roc:{auc:5.2f}')
  colab_utils.prstats(f'test_logits({lbl})',
                      test_logits[merged.labels.index(lbl)])

In [0]:
#@title Run model on target unlabeled data. { vertical-output: true }

#@markdown Choose the target class to work with.
target_class = 'my_class'  #@param
#@markdown Choose a target logit; will display results close to the target.
#@markdown Set to None to get the highest-logit examples.
target_logit = 0.0  #@param
#@markdown Number of results to display.
num_results = 50  #@param

embeddings_ds = project_state.create_embeddings_dataset(
    shuffle_files=True)
target_class_idx = merged.labels.index(target_class)
results, all_logits = search.classifer_search_embeddings_parallel(
    embeddings_classifier=model,
    target_index=target_class_idx,
    embeddings_dataset=embeddings_ds,
    hop_size_s=bootstrap_config.embedding_hop_size_s,
    target_score=target_logit,
    top_k=num_results
)

# Plot the histogram of logits.
ys, _, _ = plt.hist(all_logits, bins=128, density=True)
plt.xlabel(f'{target_class} logit')
plt.ylabel('density')
# plt.yscale('log')
plt.plot([target_logit, target_logit], [0.0, np.max(ys)], 'r:')
plt.show()


In [0]:
#@title Display results for the target label. { vertical-output: true }

display_labels = merged.labels

#@markdown Specify any extra labels you would like displayed.
extra_labels = []  #@param
for label in extra_labels:
  if label not in merged.labels:
    display_labels += (label,)
if 'unknown' not in merged.labels:
  display_labels += ('unknown',)

samples_per_page = 25
page_state = display.PageState(
    np.ceil(len(results.search_results) / samples_per_page))

display.display_paged_results(
    results, page_state, samples_per_page,
    project_state=project_state,
    embedding_sample_rate=project_state.embedding_model.sample_rate,
    exclusive_labels=False,
    checkbox_labels=display_labels,
    max_workers=5,
)

In [0]:
#@title Add selected results to the labeled data. { vertical-output: true }

results.write_labeled_data(
    bootstrap_config.annotated_path,
    project_state.embedding_model.sample_rate)

In [0]:
#@title Save the Custom Classifier. { vertical-output: true }

wrapped_model = zoo_interface.LogitsOutputHead(
    model_path=custom_classifier_path.as_posix(),
    logits_key='logits',
    logits_model=model,
    class_list=namespace.ClassList('custom', merged.labels),
)
wrapped_model.save_model(
    custom_classifier_path,
    embeddings_path)

## Inference

In [0]:
#@title Write classifier inference CSV. { vertical-output: true }

#@markdown This cell writes detections (locations of audio windows where
#@markdown the logit was greater than a threshold) to a CSV file.

output_filepath = epath.Path(working_dir) / 'inference.csv'  #@param

#@markdown Set the default detection thresholds, used for all classes.
#@markdown To set per-class detection thresholds, modify the code below.
#@markdown Keep in mind that thresholds are on the logit scale, so 0.0
#@markdown corresponds to a 50% model confidence.
default_threshold = 0.0  #@param
if default_threshold is None:
  # In this case, all logits are written. This can lead to very large CSV files.
  class_thresholds = None
else:
  class_thresholds = collections.defaultdict(lambda: default_threshold)
  # Add any per-class thresholds here.
  class_thresholds['my_class'] = 1.0

#@markdown Classes to ignore when counting detections.
exclude_classes = ['unknown']  #@param

#@markdown The `include_classes` list is ignored if empty.
#@markdown If non-empty, only scores for these classes will be written.
include_classes = []  #@param

embeddings_ds = project_state.create_embeddings_dataset(
    shuffle_files=True)
classify.write_inference_csv(
    embeddings_ds=embeddings_ds,
    model=model,
    labels=merged.labels,
    output_filepath=output_filepath,
    threshold=class_thresholds,
    embedding_hop_size_s=bootstrap_config.embedding_hop_size_s,
    include_classes=include_classes,
    exclude_classes=exclude_classes)
