In [1]:
import os
from collections import OrderedDict
import pandas as pd
import torch
import torch.nn.functional as F
from transformers import (
    AutoTokenizer,
    AutoModel,
    WhisperProcessor,
    WhisperModel
)
from tqdm import tqdm
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt
from pprint import pprint

from train import load_models, train
from utils import WhiSBERTConfig, AudioDataset, collate, mean_pooling

CACHE_DIR = '/cronus_data/rrao/cache'
CHECKPOINT_DIR = '/cronus_data/rrao/WhisBERT/models/'

In [2]:
config = WhiSBERTConfig(pooling_mode='cls', use_sbert_layers=False, batch_size=8, shuffle=False, device='cuda')
processor, whisbert, tokenizer, sbert = load_models(config, '')




Available GPU IDs: [0, 1, 2, 3]
	GPU 0: NVIDIA RTX A6000
	GPU 1: NVIDIA RTX A6000
	GPU 2: NVIDIA RTX A6000
	GPU 3: NVIDIA RTX A6000



In [3]:
print('Preprocessing AudioDataset...')
dataset = AudioDataset('/cronus_data/rrao/wtc_clinic/whisper_segments_transripts.csv', processor)
mini_size = int(0.1 * len(dataset))
drop_size = len(dataset) - mini_size
mini_dataset, _ = torch.utils.data.random_split(dataset, [mini_size, drop_size])

# Calculate lengths for the train/val split (80:20)
total_size = len(mini_dataset)
train_size = int(0.8 * total_size)  # 80% for training
val_size = total_size - train_size  # 20% for validation
# Perform the split
train_dataset, val_dataset = torch.utils.data.random_split(mini_dataset, [train_size, val_size])
print(f'\tTotal dataset size (N): {total_size}')
print(f'\tTraining dataset size (N): {train_size}')
print(f'\tValidation dataset size (N): {val_size}')

Preprocessing AudioDataset...
	Total dataset size (N): 15518
	Training dataset size (N): 12414
	Validation dataset size (N): 3104


In [4]:
data_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    num_workers=config.num_workers,
    shuffle=config.shuffle,
    collate_fn=collate
)

In [5]:
batch = next(iter(data_loader))
print(batch['audio_inputs'].shape)
print(len(batch['text']))

torch.Size([8, 80, 3000])
8


In [6]:
# Whisper-based tokenization
with torch.no_grad():
    outputs = processor.tokenizer(
        batch['text'],
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors='pt'
    ).to(config.device)

In [7]:
whis_embs = whisbert(
    batch['audio_inputs'].to(config.device),
    outputs['input_ids'],
    outputs['attention_mask']
)
whis_embs.shape

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


torch.Size([8, 768])

In [9]:
sbert_model_id = 'sentence-transformers/all-mpnet-base-v2'
tokenizer = AutoTokenizer.from_pretrained(sbert_model_id, cache_dir=CACHE_DIR)
sbert_model = AutoModel.from_pretrained(sbert_model_id, cache_dir=CACHE_DIR).to(config.device)



In [10]:
encoded_input = tokenizer(batch['text'], padding=True, truncation=True, return_tensors='pt').to(config.device)
encoded_input['input_ids'].shape

torch.Size([8, 32])

In [18]:
with torch.no_grad():
    sbert_output = sbert_model(**encoded_input)
sentence_embeddings = mean_pooling(sbert_output.last_hidden_state, encoded_input['attention_mask'])
print(sentence_embeddings.shape)

RuntimeError: expand(torch.cuda.LongTensor{[8, 32, 1]}, size=[32, 768]): the number of sizes provided (2) must be greater or equal to the number of dimensions in the tensor (3)

In [24]:
embedding_output = sbert_model.embeddings(input_ids=encoded_input['input_ids'], attention_mask=encoded_input['attention_mask'])
embedding_output.shape

torch.Size([8, 17, 768])

In [41]:
outputs['attention_mask'].shape

torch.Size([8, 18])

In [42]:
whisper_embs.shape

torch.Size([8, 18, 768])

In [39]:
extended_attention_mask = sbert_model.get_extended_attention_mask(outputs['attention_mask'], whisper_embs.size()[:-1])
head_mask = [None] * sbert_model.config.num_hidden_layers

In [43]:
encoder_output = sbert_model.encoder(whisper_embs, attention_mask=extended_attention_mask, head_mask=head_mask)[0]
encoder_output.shape

torch.Size([8, 18, 768])

In [44]:
pooled_output = sbert_model.pooler(encoder_output)
pooled_output.shape

torch.Size([8, 768])

In [2]:
# Load the processor and model
whisper_name = "openai/whisper-small"
whisper_processor = WhisperProcessor.from_pretrained(whisper_name, cache_dir='/cronus_data/rrao/cache/')
# whisper_generator = WhisperForConditionalGeneration.from_pretrained(whisper_name, cache_dir='/cronus_data/rrao/cache/')
whisper_model = WhisperModel.from_pretrained(whisper_name, cache_dir='/cronus_data/rrao/cache/')

In [3]:
batch = {
    'segment_filename': [
        'P209_segment.wav',
        'P360_segment.wav',
        'P443_segment.wav',
        'PP636_segment.wav'
    ],
    'text': [
        ' Nothing different, the same.',
        ' Yeah, like I like men.',
        " Biology doesn't change what the mind says.",
        ' esta man, bisexual.'
    ]
}

In [5]:
inputs = torch.cat([preprocess_audio(whisper_processor, os.path.join('/cronus_data/rrao/samples', audio_filename)) for audio_filename in batch['segment_filename']], dim=0)
inputs.shape

torch.Size([4, 80, 3000])

In [6]:
token_type = 'cls'

# Put models in evaluation mode
# whisper_generator.eval()
whisper_model.eval()

with torch.no_grad():
    # generated_ids = whisper_generator.generate(inputs)
    outputs = whisper_processor.tokenizer(
        batch['text'],
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors='pt'
    )
    embs = whisper_model(inputs, decoder_input_ids=outputs['input_ids'], decoder_attention_mask=outputs['attention_mask']).last_hidden_state
    
    if token_type == 'cls':
        non_padding_indices = outputs['attention_mask'].cumsum(dim=1) - 1
        last_non_padding_indices = non_padding_indices.gather(1, (outputs['attention_mask'].sum(dim=1, keepdim=True) - 1).clamp(min=0).long())
        embs = embs[torch.arange(outputs['attention_mask'].size(0)).unsqueeze(1), last_non_padding_indices].squeeze()
    else:
        sum_embs = (embs * outputs['attention_mask'].unsqueeze(-1).expand(embs.size())).sum(dim=1)
        non_padding_counts = outputs['attention_mask'].sum(dim=1).unsqueeze(-1).clamp(min=1)
        embs = sum_embs / non_padding_counts
    
print(embs.shape)

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


torch.Size([4, 768])


In [5]:
config = WhisBERTConfig(token_type='cls', use_new_encoder_layers=False)
_, whisbert_model, _, _ = load_models(config, '')

Some weights of RobertaModel were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
whisbert_model.eval()

with torch.no_grad():
    outputs = whisper_processor.tokenizer(
        batch['text'],
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors='pt'
    )
    whis_embs = whisbert_model(inputs, text_input_ids=outputs['input_ids'], text_attention_mask=outputs['attention_mask'])

print(whis_embs.shape)

torch.Size([4, 768])
