In [1]:
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
import glob
from pymilvus import MilvusClient
import embeddings_util
import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MILVUS_DATABASE = 'esc50.db'
MODEL_NAME = 'facebook/wav2vec2-large'

feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
model = Wav2Vec2Model.from_pretrained(MODEL_NAME)

In [3]:
milvus_client = MilvusClient(MILVUS_DATABASE)
embeddings_util.init_milvus(milvus_client)

In [4]:
train_files = glob.glob('esc50/train/**/*.wav', recursive=True)

for file in train_files:
    feature_vector = embeddings_util.retrieve_embeddings_for_audiofile(file, feature_extractor, model)
    embeddings_util.insert_embeddings_into_db(feature_vector, file, milvus_client)

In [5]:
val_files = glob.glob('esc50/val/**/*.wav', recursive=True)

top_1_scores_list = []
top_3_scores_list = []

for file in val_files:
    target_category = file.split('/')[2]
    feature_vector = embeddings_util.retrieve_embeddings_for_audiofile(
        file, feature_extractor, model)
    result_json = embeddings_util.retrieve_by_sample(feature_vector, milvus_client)
    inferred_category = result_json[0][0]['entity']['filename'].split('/')[2]

    top_1_scores_list.append(1) if target_category == inferred_category else top_1_scores_list.append(0)

    top_3_classes = []
    for r in result_json[0][0:3]:
        top_3_classes.append(r['entity']['filename'].split('/')[2])
    top_3_scores_list.append(1) if target_category in top_3_classes else top_3_scores_list.append(0)


print('Top 1 accuracy: {}'.format(top_1_scores_list.count(1) / len(top_1_scores_list)))
print('Top 3 accuracy: {}'.format(top_3_scores_list.count(1) / len(top_3_scores_list)))


Top 1 accuracy: 0.22
Top 3 accuracy: 0.36
