In [None]:
#| eval: false

from datasets import load_dataset, Features, Value, Audio, ClassLabel
import json

feats = Features({"path": Value("string"),
                  "audio": Audio(sampling_rate=16_000),
                  "label": ClassLabel(names=["not found","found"])}
                  )
def _generate_examples(example, tag):
        example['label'] = 1 if example['label'] in tag else 0
        example['audio'] = example['path']
        return example

with open('tags_data.json', 'r') as f:
    data = json.load(f)

data_files = {'train': 'dataset/slices_train.csv', 'test': 'dataset/slices_test.csv', 'val': 'dataset/slices_val.csv'}
dataset = load_dataset("csv", data_files=data_files)
dataset = dataset.remove_columns(column_names=['Unnamed: 0', 'split'])
tags_pool = [k for k, v in data.items() if 'chow mein' in v['tags']]
dataset = dataset.map(_generate_examples, fn_kwargs={'tag': tags_pool}, features=feats)
dataset = dataset.rename_column('path', 'file')
id2label = {0: 'not found', 1: 'found'}
label2id = {v: k for k, v in id2label.items()}

Using custom data configuration default-c58ed15a5d5a3dac
Reusing dataset csv (/home/jovyan/.cache/huggingface/datasets/csv/default-c58ed15a5d5a3dac/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a)


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/36993 [00:00<?, ?ex/s]

  0%|          | 0/4648 [00:00<?, ?ex/s]

  0%|          | 0/4586 [00:00<?, ?ex/s]

In [None]:
#| eval: false

dataset['train'].to_pandas().label.value_counts()

0    35787
1     1206
Name: label, dtype: int64

In [None]:
#| eval: false

def _filter_by_duration(example, duration):
    return len(example['audio']['array']) < duration * example['audio']['sampling_rate']

dataset = dataset.filter(_filter_by_duration, fn_kwargs={'duration': 1})


  0%|          | 0/37 [00:00<?, ?ba/s]

  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/5 [00:00<?, ?ba/s]

## Similarity measure

In [None]:
#| eval: false

from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
from datasets import load_dataset
import torch

sampling_rate = dataset['train'].features["audio"].sampling_rate

feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")


Downloading preprocessor_config.json:   0%|          | 0.00/159 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/1.56k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/360M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2Model: ['lm_head.weight', 'lm_head.bias']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
#| eval: false

dataset

DatasetDict({
    train: Dataset({
        features: ['file', 'audio', 'label'],
        num_rows: 35716
    })
    test: Dataset({
        features: ['file', 'audio', 'label'],
        num_rows: 4482
    })
    val: Dataset({
        features: ['file', 'audio', 'label'],
        num_rows: 4429
    })
})

In [None]:
#| eval: false

from collections import defaultdict
from tqdm import tqdm
from pathlib import Path

# audio file is decoded on the fly
max_duration = 1
for batch in tqdm(range(32, 35716, 32)):
    inputs = feature_extractor(
        [d["array"] for d in dataset['train'][batch-32:batch]["audio"]], sampling_rate=sampling_rate, return_tensors="pt", padding=True, max_duration=max_duration
    )
    files = [Path(d).name for d in dataset['train'][batch-32:batch]['file']]
    trans = [name.split('_')[0] for name in files]
    with torch.no_grad():
        embeddings = model(**inputs).last_hidden_state
    embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu()
    for i in range(embeddings.shape[0]):
        Path(f"embeddings/{trans[i]}").mkdir(parents=True, exist_ok=True)
        name = files[i].replace('.wav', '.pt')
        torch.save(embeddings[i], f'embeddings/{trans[i]}/{name}')


100%|██████████| 1116/1116 [50:45<00:00,  2.73s/it] 


In [None]:
#| eval: false

from glob import glob
import torch
from tqdm import tqdm
from collections import Counter

shapes = []
for file in tqdm(glob('embeddings_base/*/*.pt')):
    shapes.append(torch.load(file).shape[0])

In [None]:
#| eval: false

import numpy as np

np.quantile(shapes, np.arange(0,1,0.1))

array([19. , 33. , 38. , 40.7, 42. , 43. , 45. , 46. , 47. , 48. ])

In [None]:
#| eval: false

max(shapes)

49

In [None]:
#| eval: false

print(torch.load(file).shape)
torch.nn.functional.pad(torch.load(file), (0, 0, 0, 2), "constant", 0)

torch.Size([46, 768])


tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 