# Call density estimation

This notebook provides tools for performing call density analyses using a custom classifier.

See the "All Thresholds Barred" paper for our call density methodology: https://arxiv.org/abs/2402.15360

In [None]:
# @title Installation {vertical-output: true}

# @markdown You will likely need to work with `01_embed_audio.ipynb` and/or
# @markdown `02_agile_modeling.ipynb` before working with this notebook.
# @markdown
# @markdown Run this notebook in Google Colab by following
# @markdown [this link](https://github.com/google-research/perch-hoplite/blob/main/perch_hoplite/agile/03_call_density.ipynb).
# @markdown
# @markdown Run this cell to install the project dependencies.

!pip install git+https://github.com/google-research/perch-hoplite.git


In [None]:
# @title Imports

from matplotlib import pyplot as plt
from ml_collections import config_dict
import numpy as np

from perch_hoplite.agile import audio_loader
from perch_hoplite.agile import call_density
from perch_hoplite.agile import classifier
from perch_hoplite.agile import embed
from perch_hoplite.agile import embedding_display
from perch_hoplite.agile import source_info
from perch_hoplite.db import interface
from perch_hoplite.db import sqlite_usearch_impl

In [None]:
# @title Connect to database {vertical-output: true}

# @markdown Location of database containing audio embeddings:
db_path = "/tmp/hoplite"  # @param {type: "string"}
db = sqlite_usearch_impl.SQLiteUSearchDB.create(db_path)

all_classes = db.get_all_labels()
print("Existing db classes:\n")
for idx, c in enumerate(all_classes):
  print(f"{idx:3d}: {c}", end=("\n" if (idx + 1) % 5 == 0 else "\t"))

In [None]:
# @title Load agile classifier {vertical-output: true}

# @markdown Location of agile classifier:
agile_classifier_path = "/tmp/hoplite/agile_classifier_v2.pt"  # @param {type: "string"}
agile_classifier = classifier.LinearClassifier.load(agile_classifier_path)

embed_config = db.get_metadata("audio_sources")
audio_sources = source_info.AudioSources.from_config_dict(embed_config)
window_size_s = agile_classifier.embedding_model_config.model_config.get("window_size_s", 5.0)
sample_rate = agile_classifier.embedding_model_config.model_config.sample_rate
audio_filepath_loader = audio_loader.make_filepath_loader(
    audio_sources=audio_sources,
    window_size_s=window_size_s,
    sample_rate_hz=sample_rate,
)

In [None]:
# @title Set up new call density study {vertical-output: true}

# @markdown Pick a study name:
study_name = "all_thresholds_barred"  # @param {type: "string"}

# @markdown Pick a target class:
target_class = "amerob"  # @param {type: "string"}
target_class_idx = all_classes.index(target_class)

# @markdown Pick some quantile bounds for validation. Should be an ordered list,
# @markdown beginning with 0.0 and ending with 1.0.
quantile_bounds = [0.0, 0.5, 0.75, 0.875, 1.0]  # @param

# @markdown Pick the number of samples to validate per bin:
samples_per_bin = 25  # @param

# @markdown Pick a random seed for shuffling:
random_seed = 42  # @param

# Select and shuffle window ids. Here we're selecting all window ids from the
# database, to "simulate" classifier mistakes.
deployments_filter = None
recordings_filter = None
windows_filter = None
annotations_filter = None
window_ids = db.match_window_ids(
    deployments_filter=deployments_filter,
    recordings_filter=recordings_filter,
    windows_filter=windows_filter,
    annotations_filter=annotations_filter,
)
rng = np.random.default_rng(random_seed)
rng.shuffle(window_ids)

# @markdown Pick an optional sample size. If set to `0`, all matching windows are binned.
sample_size = 0  # @param {type: "number"}

# Truncate number of matching windows.
if sample_size > 0:
  window_ids = window_ids[:sample_size]

# Compute logits for matching windows.
logits = []
for window_ids_batch in embed.batched(window_ids, 256):
  embeddings_batch = db.get_embeddings_batch(window_ids_batch)
  logits_batch = agile_classifier(embeddings_batch)
  logits_batch = logits_batch[..., target_class_idx]
  logits.extend(logits_batch)

# Get existing annotations for validation examples from the database.
# If more than one annotation is present, we pick the last one.
annotations = []
for window_id in window_ids:
  window = db.get_window(window_id)
  recording = db.get_recording(window.recording_id)
  matching_annotations = db.get_all_annotations(
      filter=config_dict.create(
          eq=dict(
              recording_id=window.recording_id,
              label=target_class,
          ),
          approx=dict(offsets=window.offsets),
      )
  )
  annotations.append(
      matching_annotations[-1]  # Pick last match.
      if matching_annotations
      else None
  )

# Compile all info into a call density study that can be stored in the database
# and later iterated on.
study = call_density.CallDensityConfig(
    study_name=study_name,
    classifier=agile_classifier.to_config_dict(),
    target_class=target_class,
    quantile_bounds=quantile_bounds,
    samples_per_bin=samples_per_bin,
    window_ids=window_ids,
    logits=logits,
    annotations=annotations,
    deployments_filter=deployments_filter,
    recordings_filter=recordings_filter,
    windows_filter=windows_filter,
    annotations_filter=annotations_filter,
)

# Plot logits distribution.
fig, ax = plt.subplots()
ax.set_title(
    f"Logits distribution ({study.target_class})"
    f"\nquantile bounds: {study.quantile_bounds}"
    f"\nvalue bounds: {[round(float(v), 2) for v in study.value_bounds]}"
)
ax.set_xlabel("logits")
ys, _, _, = ax.hist(logits, bins=100, density=True)
for q in study.value_bounds:
  ax.plot([q, q], [0.0, np.max(ys)], "k:", alpha=0.75)
plt.show()

# Plot logits distribution by bin.
fig, ax = plt.subplots()
ax.set_title(f"Logits distribution by bin ({study.target_class})")
ax.set_xlabel("bins")
ax.set_ylabel("bin sizes")
bin_names = [str(b) for b in study.bins_dict.keys()]
bin_sizes = [len(b) for b in study.bins_dict.values()]
ax.bar(bin_names, bin_sizes)
plt.show()

In [None]:
# @title Save study in the database {vertical-output: true}

try:
  call_density_configs = db.get_metadata("call_density_configs")
except KeyError:
  call_density_configs = config_dict.create()

study_db_key = f"{study_name}/{target_class}"
call_density_configs[study_db_key] = study.to_config_dict()

db.insert_metadata("call_density_configs", call_density_configs)
db.commit()
print("Call density study saved to database:", study_db_key)

In [None]:
# @title Display results for validation {vertical-output: true}

results = study.convert_bins_to_search_results()
display_results = embedding_display.EmbeddingDisplayGroup.from_search_results(
    results,
    db,
    sample_rate_hz=32000,
    frame_rate=100,
    audio_loader=audio_filepath_loader,
)
display_results.display(positive_labels=[target_class])

In [None]:
# @title Save validation data {vertical-output: true}

# @markdown Choose an annotator name:
annotator_id = "linnaeus"  # @param {type: "string"}

annotations_dict = display_results.harvest_annotated_windows(annotator_id, skip_uncertain=False)

# Save annotations in the database.
for window_id, ann_list in annotations_dict.items():
  for ann in ann_list:
    db.insert_annotation(
        recording_id=ann.recording_id,
        offsets=ann.offsets,
        label=ann.label,
        label_type=ann.label_type,
        provenance=ann.provenance,
        handle_duplicates="skip",
    )
    print("Saved new annotation:", ann)

# Update validation examples from annotations.
study.update_from_annotated_windows(annotations_dict)
call_density_configs[study_db_key] = study.to_config_dict()
db.insert_metadata("call_density_configs", call_density_configs)

# Commit db changes.
db.commit()

In [None]:
# @title Estimate call density and ROC-AUC {vertical-output: true}

validation_examples = study.select_validation_examples(
    label_types=[interface.LabelType.POSITIVE, interface.LabelType.NEGATIVE]
)
density_ev , density_samples = call_density.estimate_call_density(
    validation_examples)

# Plot call density estimate.
plt.figure(figsize=(10, 5))
xs, ys, _ = plt.hist(density_samples, density=True, bins=25, alpha=0.25)
plt.plot([density_ev, density_ev], [0.0, np.max(xs)], "k:", alpha=0.75,
         label="density_ev")

low, high = np.quantile(density_samples, [0.05, 0.95])
plt.plot([low, low], [0.0, np.max(xs)], "g", alpha=0.75, label="low conf")
plt.plot([high, high], [0.0, np.max(xs)], "g", alpha=0.75, label="high conf")

plt.xlim(0.0, 1.0)
plt.xlabel("Call Rate (q)")
plt.ylabel("P(q)")
plt.title(f"Call Density Estimation ({target_class})")
plt.legend()
plt.show()

print(f"EV Call Density: {density_ev:.4f}")
print(f"(Low/EV/High) Call Density Estimate: ({low:5.4f} / {density_ev:5.4f} / {high:5.4f})")

roc_auc_estimate = call_density.estimate_roc_auc(validation_examples)
print(f"Estimated ROC-AUC : {roc_auc_estimate:5.4f}")
