In [None]:
#@title Imports. { vertical-output: true }

import dataclasses
import functools
import sqlite3
import os
from typing import Callable
import numpy as np
from concurrent import futures
import tqdm
import time
from scipy import stats
from matplotlib import pyplot as plt
from ml_collections import config_dict
from IPython.display import display
import ipywidgets as widgets
from chirp import audio_utils
from chirp.projects.agile2 import colab_utils
from chirp.projects.agile2 import embed
from chirp.projects.hoplite import interface

## Embed

In [None]:
#@title Configuration { vertical-output: true }

# Configure the raw dataset location(s).  The format is a mapping from a
# dataset_name to a (path, fileglob) pair.  Note that the file globs are case
# sensitive.  The dataset name can be anything you want.
#
# This structure allows you to move your data around without having to re-embed
# the dataset.  The generated embedding database will be placed next to the
# audio files. This allows you to simply swap out the base path here if you ever
# move your dataset.
audio_globs = {
    'dataset_1':
        ('/path/to/dataset/1', '*.WAV',),
    'dataset_2':
        ('/path/to/dataset/2', '*/*.mp4',),
}

# By default we only process one dataset at a time.  Re-run this entire notebook
# once per dataset.  The embeddings database will be located in the same
# directory as the raw audio
dataset_name = 'dataset_1'  #@param

if dataset_name not in audio_globs:
  raise ValueError(f'Dataset {dataset_name} not found in audio_globs')

globs_to_process = {dataset_name: audio_globs[dataset_name]}

# You do not need to change this unless you want to maintain multiple distinct
# embedding databases.
db_path = None
configs = colab_utils.load_configs(globs_to_process, db_path)
configs

In [None]:
#@title Initialize the DB { vertical-output: true }
global db
db = configs.db_config.load_db()
db.setup()
num_embeddings = db.count_embeddings()

print('Initialized DB located at ', configs.db_config.db_config.db_path)
print('Existing DB contains datasets: ', db.get_dataset_names())
print('num embeddings: ', num_embeddings)

def drop_and_reload_db(_) -> interface.GraphSearchDBInterface:
    os.unlink(configs.db_config.db_config.db_path)
    print('\n Deleted previous db at: ', configs.db_config.db_config.db_path)
    db = configs.db_config.load_db()
    db.setup()

drop_existing_db = 'True'  #@param['True', 'False']

if num_embeddings > 0 and drop_existing_db == 'True':
  print(f'\n\nClick the button below to confirm you really want to drop the database at ')
  print(f'{configs.db_config.db_config.db_path}\n')
  print(f'This will permanently delete all {num_embeddings} embeddings from the existing database.\n')
  print('If you do NOT want to delete this data, set `drop_existing_db` above to `False` and re-run this cell.\n')

  button = widgets.Button(description=f'Delete database?')
  button.on_click(drop_and_reload_db)
  display(button)

In [None]:
#@title Run the embedding { vertical-output: true }

# If the DB already exists, we need to make sure that the the current
# model_config is compatible with the model_config that was used previously.
colab_utils.validate_and_save_configs(configs, db)

print(f'Datasets requested to embed: {[key for key in globs_to_process]}')

# Avoid re-embedding datasets that are already present in the DB
# TODO(roblaber) Make this filtering more granular, ie, avoid re-embedding
#  (dataset, filename) pairs
for dataset in db.get_dataset_names():
  if dataset in globs_to_process:
    globs_to_process.pop(dataset)
    print(f'\nDataset \'{dataset}\' already present in DB, not re-embedding')

new_datasets = [key for key in globs_to_process]

print(f'\nNew datasets to embed: {new_datasets}')
print(f'\nPreparing to embed {len(new_datasets)} datasets...\n')

worker = embed.EmbedWorker(
    embed_config=configs.audio_sources_config,
    db=db,
    model_config=configs.model_config)

worker.process_all()

print('\n\nEmbedding complete, total embeddings: ', db.count_embeddings())

In [None]:
#@title Per dataset statistics { vertical-output: true }

for dataset in db.get_dataset_names():
  print(f'\nDataset \'{dataset}\':')
  print('\tnum embeddings: ', db.get_embeddings_by_source(dataset, source_id=None).shape[0])