In [1]:
# Install required libraries
%pip install torch torchaudio transformers datasets[audio]

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [2]:
# Import required libraries
from datasets import load_dataset
from transformers import ClapModel, ClapProcessor

In [3]:
# Load dataset
audio_dataset = load_dataset('TrainingDataPro/speech-emotion-recognition-dataset', split='train')

# Load CLAP model
model = ClapModel.from_pretrained('laion/larger_clap_music').to('cuda:0')
processor = ClapProcessor.from_pretrained('laion/larger_clap_music')

# Preprocess audio samples
emotions = ['euphoric', 'joyfully', 'sad', 'surprised']
sample_inputs = []
audio_embeddings = []

SAMPLING_RATE = 48000

for emotion in emotions:
    for sample in audio_dataset[emotion]:
        sample_inputs.append(
            processor(audios=sample['array'],
            sampling_rate=SAMPLING_RATE,
            return_tensors='pt').to('cuda:0')
        )

for inputs in sample_inputs:
    audio_embeddings.append(model.get_audio_features(**inputs))

In [10]:
print(len(audio_embeddings))

[tensor([[-0.0197, -0.0025,  0.1133, -0.0487, -0.0336,  0.0309, -0.0527, -0.0409,
         -0.0400,  0.0530,  0.0382, -0.0160, -0.0892,  0.0099,  0.0168, -0.0374,
          0.0693,  0.1020, -0.0434, -0.0180,  0.0379,  0.0314,  0.0399, -0.0525,
         -0.0670,  0.0361, -0.0136,  0.0019, -0.0383,  0.0466,  0.0200,  0.1150,
         -0.0810, -0.0564, -0.0013, -0.0519, -0.0469,  0.0269, -0.0407,  0.0268,
          0.0948,  0.0118,  0.0049, -0.0013, -0.0642, -0.0298, -0.0566,  0.0477,
         -0.0628, -0.0311, -0.0386,  0.0205,  0.0002,  0.0010, -0.0334,  0.0164,
          0.0053, -0.0440,  0.0240,  0.0347,  0.0721, -0.0308, -0.0108, -0.0599,
         -0.0632, -0.0261, -0.0123,  0.0508, -0.0534, -0.0206,  0.0115,  0.0196,
          0.0337,  0.0279, -0.0052,  0.0801,  0.0253,  0.0027, -0.0572,  0.0235,
          0.0050,  0.0133,  0.0346, -0.0588,  0.0151, -0.1037, -0.0492, -0.0265,
         -0.0821, -0.0077,  0.0843, -0.0015,  0.0155, -0.0064,  0.0309,  0.0489,
         -0.0121, -0.0268, 