---
title: Xvector embeddings generation
output-file: xvector_embeddings.html
description: Loads xvector embeddings from audio streams
---

In [None]:
#| default_exp xvector_embeddings

In [None]:
import json
import os
from annoy import AnnoyIndex
import random
from pathlib import Path
import torch

class AnnoyHandler(object):
    
    def __init__(self, dimentions: int, index_path: str = None):
        if not dimentions:
            raise ValueError(f"Need dimentions, got '{dimentions}'")
        self.dimentions = dimentions
        self.index = AnnoyIndex(self.dimentions, 'angular')
        self.id2label = {}
        if index_path:
            self.index.load(index_path)
            dir = Path(index_path).parent
            name = Path(index_path).stem + '.json'
            with open(os.path.join(dir, name), 'r') as f:
                self.id2label = json.load(f)
            self.items = len(self.id2label)
        self.index.set_seed(1991)
    
    def add_item(self, vector, label) -> None:
        i = len(self.id2label)
        self.id2label[i] = label if label else None
        self.index.add_item(i, vector)

    def build(self, trees: int = 10):
        return self.index.build(trees, n_jobs=4)
    
    def unbuild(self):
        self.index.unbuild()
    
    def save(self, path: str):
        self.index.save(path)
        dir = Path(path).parent
        name = Path(path).stem + '.json'
        with open(os.path.join(dir, name), 'w') as f:
            json.dump(self.id2label, f)
    
    def get_nns(self, vector, n: int):
        indexes, similarities = self.index.get_nns_by_vector(vector, n, include_distances=True)
        labels = [self.id2label[str(i)] for i in indexes]
        return indexes, similarities, labels

In [None]:
#| export

from transformers import Wav2Vec2ForXVector
from wav2keyword.audio_processor import AudioProcessor
from pydantic import BaseModel
import numpy as np
from glob import glob
from pathlib import Path
from random import choices
import torch

class AudioArray(BaseModel):
    array: np.ndarray
    class Config:
        arbitrary_types_allowed = True

class XvectorModel(object):

    def __init__(self, model_checkpoint: str, annoy_index_path: str = None) -> None:
        self.model_checkpoint = model_checkpoint
        self.model = Wav2Vec2ForXVector.from_pretrained(self.model_checkpoint)
        self.audio_processor = AudioProcessor(self.model_checkpoint)
        self.embeddings_dimention = 512
        self.annoy_handler = AnnoyHandler(self.embeddings_dimention, annoy_index_path)
    
    def prepare_raw_audio(self, raw_data: bytes, sample_width: int, channels: int, frame_rate: int):
        return self.audio_processor.encode_raw_audio(raw_data, sample_width, channels, frame_rate)
    
    def get_embeddings(self, inputs):
        with torch.no_grad():
            result = self.model(**inputs).embeddings
        return result

    def get_predicted_labels(self, logits):
        proj = self.model.objective._parameters['weight'].cpu().detach().numpy()
        return np.argmax(np.dot(logits, proj), axis=1)


In [None]:
#| eval: false

model_checkpoint = 'data/panda/wav2vec2-base-finetuned-xvector/best_checkpoint/'
xvector_model = XvectorModel(model_checkpoint)

files = glob(f'{Path.home()}/.cache/panda/audios/*.wav')
file = choices(files, k=1)[0]
with open(file, "rb") as f:
    audio_bytes = bytearray()
    while (byte := f.read(1)):
        audio_bytes.extend(byte)
encoded_data = xvector_model.prepare_raw_audio(audio_bytes, 2, 2, 16000)
embeddings = xvector_model.get_embeddings(encoded_data)
print(embeddings)

tensor([[-3.5255e-04, -3.4512e-03, -3.3339e-04,  ...,  9.5668e-04,
          2.7784e-03, -9.6302e-04],
        [ 8.0288e-04, -1.0105e-03, -2.0560e-04,  ..., -4.4576e-04,
         -4.0969e-04,  2.0239e-03],
        [ 1.3787e-04, -4.7847e-04,  6.4930e-04,  ...,  9.0482e-05,
          2.8542e-04,  9.5453e-04],
        [ 8.3139e-04, -1.1354e-03,  1.1984e-03,  ..., -9.5183e-05,
          3.6603e-04,  8.8071e-05]])
