In [1]:
import time
from datetime import timedelta
import os
import argparse
import torch
import torch.nn.functional as F
from transformers import (
    AutoTokenizer,
    AutoModel,
    WhisperProcessor
)
from sentence_transformers import SentenceTransformer
from collections import OrderedDict
import pandas as pd
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
from pprint import pprint

from config import WhiSBERTConfig, CACHE_DIR, CHECKPOINT_DIR
from model import WhiSBERTModel
from data import AudioDataset, collate_train
from train import load_models
from utils import (
    mean_pooling,
    cos_sim_loss,
    sim_clr_loss,
    norm_temp_ce_loss
)

os.environ["TOKENIZERS_PARALLELISM"] = "false"



In [2]:
config = WhiSBERTConfig(
    whisper_model_id='openai/whisper-base',
    pooling_mode='mean',
    # use_sbert_encoder=True,
    n_new_dims=7,
    batch_size=8,
    shuffle=False,
    device='cpu'
)
processor, whisbert, tokenizer, sbert = load_models(config, '')



In [3]:
print('Preprocessing AudioDataset...')
dataset = AudioDataset(processor)
mini_size = int(0.1 * len(dataset))
drop_size = len(dataset) - mini_size
mini_dataset, _ = torch.utils.data.random_split(dataset, [mini_size, drop_size])

# Calculate lengths for the train/val split (80:20)
total_size = len(mini_dataset)
train_size = int(0.8 * total_size)  # 80% for training
val_size = total_size - train_size  # 20% for validation
# Perform the split
train_dataset, val_dataset = torch.utils.data.random_split(mini_dataset, [train_size, val_size])
print(f'\tTotal dataset size (N): {total_size}')
print(f'\tTraining dataset size (N): {train_size}')
print(f'\tValidation dataset size (N): {val_size}')

Preprocessing AudioDataset...
	Total dataset size (N): 50352
	Training dataset size (N): 40281
	Validation dataset size (N): 10071


In [4]:
data_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    num_workers=config.num_workers,
    shuffle=config.shuffle,
    collate_fn=collate_train
)

In [6]:
batch = next(iter(data_loader))
print(batch['audio_inputs'].shape)
print(len(batch['message']))
print(batch['outcomes'].shape)

torch.Size([8, 80, 3000])
8
torch.Size([8, 7])


In [7]:
encoded_input = tokenizer(batch['message'], padding=True, truncation=True, return_tensors='pt').to(config.device)
encoded_input['input_ids'].shape

torch.Size([8, 26])

In [8]:
with torch.no_grad():
    sbert_output = sbert(**encoded_input)
print(sbert_output.last_hidden_state.shape)
sentence_embeddings = mean_pooling(sbert_output.last_hidden_state, encoded_input['attention_mask'])
print(sentence_embeddings.shape)

torch.Size([8, 26, 384])
torch.Size([8, 384])


In [9]:
# Whisper-based tokenization
with torch.no_grad():
    outputs = processor.tokenizer(
        batch['message'],
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors='pt'
    ).to(config.device)
print(outputs['input_ids'].shape)
print(outputs['attention_mask'].shape)

torch.Size([8, 23])
torch.Size([8, 23])


In [10]:
# Get WhiSBERT's MEAN/LAST token
whis_embs = whisbert(
    batch['audio_inputs'].to(config.device),
    outputs['input_ids'],
    outputs['attention_mask']
)
whis_embs.shape

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


torch.Size([8, 391])