In [51]:
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer, AutoModel, AutoFeatureExtractor
from IPython.display import Audio, display
import librosa
import soundfile as sf
import os

In [45]:
# Load the Wav2Vec2.0 tokenizer and model
tokenizer = Wav2Vec2Tokenizer.from_pretrained('facebook/wav2vec2-base-960h')
model = AutoModel.from_pretrained('facebook/wav2vec2-base-960h')
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'Wav2Vec2CTCTokenizer'. 
The class this function is called from is 'Wav2Vec2Tokenizer'.
Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2Model: ['lm_head.weight', 'lm_head.bias']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading (…)rocessor_config.json:   0%|          | 0.00/159 [00:00<?, ?B/s]

## Resampling the track to 16000 and splitting the track into 10 second small tracks

In [7]:
sample_rate = 16000

In [8]:
x, sr = librosa.load('BorisBrejcha.wav', sr=sample_rate)

In [25]:
len(x)

53207876

In [9]:
number_of_seconds = 10

In [39]:
for i in range(0, len(x), number_of_seconds * sr):
    y = x[i: number_of_seconds * sr + (i)]
    sf.write("dest_audio"+str(i)+".wav", y, sr)

In [40]:
waveform, sample_rate = torchaudio.load('dest_audio160000.wav')
waveform = waveform.squeeze().numpy()

waveform.shape

(160000,)

In [41]:
display(Audio(waveform, rate=sample_rate))
print()




## Embeddings generator

In [46]:
input_values = feature_extractor(waveform, return_tensors="pt")

It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


In [64]:
with torch.no_grad():
    embeddings = model(**input_values).last_hidden_state
    embeddings = embeddings.mean(dim=1)
    print(embeddings.shape)

torch.Size([1, 768])


array([[-6.35363311e-02,  2.18565725e-02, -3.93589847e-02,
        -3.07763238e-02, -2.64905635e-02, -9.55875516e-02,
         4.59514447e-02, -2.32393071e-02, -2.71134917e-02,
        -3.86615992e-01,  4.97092605e-02, -2.69201957e-02,
         4.70911190e-02,  7.03526363e-02, -7.70962834e-02,
        -4.04429249e-03, -2.44204447e-01,  3.14419657e-01,
         1.45180384e-02,  2.09609456e-02, -1.68319151e-01,
         1.24339640e-01,  8.73123556e-02,  1.01502491e-02,
         1.34945884e-01,  7.13615306e-03, -3.55252266e-01,
         4.78337891e-02,  2.07347888e-02, -1.29381657e-01,
         1.28535450e-01, -1.15841217e-02, -2.20219791e-02,
        -6.84324726e-02, -1.95385009e-01,  4.56412844e-02,
         8.45474377e-02, -2.70121187e-01, -1.10124044e-01,
         8.81849602e-02, -1.21737674e-01, -1.13990650e-01,
        -9.80903283e-02,  1.71214595e-01, -1.01497218e-01,
         9.20501724e-02, -4.13665213e-02, -1.06576808e-01,
         1.34882918e-02,  3.81322135e-03, -1.28232196e-0

In [66]:
embeddings.shape[1]

768

## Generate embeddings for all the small portions of the track and store it in FAISS

In [67]:
import faiss                   # make faiss available
index = faiss.IndexFlatL2(embeddings.shape[1])

In [None]:
for files in os.listdir():
    if 'dest_audio' in files:
        waveform, sample_rate = torchaudio.load(files)
        waveform = waveform.squeeze().numpy()
        input_values = feature_extractor(waveform, return_tensors="pt")
        embeddings = model(**input_values).last_hidden_state
        embeddings = embeddings.mean(dim=1)
        index.add(embeddings.detach().numpy())

It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_ra

It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_ra