In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display

import torch
import torch.nn.functional as F
from vox_profile_release.src.model.emotion.whisper_emotion import WhisperWrapper

from torchcodec.decoders import AudioDecoder

from torchinfo import summary

In [None]:
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_rows', 600)

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
device

PYTORCH_CUDA_ALLOC_CONF={'expandable_segments':True}

In [None]:
SAMPLE_RATE = 16000
DURATION = 15
CHANNELS = 1
MAX_LENGTH = SAMPLE_RATE * DURATION

In [None]:
rec_path = 'inputs/rec_pre.mp3'
seg_path = 'outputs/transcript.parquet'

In [None]:
segments = pd.read_parquet(seg_path)
segments

In [None]:
model = WhisperWrapper.from_pretrained('tiantiaf/whisper-large-v3-msp-podcast-emotion').to(device)

In [None]:
summary(model)

In [None]:
decoder = AudioDecoder(rec_path, sample_rate=SAMPLE_RATE, num_channels=CHANNELS)

samples = []
for start, end, text in segments.to_numpy():
    sample = decoder.get_samples_played_in_range(start, end).data.squeeze(0).to(device)  # returns 1D array/tensor
    if len(sample) < MAX_LENGTH:
        pad_len = MAX_LENGTH - len(sample)
        sample = torch.cat([sample, torch.zeros(pad_len)])
    else:
        sample = sample[:MAX_LENGTH]

    samples.append({'audio': sample, 'start': start, 'end': end, 'text': text})

In [None]:
samples[-1]

In [None]:
logits, embedding, _, _, _, _ = model([sample['audio'] for sample in samples[:-1]], return_feature=True)

In [None]:
emotion_list = [
    'Anger', 
    'Contempt', 
    'Disgust', 
    'Fear', 
    'Happiness', 
    'Neutral', 
    'Sadness', 
    'Surprise', 
    'Other'
]

# Probability
emotion_prob = F.softmax(logits, dim=1)
emotion_list[torch.argmax(emotion_prob).detach().cpu().item()]