In [0]:
#@title Imports. { vertical-output: true }
import json
from ml_collections import config_dict
import numpy as np
from etils import epath
import matplotlib.pyplot as plt

from chirp.inference import colab_utils
colab_utils.initialize(use_tf_gpu=True, disable_warnings=True)

from chirp import audio_utils
from chirp import path_utils
from chirp.inference import models
from chirp.projects.bootstrap import bootstrap
from chirp.projects.bootstrap import search
from chirp.projects.bootstrap import display


In [0]:
#@title Configuration and Setup. { vertical-output: true }

# Path to embeddings of unlabeled data.
embeddings_path = ''  #@param

# Path for storing annotated examples.
labeled_data_path = ''  #@param

separation_model_key = 'separator_model_tf'  #@param
separation_model_path = ''  #@param


In [0]:
#@title Load Project State and Config. { vertical-output: true }

# Get relevant info from the embedding configuration.
embeddings_path = epath.Path(embeddings_path)
with (embeddings_path / 'config.json').open() as f:
  embedding_config = config_dict.ConfigDict(json.loads(f.read()))
embeddings_glob = embeddings_path / 'embeddings-*'

config = bootstrap.BootstrapConfig.load_from_embedding_config(
    embeddings_path=embeddings_path,
    annotated_path=labeled_data_path)

project_state = bootstrap.BootstrapState(config)

separator = None

In [0]:
#@title Load Separation Model (Optional) { vertical-output: true }
separation_model_key = 'separator_model_tf'  #@param
separation_model_path = ''  #@param

if config.model_key == 'separate_embed_model' and not separation_model_key.strip():
  separation_model_key = 'separator_model_tf'
  separation_model_path = config.model_config.separator_model_tf_config.model_path


if separation_model_path:
  separation_config = config_dict.ConfigDict({
      'model_path': separation_model_path,
      'frame_size': 32000,
      'sample_rate': 32000,
  })
  separator = models.model_class_map()[
      separation_model_key].from_config(separation_config)
  print("Loaded separator model at {}".format(separation_model_path))
else:
  print('No separation model loaded.')
  separator = None

## Query Creation

In [0]:
#@title Load query audio. { vertical-output: true }

# Point to an audio file, Xeno-Canto id (like 'xc12345') or audio file URL.
audio_path = 'xc12345'  #@param
# Muck around with manual selection of the query start time...
start_s = 0  #@param

window_s = config.model_config['window_size_s']
sample_rate = config.model_config['sample_rate']
audio = audio_utils.load_audio(audio_path, sample_rate)

# Display the full file.
display.plot_audio_melspec(audio, sample_rate)

# Display the selected window.
print('-' * 80)
print('Selected audio window:')
st = int(start_s * sample_rate)
end = int(st + window_s * sample_rate)
if end > audio.shape[0]:
  end = audio.shape[0]
  st = max([0, int(end - window_s * sample_rate)])
audio_window = audio[st:end]
display.plot_audio_melspec(audio_window, sample_rate)

query_audio = audio_window
sep_outputs = None

In [0]:
#@title Separate the target audio window { vertical-output: true }

if separator is not None:
  sep_outputs = separator.embed(audio_window)

  for c in range(sep_outputs.separated_audio.shape[0]):
    print(f'Channel {c}')
    display.plot_audio_melspec(sep_outputs.separated_audio[c, :], sample_rate)
    print('-' * 80)
else:
  sep_outputs = None
  print('No separation model loaded.')

In [0]:
#@title Select the query channel. { vertical-output: true }

query_label = 'some_audio'  #@param
query_channel = -1  #@param

if query_channel < 0 or sep_outputs is None:
  query_audio = audio_window
else:
  query_audio = sep_outputs.separated_audio[query_channel].copy()

display.plot_audio_melspec(query_audio, sample_rate)

outputs = project_state.embedding_model.embed(query_audio)
query = outputs.pooled_embeddings('first', 'first')


In [0]:
#@title Run Top-K Search. { vertical-output: true }

# Number of search results to capture.
top_k = 25  #@param

# Target distance for search results.
# This lets us try to hone in on a 'classifier boundary' instead of just
# looking at the closest matches.
# Set to 'None' for raw 'best results' search.
target_score = None  #@param

metric = 'euclidean'  #@param['euclidean', 'mip', 'cosine']

random_sample = False  #@param

ds = project_state.create_embeddings_dataset()
results, all_scores = search.search_embeddings_parallel(
    ds, query,
    hop_size_s=config.embedding_hop_size_s,
    top_k=top_k, target_score=target_score, score_fn=metric,
    random_sample=random_sample)

# Plot histogram of distances
ys, _, _ = plt.hist(all_scores, bins=128, density=True)
hit_scores = [r.score for r in results.search_results]
plt.scatter(hit_scores, np.zeros_like(hit_scores), marker='|',
            color='r', alpha=0.5)

plt.xlabel(metric)
plt.ylabel('density')
if target_score is not None:
  plt.plot([target_score, target_score], [0.0, np.max(ys)], 'r:')
  # Compute the proportion of scores < target_score
  hit_percentage = (all_scores < target_score).mean()
  print(f'score < target_score percentage : {hit_percentage:5.3f}')
min_score = np.min(all_scores)
plt.plot([min_score, min_score], [0.0, np.max(ys)], 'g:')

plt.show()


In [0]:
#@title Display results. { vertical-output: true }

display.display_search_results(
    results, sample_rate, project_state.source_map,
    checkbox_labels=[query_label, 'unknown'],
    max_workers=5)

In [0]:
#@title Write annotated examples. { vertical-output: true }

results.write_labeled_data(config.annotated_path,
                           project_state.embedding_model.sample_rate)