In [1]:
import os, json
import regex as re
import numpy as np
import pandas as pd
import torch
from transformers import (
    WhisperProcessor,
    # WhisperForConditionalGeneration,
    WhisperModel,
)
from pprint import pprint

from train import load_models
from utils import WhisBERTConfig, preprocess_audio

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load the processor and model
whisper_name = "openai/whisper-small"
whisper_processor = WhisperProcessor.from_pretrained(whisper_name, cache_dir='/cronus_data/rrao/cache/')
# whisper_generator = WhisperForConditionalGeneration.from_pretrained(whisper_name, cache_dir='/cronus_data/rrao/cache/')
whisper_model = WhisperModel.from_pretrained(whisper_name, cache_dir='/cronus_data/rrao/cache/')

In [3]:
batch = {
    'segment_filename': [
        'P209_segment.wav',
        'P360_segment.wav',
        'P443_segment.wav',
        'PP636_segment.wav'
    ],
    'text': [
        ' Nothing different, the same.',
        ' Yeah, like I like men.',
        " Biology doesn't change what the mind says.",
        ' esta man, bisexual.'
    ]
}

In [5]:
inputs = torch.cat([preprocess_audio(whisper_processor, os.path.join('/cronus_data/rrao/samples', audio_filename)) for audio_filename in batch['segment_filename']], dim=0)
inputs.shape

torch.Size([4, 80, 3000])

In [6]:
token_type = 'cls'

# Put models in evaluation mode
# whisper_generator.eval()
whisper_model.eval()

with torch.no_grad():
    # generated_ids = whisper_generator.generate(inputs)
    outputs = whisper_processor.tokenizer(
        batch['text'],
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors='pt'
    )
    embs = whisper_model(inputs, decoder_input_ids=outputs['input_ids'], decoder_attention_mask=outputs['attention_mask']).last_hidden_state
    
    if token_type == 'cls':
        non_padding_indices = outputs['attention_mask'].cumsum(dim=1) - 1
        last_non_padding_indices = non_padding_indices.gather(1, (outputs['attention_mask'].sum(dim=1, keepdim=True) - 1).clamp(min=0).long())
        embs = embs[torch.arange(outputs['attention_mask'].size(0)).unsqueeze(1), last_non_padding_indices].squeeze()
    else:
        sum_embs = (embs * outputs['attention_mask'].unsqueeze(-1).expand(embs.size())).sum(dim=1)
        non_padding_counts = outputs['attention_mask'].sum(dim=1).unsqueeze(-1).clamp(min=1)
        embs = sum_embs / non_padding_counts
    
print(embs.shape)

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


torch.Size([4, 768])


In [5]:
config = WhisBERTConfig(token_type='cls', use_new_encoder_layers=False)
_, whisbert_model, _, _ = load_models(config, '')

Some weights of RobertaModel were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
whisbert_model.eval()

with torch.no_grad():
    outputs = whisper_processor.tokenizer(
        batch['text'],
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors='pt'
    )
    whis_embs = whisbert_model(inputs, text_input_ids=outputs['input_ids'], text_attention_mask=outputs['attention_mask'])

print(whis_embs.shape)

torch.Size([4, 768])
