In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

from etils import epath
from ml_collections import config_dict
import numpy as np
import tensorflow as tf
import tqdm
import os

from chirp import audio_utils
from chirp.inference import embed_lib
from chirp.inference import tf_examples
import pandas as pd
import csv
import re

2024-06-20 15:49:09.900161: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-20 15:49:09.900226: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-20 15:49:09.900238: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-06-20 15:49:09.937197: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [7]:
model_choice = 'perch'

working_dir = '.'  


embeddings_path = epath.Path(working_dir) / 'embeddings_all'
labeled_data_path = epath.Path(working_dir) / 'labeled'
embeddings_glob = embeddings_path / 'embeddings-*'


In [8]:
config = config_dict.ConfigDict()
config.embed_fn_config = config_dict.ConfigDict()
config.embed_fn_config.model_config = config_dict.ConfigDict()


config.source_file_patterns = ['/home/mschulist/caples_sound/ARU_data_all/*/*.wav']
config.output_dir = embeddings_path.as_posix()


perch_tfhub_version = 8
perch_model_path = ''

config.embed_fn_config.model_key = 'taxonomy_model_tf'
config.embed_fn_config.model_config.window_size_s = 5.0
config.embed_fn_config.model_config.hop_size_s = 5.0
config.embed_fn_config.model_config.sample_rate = 32000
config.embed_fn_config.model_config.tfhub_version = perch_tfhub_version
config.embed_fn_config.model_config.model_path = perch_model_path

# Only write embeddings to reduce size.
config.embed_fn_config.write_embeddings = True
config.embed_fn_config.write_logits = False
config.embed_fn_config.write_separated_audio = False
config.embed_fn_config.write_raw_audio = False

# Number of parent directories to include in the filename.
config.embed_fn_config.file_id_depth = 1

In [9]:
# Set up the embedding function, including loading models.
embed_fn = embed_lib.EmbedFn(**config.embed_fn_config)
print('\n\nLoading model(s)...')
embed_fn.setup()

# Create output directory and write the configuration.
output_dir = epath.Path(config.output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
embed_lib.maybe_write_config(config, output_dir)

# Create SourceInfos.
source_infos = embed_lib.create_source_infos(
    config.source_file_patterns,
    num_shards_per_file=config.get('num_shards_per_file', -1),
    shard_len_s=config.get('shard_len_s', -1))
print(f'Found {len(source_infos)} source infos.')

print('\n\nTest-run of model...')
window_size_s = config.embed_fn_config.model_config.window_size_s
sr = config.embed_fn_config.model_config.sample_rate
z = np.zeros([int(sr * window_size_s)])
embed_fn.embedding_model.embed(z)
print('Setup complete!')



Loading model(s)...
Found 87135 source infos.


Test-run of model...


2024-06-20 15:55:38.802812: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:38] Ignoring Assert operator jax2tf_infer_fn_/assert_equal_1/Assert/AssertGuard/Assert


Setup complete!


In [10]:
# Uses multiple threads to load audio before embedding.
# This tends to be faster, but can fail if any audio files are corrupt.

min_file_size = 10_000_000 # 10 MB


embed_fn.min_audio_s = 1.0
record_file = (output_dir / 'embeddings.tfrecord').as_posix()
succ, fail = 0, 0

existing_embedding_ids = embed_lib.get_existing_source_ids(
    output_dir, 'embeddings-*')

new_source_infos = embed_lib.get_new_source_infos(
    source_infos, existing_embedding_ids, config.embed_fn_config.file_id_depth)

filtered_source_infos = []

for s in new_source_infos:
    size = os.stat(s.filepath).st_size
    if size < min_file_size:
        continue
    filtered_source_infos.append(s)

new_source_infos = filtered_source_infos

print(f'Found {len(new_source_infos)} existing embedding ids.'
      f'Processing {len(new_source_infos)} new source infos. ')

try:
  audio_loader = lambda fp, offset: audio_utils.load_audio_window(
      fp, offset, sample_rate=config.embed_fn_config.model_config.sample_rate,
      window_size_s=config.get('shard_len_s', -1.0))
  audio_iterator = audio_utils.multi_load_audio_window(
      filepaths=[s.filepath for s in new_source_infos],
      offsets=[s.shard_num * s.shard_len_s for s in new_source_infos],
      audio_loader=audio_loader,
  )
  with tf_examples.EmbeddingsTFRecordMultiWriter(
      output_dir=output_dir, num_files=config.get('tf_record_shards', 1)) as file_writer:
    for source_info, audio in tqdm.tqdm(
        zip(new_source_infos, audio_iterator), total=len(new_source_infos)):
      file_id = source_info.file_id(config.embed_fn_config.file_id_depth)
      offset_s = source_info.shard_num * source_info.shard_len_s
      example = embed_fn.audio_to_example(file_id, offset_s, audio)
      if example is None:
        fail += 1
        continue
      file_writer.write(example.SerializeToString())
      succ += 1
    file_writer.flush()
finally:
  del(audio_iterator)
print(f'\n\nSuccessfully processed {succ} source_infos, failed {fail} times.')

fns = [fn for fn in output_dir.glob('embeddings-*')]
ds = tf.data.TFRecordDataset(fns)
parser = tf_examples.get_example_parser()
ds = ds.map(parser)
for ex in ds.as_numpy_iterator():
  print(ex['filename'])
  print(ex['embedding'].shape, flush=True)
  break

Found 87135 existing embedding ids.Processing 87135 new source infos. 


  0%|                                                                                                 | 0/87135 [00:00<?, ?it/s]2024-06-20 15:56:21.214690: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:38] Ignoring Assert operator jax2tf_infer_fn_/assert_equal_1/Assert/AssertGuard/Assert
  0%|                                                                                    | 111/87135 [04:20<63:20:55,  2.62s/it]2024-06-20 16:00:29.909328: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:38] Ignoring Assert operator jax2tf_infer_fn_/assert_equal_1/Assert/AssertGuard/Assert
  2%|█▌                                                                                 | 1604/87135 [47:09<38:35:36,  1.62s/it]2024-06-20 16:43:17.789020: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:38] Ignoring Assert operator jax2tf_infer_fn_/assert_equal_1/Assert/AssertGuard/Assert
  2%|█▋                                                                                 | 1769/87135 [50:40<35:10:09,  1.