# Mass Embedding of Bioacoustic Audio

This notebook facilitates pre-computing embeddings of audio data for subsequent
use with search, classification, and analysis.

## Configuration and Imports.

In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from etils import epath
from ml_collections import config_dict
import numpy as np
import tensorflow as tf
import tqdm
from chirp import audio_utils
from chirp.inference import embed_lib
from chirp.inference import tf_examples

In [2]:
# Working dir for saving products
working_dir = "/workspaces/2023_ECCC4_Biodiv/data/"
embeddings_path = epath.Path(working_dir) / 'embeddings'

# Path to data
#data_to_embed = working_dir + "external_data_drive/" # if path to data is outside repo "/workspaces/data_folder/audio/Anas platyrhynchos/". Depending on the mount path in devcontainer.json
#data_to_embed_extention = ["*.WAV"]
data_to_embed = working_dir + "audio/" # if path to data is outside repo "/workspaces/data_folder/audio/Anas platyrhynchos/". Depending on the mount path in devcontainer.json
data_to_embed_extention = ["*.WAV"]

## Embed Audio

In [3]:
# Set up configs to be used by model
config = config_dict.ConfigDict()
config.embed_fn_config = config_dict.ConfigDict()
config.embed_fn_config.model_config = config_dict.ConfigDict()

config.source_file_patterns = data_to_embed_extention
config.source_file_root = data_to_embed
config.output_dir = embeddings_path.as_posix()

config.embed_fn_config.model_key = 'taxonomy_model_tf'
config.embed_fn_config.model_config.window_size_s = 5.0
config.embed_fn_config.model_config.hop_size_s = 5.0
config.embed_fn_config.model_config.sample_rate = 32000
config.embed_fn_config.model_config.tfhub_version = 8
config.embed_fn_config.model_config.model_path = ''

# Only write embeddings to reduce size.
config.embed_fn_config.write_embeddings = True
config.embed_fn_config.write_logits = False
config.embed_fn_config.write_separated_audio = False
config.embed_fn_config.write_raw_audio = False

# Use sharding
use_file_sharding = True
if use_file_sharding:
  config.shard_len_s = 60.0

# Number of parent directories to include in the filename.
config.embed_fn_config.file_id_depth = 1

In [4]:
# Set up the embedding function, including loading models.
embed_fn = embed_lib.EmbedFn(**config.embed_fn_config)
print('\nLoading model(s)...')
embed_fn.setup()

# Create output directory and write the configuration.
output_dir = epath.Path(config.output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
embed_lib.maybe_write_config(config, output_dir)

# Create SourceInfos.
source_infos = embed_lib.create_source_infos(
    config.source_file_patterns,
    config.source_file_root,
    num_shards_per_file=config.get('num_shards_per_file', -1),
    shard_len_s=config.get('shard_len_s', -1))
print(f'Found {len(source_infos)} source infos.')

# Set up
print('\n\nTest-run of model...')
window_size_s = config.embed_fn_config.model_config.window_size_s
sr = config.embed_fn_config.model_config.sample_rate
z = np.zeros([int(sr * window_size_s)], dtype=np.float32)
embed_fn.embedding_model.embed(z)
print('Setup complete!')


Loading model(s)...
Found 60 source infos.


Test-run of model...


I0000 00:00:1742402020.270868    4514 service.cc:145] XLA service 0x55e82f305680 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1742402020.270926    4514 service.cc:153]   StreamExecutor device (0): NVIDIA GeForce RTX 3070 Ti Laptop GPU, Compute Capability 8.6
W0000 00:00:1742402020.485330    4514 assert_op.cc:38] Ignoring Assert operator jax2tf_infer_fn_/assert_equal_1/Assert/AssertGuard/Assert


Setup complete!


I0000 00:00:1742402029.696755    4514 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


In [5]:
# Uses multiple threads to load audio before embedding.
# This tends to be faster, but can fail if any audio files are corrupt.
embed_fn.min_audio_s = 1.0
record_file = (output_dir / 'embeddings.tfrecord').as_posix()
succ, fail = 0, 0

existing_embedding_ids = embed_lib.get_existing_source_ids(
    output_dir, 'embeddings-*')

new_source_infos = embed_lib.get_new_source_infos(
    source_infos, 
    existing_embedding_ids,
    config.embed_fn_config.file_id_depth)

print(f'Found {len(existing_embedding_ids)} existing embedding ids. \n'
      f'Processing {len(new_source_infos)} new source infos. ')

try:
  audio_loader = lambda fp, offset: audio_utils.load_audio_window(
      fp,
      offset,
      sample_rate=config.embed_fn_config.model_config.sample_rate,
      window_size_s=config.get('shard_len_s', -1.0))
  audio_iterator = audio_utils.multi_load_audio_window(
      filepaths=[s.filepath for s in new_source_infos],
      offsets=[s.shard_num * s.shard_len_s for s in new_source_infos],
      audio_loader=audio_loader,
  )
  with tf_examples.EmbeddingsTFRecordMultiWriter(
      output_dir=output_dir,
      num_files=config.get('tf_record_shards', 1)) as file_writer:
    for source_info, audio in tqdm.tqdm(
        zip(new_source_infos, audio_iterator),
        total=len(new_source_infos)):
      if not embed_fn.validate_audio(source_info, audio):
        continue
      file_id = source_info.filepath
      offset_s = source_info.shard_num * source_info.shard_len_s
      example = embed_fn.audio_to_example(file_id, offset_s, audio)
      if example is None:
        fail += 1
        continue
      file_writer.write(example.SerializeToString())
      succ += 1
    file_writer.flush()
finally:
  del(audio_iterator)
print(f'\n\nSuccessfully processed {succ} source_infos, failed {fail} times.')

fns = [fn for fn in output_dir.glob('embeddings-*')]
ds = tf.data.TFRecordDataset(fns)
parser = tf_examples.get_example_parser()
ds = ds.map(parser)
for ex in ds.as_numpy_iterator():
  print(ex['filename'])
  print(ex['embedding'].shape, flush=True)
  break

Found 0 existing embedding ids. 
Processing 60 new source infos. 


  0%|          | 0/60 [00:00<?, ?it/s]W0000 00:00:1742402033.966301    4514 assert_op.cc:38] Ignoring Assert operator jax2tf_infer_fn_/assert_equal_1/Assert/AssertGuard/Assert



100%|██████████| 60/60 [00:14<00:00,  4.08it/s]



Successfully processed 50 source_infos, failed 0 times.
b'/workspaces/2023_ECCC4_Biodiv/data/audio/20240429_074100_done.WAV'
(12, 1, 1280)



