# Overview

This is an end-to-end demo of using Perch Hoplite to create a custom classifier using pre-trained embeddings and searching for a particular sound, focused on marine passive acoustic data and whale vocalization classification. For this demo, we are using all publicly available data from NOAA and NCEI's [Passive Acoustic Monitoring Data Archive](https://console.cloud.google.com/marketplace/details/noaa-public/passive_acoustic_monitoring?pli=1) hosted on Google Cloud. This tutorial uses the newer perch-hoplite infrastructure, but the earlier version with additional explanations based on the prior codebase (now out of date) can be found in the NeurIPS [2023 Climate Change AI Tutorial](https://colab.research.google.com/github/climatechange-ai-tutorials/bioacoustic-monitoring/blob/main/Agile_Modeling_for_Bioacoustic_Monitoring.ipynb).


## Note on Agile Modeling process
This notebook is a combination and more detailed end-to-end version of the two agile modeling notebooks in perch-hoplite. The first notebook [1_embed_audio_v2.ipynb](https://github.com/google-research/perch-hoplite/blob/main/perch_hoplite/agile/1_embed_audio_v2.ipy) is the first step in the agile modeling process where we create a database of embeddings of the audio we wish to search given a pre-trained model. We then would use the step 2 colab notebook ([2_agile_modeling_v2.ipynb](https://github.com/google-research/perch-hoplite/blob/main/perch_hoplite/agile/2_agile_modeling_v2.ipynb)) to search and classify that database for a particular call type. You can read more about the agile modeling process in [this paper](https://arxiv.org/abs/2505.03071).



## Set-up notes
For this demo, we recommend using a hosted Google Colab runtime and saving the output via Google Drive. To load this notebook in Google Colab - go to https://colab.research.google.com/. Under "Open Notebook" - go to the tab for "GitHub" and paste the url for this notebook there.

To connect to a hosted runtime, in the top right corner, click "Connect."
Then select "Connect to a hosted runtime." You should be automatically connected to a colab runtime. We recommend that you also select a GPU runtime - if you click "Change runtime type" you can confirm or switch the runtime type.


## [Optional] perch-hoplite installation for hosted runtimes

If you have not already installed perch-hoplite (particularly if you are using a hosted Colab runtime), make sure to install perch-hoplite from the Github source to ensure the most recent version is installed. After installation, you will need to restart your runtime before running anything else. Go to the top menu, select "Runtime" then "Restart Session".

In [None]:
#@title Only run this code if you need to install perch-hoplite
#@markdown You will likely be asked to restart your runtime, but after restarting, don't need to rerun this block.
!pip install git+https://github.com/google-research/perch-hoplite.git

In [None]:
# @title Imports
from etils import epath
from IPython.display import display
import ipywidgets as widgets
import numpy as np
from perch_hoplite.agile import colab_utils
from perch_hoplite.agile import embed
from perch_hoplite.agile import source_info
from perch_hoplite.db import brutalism
from perch_hoplite.db import interface

## For saving data in this example, we're going to use Google Drive.

This example is assuming you are running colab from a hosted runtime - this means the code is running on a cloud service not on your local machine, and thus won't have access to your local file directory. The raw audio files are coming from a public URL, so we will need a place to save the embeddings database and any results. While you can save all these created files to a temporary folder (e.g. '/tmp/...'), if your runtime crashes, you will potentially lose those files.



In [None]:
#@title Mount Google Drive for saving data
#@markdown When you run this, you will be prompted to authenticate to grant permissions for colab to save to your Google Drive.

#@markdown After you grant permissions, you will see a code that you will need to copy and paste in the form output that will be generated below.
import os
from google.colab import drive

drive.mount('/content/drive')

In [None]:
#@title Create a new folder in Drive (if it doesn't already exist) within your Google drive.
base_dir = '/content/drive/My Drive/'
#@ markdown Name of your new folder in Drive
new_folder_name = 'noaa_demo' #@param

drive_output_directory = base_dir + new_folder_name

try:
  if not os.path.exists(drive_output_directory):
    os.makedirs(drive_output_directory, exist_ok=True)
    print(f'Directory {drive_output_directory} created successfully.')
  else:
    print(f'Directory {drive_output_directory} already exists.')
except OSError as e:
    print("Error:", e)

# Embed the audio data

## Example data - NOAA PIFSC Saipan selection for Bryde's whale biotwangs
For this example, we are loading audio data from the [Passive Acoustic Data](https://www.ncei.noaa.gov/products/passive-acoustic-data) archive by NOAA and NCEI, where the files are stored on a [Google Cloud Bucket](https://console.cloud.google.com/marketplace/details/noaa-public/passive_acoustic_monitoring?pli=1). Note that the files are stored as .flac files, which needs to be uncompressed to access the full metadata associated with the .wav files in order to have correct time stamps (see the [README](https://storage.googleapis.com/noaa-passive-bioacoustic/pifsc/README.md)). Because these files are very large (generally over 8 hours of audio and over 1 GB each), we do need to shard the files for embedding. The files are all accessible via a public url, so if applying this on a different set of data, you'll need to specify a different pathway where the data are stored.

The file chosen for this example is one that was annotated and labeled specifically for the Bryde's whale biotwang and had a large number of detections of this particular sound found (see [paper](https://www.frontiersin.org/journals/marine-science/articles/10.3389/fmars.2024.1394695/full)). However, you can repeat this process or change out the file to try a new selection of audio.

## Starting model choices
Given the [multispecies whale model](https://www.kaggle.com/models/google/multispecies-whale) is pre-trained to detect Bryde's whale biotwangs (see [blog post for more details](https://research.google/blog/whistles-songs-boings-and-biotwangs-recognizing-whale-vocalizations-with-ai)), this is a good starting point for a pre-trained model. You can also select other models - [SurfPerch](https://www.kaggle.com/models/google/surfperch) would be another solid option.

In [None]:
# @title Configuration { vertical-output: true }

# @markdown Configure the raw dataset and output location(s).  The format is a mapping from
# @markdown a dataset_name to a (base_path, fileglob) pair.  Note that the file
# @markdown globs are case sensitive.  The dataset name can be anything you want.
#
# @markdown This structure allows you to move your data around without having to
# @markdown re-embed the dataset.  The generated embedding database will be
# @markdown placed in the base path. This allows you to simply swap out
# @markdown the base path here if you ever move your dataset.

# @markdown By default we only process one dataset at a time.  Re-run this entire portion [Embed] of the notebook
# @markdown once per dataset.

# @markdown For example, we might set dataset_base_path to '/home/me/myproject',
# @markdown and use the glob '\*/\*.wav' if all of the audio files have filepaths
# @markdown like '/home/me/myproject/site_XYZ/audio_ABC.wav' (e.g. audio files are contained in subfolders of the base directory).


# @markdown 1. Create a unique name for the database that will store the embeddings for the target data.
# @markdown For this example, we use the name of the large audio file, but you can use a different name here.
dataset_name = 'Saipan_A_06_151006_091215'  # @param {type:'string'}
# @markdown 2. Input the filepath for the folder that is containing the input audio files.
dataset_base_path = 'gs://noaa-passive-bioacoustic/pifsc/audio/pipan/saipan/pipan_saipan_06/audio'  #@param {type:'string'}
# @markdown 3. Input the file pattern for the audio files within that folder that you want to embed. Some examples for how to input:
# @markdown - All files in the base directory of a specific type (not subdirectories): e.g. `*.wav` (or `*.flac` etc) will generate embeddings for all .wav files (or whichever format) in the dataset_base_path
# @markdown - All files in one level of subdirectories within the base directory: `*/*.flac` will generate embeddings for all .flac files
# @markdown - Single file: `myfile.wav` will only embed the audio from that specific file.
dataset_fileglob = 'Saipan_A_06_151006_091215.df20.*.flac'  # @param {type:'string'}

# @markdown 4. [Optional] If saving the embeddings database to a new directory, specify here.
# @markdown Otherwise, leave blank - by default the embeddings database output will be saved within
# @markdown dataset_base_path where the audio is located. You do not need to specify db_path unless you want to maintain multiple
# @markdown distinct embedding databases, or if you would like to save the output
# @markdown in a different folder. If your input audio data is accessed
# @markdown from a public URL, we recommend specifying a separate output directory here.
db_subdir = '/agile_Saipan_A_06_151006_091215'  # @param {type:'string'}
db_path = drive_output_directory + db_subdir if db_subdir else None
if not db_path or db_path == 'None':
  db_path = None


# @markdown 5. Choose a supported model to generate embeddings: `perch_8` or `birdnet_v2.3` are most common
# @markdown for birds. Other choices include `surfperch` for coral reefs or
# @markdown `multispecies_whale` for marine mammals.
model_choice = 'surfperch'  #@param['perch_8', 'humpback', 'multispecies_whale', 'surfperch', 'birdnet_V2.3']

# @markdown 6. [Optional] Shard the audio for embeddings. File sharding automatically splits audio files into smaller chunks
# @markdown for creating embeddings. This limits both system and GPU memory usage,
# @markdown especially useful when working with long files (>1 hour).
use_file_sharding = True  # @param {type:'boolean'}
# @markdown If you want to change the length in seconds for the shards, specify here.
shard_length_in_seconds = 75  # @param {type:'number'}

# @markdown We also need to specify the targeted sample rate. -2 will give the target sample rate of the model,
# @markdown -1 will use the target sample rate of the original source audio, and any other number >0 will
# @markdown use that specified rate.
target_sample_rate_hz = -1  # @param {type:'number'}

audio_glob = source_info.AudioSourceConfig(
    dataset_name=dataset_name,
    base_path=dataset_base_path,
    file_glob=dataset_fileglob,
    min_audio_len_s=1.0,
    target_sample_rate_hz=target_sample_rate_hz,
    shard_len_s=float(shard_length_in_seconds) if use_file_sharding else None,
)

configs = colab_utils.load_configs(
    source_info.AudioSources((audio_glob,)),
    db_path,
    model_config_key=model_choice,
    db_key='sqlite_usearch',
)
configs

In [None]:
#@title Initialize the hoplite database (DB) { vertical-output: true }
global db
db = configs.db_config.load_db()
num_embeddings = db.count_embeddings()

print('Initialized DB located at ', configs.db_config.db_config.db_path)

def drop_and_reload_db(_) -> interface.HopliteDBInterface:
  db_path = epath.Path(configs.db_config.db_config.db_path)
  for fp in db_path.glob('hoplite.sqlite*'):
    fp.unlink()
  (db_path / 'usearch.index').unlink()
  print('\n Deleted previous db at: ', configs.db_config.db_config.db_path)
  db = configs.db_config.load_db()

#@markdown If `drop_existing_db` set to True, when the database already exists and contains embeddings,
#@markdown then those existing embeddings will be erased. You will be prompted to confirm you wish to delete those existing
#@markdown embeddings. If you want to keep existing embeddings in the database, then set to False, which will append the new
#@markdown embeddings to the database.
drop_existing_db = False  #@param {type:'boolean'}

if num_embeddings > 0 and drop_existing_db:
  print('Existing DB contains datasets: ', db.get_dataset_names())
  print('num embeddings: ', num_embeddings)
  print('\n\nClick the button below to confirm you really want to drop the database at ')
  print(f'{configs.db_config.db_config.db_path}\n')
  print(f'This will permanently delete all {num_embeddings} embeddings from the existing database.\n')
  print('If you do NOT want to delete this data, set `drop_existing_db` above to `False` and re-run this cell.\n')

  button = widgets.Button(description='Delete database?')
  button.on_click(drop_and_reload_db)
  display(button)

In [None]:
#@title Run the embedding { vertical-output: true }
#@markdown This may take approximately 15 minutes to run.

print(f'Embedding dataset: {audio_glob.dataset_name}')

worker = embed.EmbedWorker(
    audio_sources=configs.audio_sources_config,
    db=db,
    model_config=configs.model_config)

worker.process_all(target_dataset_name=audio_glob.dataset_name)

print('\n\nEmbedding complete, total embeddings: ', db.count_embeddings())

If you already have a database saved from running this earlier, you may get the following error running the above cell when trying to add more embeddings to the database:

```
AssertionError: The configured model key does not match the model key that is already in the DB.
```
For a given database name and saved location, the same embedding model and settings must be used. If you want to use a different embedding model, then you'll need to create a new database and save location in the "Configuration" cell.


In [None]:
#@title Per dataset statistics { vertical-output: true }
#@markdown This tells us how many unique segments are embedded in the database.

for dataset in db.get_dataset_names():
  print(f'\nDataset \'{dataset}\':')
  print('\tnum embeddings: ', db.get_embeddings_by_source(dataset, source_id=None).shape[0])

In [None]:
#@title Show example embedding search
#@markdown As an example (and to show that the embedding process worked), this
#@markdown selects a single embedding from the database and outputs the embedding ids of the
#@markdown top-K (k = 128) nearest neighbors in the database.

q = db.get_embedding(db.get_one_embedding_id())
%time results, scores = brutalism.brute_search(worker.db, query_embedding=q, search_list_size=128, score_fn=np.dot)
print([int(r.embedding_id) for r in results])

# Agile Modeling - Search and classify

For this example, we are going to search for "Biotwang" calls produced by Bryde's whales. The selection of audio above was chosen because of a large number of detections of Biotwangs over the duration of the audio.

In [None]:
#@title Imports
import os

from matplotlib import pyplot as plt
import numpy as np

from perch_hoplite.agile import audio_loader
from perch_hoplite.agile import classifier
from perch_hoplite.agile import classifier_data
from perch_hoplite.agile import embedding_display
from perch_hoplite.agile import source_info
from perch_hoplite.db  import brutalism
from perch_hoplite.db import score_functions
from perch_hoplite.db  import search_results
from perch_hoplite.db import sqlite_usearch_impl
from perch_hoplite.zoo import model_configs

In [None]:
#@title Load model and connect to database. { vertical-output: true }

#@markdown Location of database containing audio embeddings - if you are running this
#@markdown in the same session as the embeddings (e.g. you haven't had to restart your runtime),
#@markdown then you can leave this blank and it will fill in with the same db_path
#@markdown as the embeddings defined above. However, you can fill out the path
#@markdown if you are running this in a new session or want to load a different saved database.
load_db_path = ''  #@param {type:'string'}
if load_db_path is None:
  load_db_path = db_path
#@markdown Identifier (eg, the annotator's name or unique ID) to attach to labels produced during validation.
annotator_id = 'laurenharrell'  #@param {type:'string'}
#@markdown Sample rate for loading audio - for the NOAA raw data this is 10_000,
#@markdown but note that the model sample rates will be different from this rate.
#@markdown If left blank, then the sample rate will be input from the model's
#@markdown sample rate.
audio_loader_sample_rate_hz = 10_000  #@param {type:'number'}

db = sqlite_usearch_impl.SQLiteUsearchDB.create(db_path)
db_model_config = db.get_metadata('model_config')
embed_config = db.get_metadata('audio_sources')
model_class = model_configs.get_model_class(db_model_config.model_key)
embedding_model = model_class.from_config(db_model_config.model_config)
audio_sources = source_info.AudioSources.from_config_dict(embed_config)

if audio_loader_sample_rate_hz == None:
  audio_loader_sample_rate_hz = embedding_model.sample_rate

if hasattr(embedding_model, 'window_size_s'):
  window_size_s = embedding_model.window_size_s
else:
  window_size_s = 5.0
audio_filepath_loader = audio_loader.make_filepath_loader(
    audio_sources=audio_sources,
    window_size_s=window_size_s,
    sample_rate_hz=audio_loader_sample_rate_hz,
)

In [None]:
#@title Load query audio. { vertical-output: true }

#@markdown The `query_uri` can be a URL, filepath, or Xeno-Canto ID
#@markdown (like `xc777802`, containing an Eastern Whipbird (`easwhi1`)).
#@markdown We have a few pre-selected examples of Bryde's whale biotwang in a
#@markdown public folder on Google cloud, you can change the example by replacing
#@markdown the number 3 with any digit between 1 and 5.
query_uri = 'gs://bioacoustics-www1/multispecies_blog_media/Be_example3.wav'  #@param {type:'string'}
query_label = 'Be_biotwang'  #@param {type:'string'}


query = embedding_display.QueryDisplay(
    uri=query_uri, offset_s=0.0, window_size_s=5.0)
_ = query.display_interactive()

# Search

In [None]:
#@title Embed the Query and Search. { vertical-output: true }

#@markdown Number of results to find and display.
num_results = 100  #@param
query_embedding = embedding_model.embed(
    query.get_audio_window()).embeddings[0, 0]

#@markdown If checked, search for examples
#@markdown near a particular target score.
target_sampling = False  #@param {type: 'boolean'}

#@markdown When target sampling, target this score.
target_score = -1.0  #@param
if not target_sampling:
  target_score = None

#@markdown If True, search the full DB. Otherwise, use approximate
#@markdown nearest-neighbor search.
exact_search = True  #@param {type: 'boolean'}

score_function_name = 'dot' #@param['neg_euclidean','dot','cos']

if exact_search:
  score_fn = score_functions.get_score_fn(score_function_name, target_score=target_score)
  results, all_scores = brutalism.threaded_brute_search(
      db, query_embedding, num_results, score_fn=score_fn)
  # TODO(tomdenton): Better histogram when target sampling.
  _ = plt.hist(all_scores, bins=100)
  hit_scores = [r.sort_score for r in results.search_results]
  plt.scatter(hit_scores, np.zeros_like(hit_scores), marker='|',
              color='r', alpha=0.5)
else:
  ann_matches = db.ui.search(query_embedding, count=num_results)
  results = search_results.TopKSearchResults(top_k=num_results)
  for k, d in zip(ann_matches.keys, ann_matches.distances):
    results.update(search_results.SearchResult(k, d))

#@markdown Note: the query results will always be ordered where higher values indicate a stronger match for the target query/target score.

In [None]:
#@title Display Results. { vertical-output: true }

#@markdown Click on the button once for a positive label (will turn green), and
#@markdown a second click on the same button will change the color to orange indicating
#@markdown a negative label. If you want to undo a label - then click a third time to return the button to
#@markdown a default background and will be assumed unlabeled.

display_results = embedding_display.EmbeddingDisplayGroup.from_search_results(
    results, db, sample_rate_hz=audio_loader_sample_rate_hz, frame_rate=100,
    audio_loader=audio_filepath_loader)
display_results.display(positive_labels=[query_label])

In [None]:
#@title Save data labels. { vertical-output: true }
#@markdown Counts new labels added to the database.

prev_lbls, new_lbls = 0, 0
for lbl in display_results.harvest_labels(annotator_id):
  check = db.insert_label(lbl, skip_duplicates=True)
  new_lbls += check
  prev_lbls += (1 - check)
print('\nNew labels added: ', new_lbls)
print('\nLabeled query results that already existed: ', prev_lbls)

In [None]:
#@title Check how many labels of each class exist in the data
print('\nTotal positive labels per class: ', db.get_class_counts())
print('\nTotal negative labels per class: ', db.get_class_counts(label_type = interface.LabelType.NEGATIVE))

## If you don't have enough examples to start a classfier, try a different query!
If you have fewer than ~10-20 examples of your target class after saving the labels you just created, then try altering your query with either a new target example or playing around with a new target score or exact matching. You can repeat the above search steps as many times as you want.

# Classify
Once you think you have sufficient examples to try to train a model (hopefully at least 5), you can now try training a classifier on the data you embedded. Note that if you don't have any explicit negative labels, then the model assumes unlabeled data is "weak" negative labels, and the computation of metrics may have some errors. This won't affect your ability to keep iterating on the model, just measuring its performance.

In [None]:
#@title Classifier training. { vertical-output: true }

#@markdown Set of labels to classify. If None, auto-populated from the DB.
target_labels = None  #@param

#@markdown Classifier traning hyperparams. These should not require tuning.
learning_rate = 1e-3  #@param
weak_neg_weight = 0.05  #@param
l2_mu = 0.000  #@param
num_steps = 128  #@param

train_ratio = 0.9  #@param
batch_size = 128  #@param
weak_negatives_batch_size = 128  #@param
loss_fn_name = 'bce'  #@param ['hinge', 'bce']

data_manager = classifier_data.AgileDataManager(
    target_labels=target_labels,
    db=db,
    train_ratio=train_ratio,
    min_eval_examples=1,
    batch_size=batch_size,
    weak_negatives_batch_size=weak_negatives_batch_size,
    rng=np.random.default_rng(seed=5))
print('Training for target labels : ')
print(data_manager.get_target_labels())
linear_classifier, eval_scores = classifier.train_linear_classifier(
    data_manager=data_manager,
    learning_rate=learning_rate,
    weak_neg_weight=weak_neg_weight,
    num_train_steps=num_steps,
)
print('\n' + '-' * 80)
top1 = eval_scores['top1_acc']
print(f'top-1      {top1:.3f}')
rocauc = eval_scores['roc_auc']
print(f'roc_auc    {rocauc:.3f}')
cmap = eval_scores['cmap']
print(f'cmap       {cmap:.3f}')

#@markdown Save the linear classifier to the database folder
classifier_filename = 'Biotwang_agile_classifier_v2.pt'  #@param
linear_classifier.save(os.path.join(db_path, 'agile_classifier_v2.pt'))

In [None]:
#@title Review Classifier Results. { vertical-output: true }
#@markdown Not only can we examine outputs from the classifier to understand the
#@markdown classifier performance, we can also generate additional labels for
#@markdown the data and use for another round of training.

#@markdown Number of results to find and display.
target_label = 'Be_biotwang'  #@param {type:'string'}
num_results = 50  #@param

target_label_idx = data_manager.get_target_labels().index(target_label)
class_query = linear_classifier.beta[:, target_label_idx]
bias = linear_classifier.beta_bias[target_label_idx]

#@markdown Number of (randomly selected) database entries to search over.
sample_size = 1_000  #@param

#@markdown Whether to use margin-sampling. If checked, search for examples
#@markdown with logits near a particular target score. A typical target score for
#@markdown margin sampling is 0.0, which would find examples where the model is more
#@markdown uncertain. Higher positive target scores would be more likely to be positive,
#@markdown while targeting lower negative scores would produce results that are more likely
#@markdown to be negatives.
margin_sampling = False  #@param {type: 'boolean'}

#@markdown When margin sampling, target this logit.
margin_target_score = 0.0  #@param
if not margin_sampling:
  margin_target_score = None
score_fn = score_functions.get_score_fn(
    'dot', bias=bias, target_score=margin_target_score)
results, all_scores = brutalism.threaded_brute_search(
    db, class_query, num_results, score_fn=score_fn,
    sample_size=sample_size)

# TODO(tomdenton): Better histogram when margin sampling.
_ = plt.hist(all_scores, bins=100)
hit_scores = [r.sort_score for r in results.search_results]
plt.scatter(hit_scores, np.zeros_like(hit_scores), marker='|',
            color='r', alpha=0.5)



In [None]:
#@title Display results and annotate the output (make sure to save by running the next cell) { vertical-output: true }
#@markdown Reminder to click the label button once to mark as a positive label,
#@markdown twice for a negative label, and a third time to reset (assumed unlabeled/weak negative).
display_results = embedding_display.EmbeddingDisplayGroup.from_search_results(
    results, db, sample_rate_hz=audio_loader_sample_rate_hz, frame_rate=100,
    audio_loader=audio_filepath_loader)
display_results.display(positive_labels=[target_label])

In [None]:
#@title Save data labels. { vertical-output: true }
#@markdown This will save the labels to the database, attached to the embedded examples.

prev_lbls, new_lbls = 0, 0
for lbl in display_results.harvest_labels(annotator_id):
  check = db.insert_label(lbl, skip_duplicates=True)
  new_lbls += check
  prev_lbls += (1 - check)
print('\nNew labels added: ', new_lbls)
print('\nQuery examples that already existed: ', prev_lbls)

In [None]:
#@title Check how many labels of each class exist in the data
print('\nTotal positive labels per class: ', db.get_class_counts())
print('\nTotal negative labels per class: ', db.get_class_counts(label_type = interface.LabelType.NEGATIVE))

## Repeat and iterate!
You can repeat the search process from a given classfier by rerunning the "Review Classifier Results" cell (modify for different query settings), reviewing the output, and saving new labels created. We also recommend you try searching for more negative labels (scores around -1 or lower) and/or more labels with scores near 0, not just the highest scoring result which is the default setting.

Once you have more labels - try to repeat the classifier training by starting at the beginning of this section starting with "Classifier Training." The model training will now include the labels you just added from the initial classifier results.



# When you are done with your classifier

In [None]:
#@title Run inference with trained classifier and save results to a .csv { vertical-output: true }
#@markdown This will save the results of the classifier to a csv file.

output_csv_filepath = '/content/drive/My Drive/noaa_demo/Biotwang_agile_classifier_v2_results.csv' #@param {type:'string'}
# @markdown The threshold sets the minimum value for which the results will be saved for a given class.
# @markdown If set to 1.0, for example, only scores above 1.0 for a class will be saved to the csv.
logit_threshold = 0.0  #@param
# @markdown Set labels to a tuple of desired labels if you want to run inference on a
# @markdown subset of the labels. If None, then all labels in the data will be included.
labels = None  #@param

classifier.write_inference_csv(
    linear_classifier, db, output_csv_filepath, logit_threshold, labels=labels)


In [None]:
#@title [Optional] Read the saved csv and examine the results
import pandas as pd

results_df = pd.read_csv(output_csv_filepath)
display(results_df.head())

In [None]:
results_df.label.value_counts()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
sns.histplot(data=results_df, x='logits', hue='label', multiple='stack', bins=30)
plt.title('Distribution of Logits by Label')
plt.xlabel('Logits')
plt.ylabel('Frequency')
plt.show()

# What to try next?


1. Look pick another file from the NOAA data to embed and add to the existing database. The file 'Saipan_A_06_151005_015500.df20.x.flac' is the recording from the time period before the example file we used, and 'Saipan_A_06_151007_163630.df20.x.flac' is the next recording in time.

2. Play around with the classifier model training settings.

3. Create a new database with embeddings from audio an entirely different fileset - for example your own data saved on a Google Drive Folder.  