In [33]:
from transformers import Wav2Vec2Processor, Wav2Vec2Model, AutoProcessor, WavLMModel, Wav2Vec2FeatureExtractor
import torch
import librosa
import os
import yaml
import importlib

In [34]:
def get_class_by_name(module_name, class_name):
    module = importlib.import_module(module_name)
    return getattr(module, class_name)



get_class_by_name('transformers', 'Wav2Vec2Model')

transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model

In [28]:
def load_yaml(file_path: str) -> dict[dict[str, str]]:
    with open(file_path, 'r') as file:
        return yaml.safe_load(file)


YAML_PATH = 'params.yaml'
VARS = load_yaml(YAML_PATH)

In [1]:
for i, j in [(0, 0), (0, 1), (1, 0), (1, 1)]:
    print(i, j)

0 0
0 1
1 0
1 1


In [17]:
MODEL_NAME = 'microsoft/wavlm-large'
PROCESSOR = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
MODEL = WavLMModel.from_pretrained(MODEL_NAME)

In [30]:
VARS['audio_models']

[{'processor': 'Wav2Vec2FeatureExtractor',
  'model': 'Wav2Vec2Model',
  'name': 'jonatasgrosman/wav2vec2-large-xlsr-53-english'},
 {'processor': 'Wav2Vec2FeatureExtractor',
  'model': 'WavLMModel',
  'name': 'microsoft/wavlm-large'},
 {'processor': 'Wav2Vec2FeatureExtractor',
  'model': 'HubertModel',
  'name': 'facebook/hubert-large-ls960-ft'}]

In [18]:






def get_single_audio_embedding(file_path: str) -> torch.tensor:
    waveform, sample_rate = librosa.load(file_path, sr = 16000)
    file_emo_alias = file_path.split('/')[-1].split('_')[2]
    label = VARS['emotion_mapping'][file_emo_alias]
    inputs = PROCESSOR(waveform, sampling_rate = sample_rate, return_tensors = 'pt', padding = True)
    
    with torch.no_grad():
        outputs = MODEL(**inputs)
        last_hidden_state = outputs.last_hidden_state
    global_embedding = torch.mean(last_hidden_state, dim = 1)
    
    return global_embedding.squeeze(0), torch.tensor(label)


def get_all_audio_embeddings(root: str) -> dict[torch.tensor, torch.tensor]:
    embeddings = dict()
    paths = sorted([os.path.join(root, file) for file in os.listdir(root) if (file.endswith('.wav') and file != '1076_MTI_SAD_XX.wav')])
    
    for path in paths:
        embedding, label = get_single_audio_embedding(path)
        embeddings[embedding] = label
    
    return embeddings

In [20]:
sample = get_single_audio_embedding('data/audio_data/1001_DFA_ANG_XX.wav')
sample[0]

tensor([-0.1441,  0.1018, -0.0264,  ...,  0.0402, -0.0409, -0.1608])

In [None]:
from transformers import WavLM