# Model Training for ASR Based Speech Emotion Recogntion

## Import Libraries

In [39]:
import os
import json
import librosa
import numpy as np
from collections import defaultdict

import torch 
from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader

## Prepare and Perform ASR on Datasets

### Define Filepaths

In [40]:
train_dir = './Train'
valid_dir = './Valid'
emotion2idx_filename = './emotion2idx.json'
idx2emotion_filename = './idx2emotion.json'

### Prepare Label Dictionaries

In [41]:
emotion2idx_dict = defaultdict(list)
idx2emotion_dict = defaultdict(list)

idx = 0

for _, filename in enumerate(os.listdir(train_dir)):
    emotion = filename.split('_')[2]
    
    if emotion not in emotion2idx_dict:
        emotion2idx_dict[emotion] = idx
        idx2emotion_dict[idx] = emotion
        idx += 1

with open(emotion2idx_filename, 'w') as json_file:
    json.dump(emotion2idx_dict, json_file)

with open(idx2emotion_filename, 'w') as json_file:
    json.dump(idx2emotion_dict, json_file)

### Load Label Dictionaries

In [42]:
with open(emotion2idx_filename, 'r') as json_file:
    emotion2idx_dict = json.load(json_file)

with open(idx2emotion_filename, 'r') as json_file:
    idx2emotion_dict = json.load(json_file)

### Load ASR Model and Processor

In [43]:
asr_model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr")
asr_processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr")

Some weights of Speech2TextForConditionalGeneration were not initialized from the model checkpoint at facebook/s2t-small-librispeech-asr and are newly initialized: ['model.encoder.embed_positions.weights', 'model.decoder.embed_positions.weights']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Perform ASR on Train and Validation Data

In [45]:
def speech_to_text(data_dir, emotion2idx_dict, sample_rate, asr_model, asr_processor):
    filenames = os.listdir(data_dir)
    inputs = []
    labels = []

    for filename in filenames:
        # load the speech file
        speech_audio, _ = librosa.load(os.path.join(data_dir, filename), sr = sample_rate)

        # perform ASR
        input_features = asr_processor(speech_audio, sampling_rate=sample_rate, return_tensors="pt").input_features
        generated_ids = asr_model.generate(input_features=input_features)
        transcription = asr_processor.batch_decode(generated_ids)
        
        # append to inputs and labels
        inputs.append(transcription)
        labels.append(emotion2idx_dict[filename.split('_')[2]])
        
    return filenames, inputs, labels

In [46]:
train_filenames, train_inputs, train_labels = speech_to_text(train_dir, emotion2idx_dict, 16_000, asr_model, asr_processor)
valid_filenames, valid_inputs, valid_labels = speech_to_text(valid_dir, emotion2idx_dict, 16_000, asr_model, asr_processor)

  x = np.divide(x, std)


In [48]:
idx = 100

print(train_filenames[idx])
print(train_inputs[idx])
print(train_labels[idx])

1033_ITS_SAD_XX.wav
["</s> i think i've seen this before</s>"]
3


### Define Dataset

In [None]:
class Speech_Text_Dataset(Dataset):
    def __init__(self, data_dir, emotion2idx_dict, sample_rate, asr_model, asr_processor):
        self.data_dir = data_dir
        self.emotion2idx_dict = emotion2idx_dict
        self.sample_rate = sample_rate
        self.asr_model = asr_model
        self.asr_processor = asr_processor

        self.filenames = os.listdir(data_dir)
        self.labels = [self.emotion2idx_dict[filename.split('_')[2]] for filename in self.filenames]

    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        filename = self.filenames[idx]
        label = self.labels[idx]

        # load the speech file
        speech_audio, _ = librosa.load(os.path.join(self.data_dir, filename), sr = self.sample_rate)

        # perform ASR
        input_features = asr_processor(speech_audio, sampling_rate=self.sample_rate, return_tensors="pt").input_features
        generated_ids = asr_model.generate(input_features=input_features)
        transcription = asr_processor.batch_decode(generated_ids)
        
        return speech_audio , label