In [2]:
import time
from datetime import timedelta
import os
import argparse
import torch
import torch.nn.functional as F
from transformers import (
    AutoTokenizer,
    AutoModel,
    WhisperProcessor
)
from sentence_transformers import SentenceTransformer
from collections import OrderedDict
import pandas as pd
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
from pprint import pprint

from config import WhiSBERTConfig, CACHE_DIR, CHECKPOINT_DIR
from model import WhiSBERTModel
from data import AudioDataset, collate_train
from train import load_models
from utils import (
    mean_pooling,
    cos_sim_loss,
    sim_clr_loss,
    norm_temp_ce_loss
)

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
import torch
from config import CACHE_DIR
from model import (
    expand_conv1d_layer,
    expand_embedding_layer,
    expand_linear_layer,
    expand_layer_norm,
    expand_positional_embedding
)
from transformers import WhisperModel

whisper_model = WhisperModel.from_pretrained(
    'openai/whisper-tiny',
    cache_dir=CACHE_DIR,
)
n_new_dims = 7
whisper_model.decoder

WhisperDecoder(
  (embed_tokens): Embedding(51865, 384, padding_idx=50257)
  (embed_positions): WhisperPositionalEmbedding(448, 384)
  (layers): ModuleList(
    (0-3): 4 x WhisperDecoderLayer(
      (self_attn): WhisperSdpaAttention(
        (k_proj): Linear(in_features=384, out_features=384, bias=False)
        (v_proj): Linear(in_features=384, out_features=384, bias=True)
        (q_proj): Linear(in_features=384, out_features=384, bias=True)
        (out_proj): Linear(in_features=384, out_features=384, bias=True)
      )
      (activation_fn): GELUActivation()
      (self_attn_layer_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (encoder_attn): WhisperSdpaAttention(
        (k_proj): Linear(in_features=384, out_features=384, bias=False)
        (v_proj): Linear(in_features=384, out_features=384, bias=True)
        (q_proj): Linear(in_features=384, out_features=384, bias=True)
        (out_proj): Linear(in_features=384, out_features=384, bias=True)
      )
      (e

In [4]:
print(whisper_model.decoder.embed_positions)
print(whisper_model.decoder.embed_positions.weight)
# print(whisper_model.encoder.embed_positions.weight.shape)
# print(whisper_model.encoder.embed_positions.weight.dtype)
# print(whisper_model.encoder.embed_positions.weight[0,:])

WhisperPositionalEmbedding(448, 384)
Parameter containing:
tensor([[ 0.0041, -0.0136,  0.0009,  ..., -0.0028, -0.0451,  0.0170],
        [ 0.0212, -0.0119,  0.0017,  ...,  0.0055, -0.0307, -0.0371],
        [ 0.0233, -0.0196,  0.0041,  ...,  0.0133, -0.0091, -0.0507],
        ...,
        [ 0.0039,  0.0008,  0.0019,  ..., -0.0101,  0.0593, -0.0149],
        [-0.0064, -0.0149,  0.0009,  ..., -0.0015,  0.0635, -0.0135],
        [-0.0016, -0.0101,  0.0026,  ..., -0.0035,  0.0626, -0.0055]],
       requires_grad=True)


In [5]:
whisper_model.decoder.embed_tokens = expand_embedding_layer(whisper_model.decoder.embed_tokens, n_new_dims, distribution='normal')
whisper_model.decoder.embed_positions = expand_positional_embedding(whisper_model.decoder.embed_positions, n_new_dims)

for layer in whisper_model.decoder.layers:
    layer.self_attn.k_proj = expand_linear_layer(layer.self_attn.k_proj, n_new_dims, n_new_dims)
    layer.self_attn.v_proj = expand_linear_layer(layer.self_attn.v_proj, n_new_dims, n_new_dims)
    layer.self_attn.q_proj = expand_linear_layer(layer.self_attn.q_proj, n_new_dims, n_new_dims)
    layer.self_attn.out_proj = expand_linear_layer(layer.self_attn.out_proj, n_new_dims, n_new_dims)
    layer.self_attn_layer_norm = expand_layer_norm(layer.self_attn_layer_norm, n_new_dims)
    layer.encoder_attn.k_proj = expand_linear_layer(layer.encoder_attn.k_proj, n_new_dims, n_new_dims)
    layer.encoder_attn.v_proj = expand_linear_layer(layer.encoder_attn.v_proj, n_new_dims, n_new_dims)
    layer.encoder_attn.q_proj = expand_linear_layer(layer.encoder_attn.q_proj, n_new_dims, n_new_dims)
    layer.encoder_attn.out_proj = expand_linear_layer(layer.encoder_attn.out_proj, n_new_dims, n_new_dims)
    layer.encoder_attn_layer_norm = expand_layer_norm(layer.encoder_attn_layer_norm, n_new_dims)
    layer.fc1 = expand_linear_layer(layer.fc1, added_in_features=n_new_dims)
    layer.fc2 = expand_linear_layer(layer.fc2, added_out_features=n_new_dims)
    layer.final_layer_norm = expand_layer_norm(layer.final_layer_norm, n_new_dims)

whisper_model.decoder.layer_norm = expand_layer_norm(whisper_model.decoder.layer_norm, n_new_dims)

In [6]:
print(whisper_model.decoder.embed_positions)
print(whisper_model.decoder.embed_positions.weight)
# print(whisper_model.encoder.embed_positions.weight.shape)
# print(whisper_model.encoder.embed_positions.weight.dtype)
# print(whisper_model.encoder.embed_positions.weight[0,:])

WhisperPositionalEmbedding(448, 391)
Parameter containing:
tensor([[ 0.0041, -0.0136,  0.0009,  ..., -0.0155,  0.0111,  0.0084],
        [ 0.0212, -0.0119,  0.0017,  ...,  0.0139,  0.0174,  0.0117],
        [ 0.0233, -0.0196,  0.0041,  ..., -0.0025, -0.0043,  0.0102],
        ...,
        [ 0.0039,  0.0008,  0.0019,  ..., -0.0011, -0.0084,  0.0098],
        [-0.0064, -0.0149,  0.0009,  ..., -0.0041, -0.0003, -0.0071],
        [-0.0016, -0.0101,  0.0026,  ...,  0.0124, -0.0141,  0.0017]],
       requires_grad=True)


In [7]:
whisper_model.decoder

WhisperDecoder(
  (embed_tokens): Embedding(51865, 391, padding_idx=50257)
  (embed_positions): WhisperPositionalEmbedding(448, 391)
  (layers): ModuleList(
    (0-3): 4 x WhisperDecoderLayer(
      (self_attn): WhisperSdpaAttention(
        (k_proj): Linear(in_features=391, out_features=391, bias=False)
        (v_proj): Linear(in_features=391, out_features=391, bias=True)
        (q_proj): Linear(in_features=391, out_features=391, bias=True)
        (out_proj): Linear(in_features=391, out_features=391, bias=True)
      )
      (activation_fn): GELUActivation()
      (self_attn_layer_norm): LayerNorm((391,), eps=1e-05, elementwise_affine=True)
      (encoder_attn): WhisperSdpaAttention(
        (k_proj): Linear(in_features=391, out_features=391, bias=False)
        (v_proj): Linear(in_features=391, out_features=391, bias=True)
        (q_proj): Linear(in_features=391, out_features=391, bias=True)
        (out_proj): Linear(in_features=391, out_features=391, bias=True)
      )
      (e

In [None]:
whisper_model.encoder.conv1 = expand_conv1d_layer(whisper_model.encoder.conv1, added_out_channels=n_new_dims)
whisper_model.encoder.conv2 = expand_conv1d_layer(whisper_model.encoder.conv2, added_in_channels=n_new_dims, added_out_channels=n_new_dims)

whisper_model.encoder.embed_positions = expand_embedding_layer(whisper_model.encoder.embed_positions, n_new_dims, distribution='zeros')
whisper_model.encoder.embed_positions.weight.requires_grad = False

for layer in whisper_model.encoder.layers:
    layer.self_attn.k_proj = expand_linear_layer(layer.self_attn.k_proj, n_new_dims, n_new_dims)
    layer.self_attn.v_proj = expand_linear_layer(layer.self_attn.v_proj, n_new_dims, n_new_dims)
    layer.self_attn.q_proj = expand_linear_layer(layer.self_attn.q_proj, n_new_dims, n_new_dims)
    layer.self_attn.out_proj = expand_linear_layer(layer.self_attn.out_proj, n_new_dims, n_new_dims)
    layer.self_attn_layer_norm = expand_layer_norm(layer.self_attn_layer_norm, n_new_dims)
    layer.fc1 = expand_linear_layer(layer.fc1, added_in_features=n_new_dims)
    layer.fc2 = expand_linear_layer(layer.fc2, added_out_features=n_new_dims)
    layer.final_layer_norm = expand_layer_norm(layer.final_layer_norm, n_new_dims)

whisper_model.encoder.layer_norm = expand_layer_norm(whisper_model.encoder.layer_norm, n_new_dims)

In [None]:
config = WhiSBERTConfig(
    whisper_model_id='openai/whisper-base',
    pooling_mode='mean',
    use_sbert_encoder=True,
    batch_size=8,
    shuffle=False,
    device='cuda'
)
processor, _, _, _ = load_models(config, '')



DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(119547, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): L

In [3]:
print('Preprocessing AudioDataset...')
dataset = AudioDataset(processor)
mini_size = int(0.1 * len(dataset))
drop_size = len(dataset) - mini_size
mini_dataset, _ = torch.utils.data.random_split(dataset, [mini_size, drop_size])

# Calculate lengths for the train/val split (80:20)
total_size = len(mini_dataset)
train_size = int(0.8 * total_size)  # 80% for training
val_size = total_size - train_size  # 20% for validation
# Perform the split
train_dataset, val_dataset = torch.utils.data.random_split(mini_dataset, [train_size, val_size])
print(f'\tTotal dataset size (N): {total_size}')
print(f'\tTraining dataset size (N): {train_size}')
print(f'\tValidation dataset size (N): {val_size}')

Preprocessing AudioDataset...
	Total dataset size (N): 50369
	Training dataset size (N): 40295
	Validation dataset size (N): 10074


In [4]:
data_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    num_workers=config.num_workers,
    shuffle=config.shuffle,
    collate_fn=collate
)

In [5]:
batch = next(iter(data_loader))
print(batch['audio_inputs'].shape)
print(len(batch['text']))

torch.Size([8, 80, 3000])
8


In [6]:
encoded_input = tokenizer(batch['text'], padding=True, truncation=True, return_tensors='pt').to(config.device)
encoded_input['input_ids'].shape

torch.Size([8, 28])

In [7]:
with torch.no_grad():
    sbert_output = sbert(**encoded_input)
print(sbert_output.last_hidden_state.shape)
sentence_embeddings = mean_pooling(sbert_output.last_hidden_state, encoded_input['attention_mask'])
print(sentence_embeddings.shape)

torch.Size([8, 28, 768])
torch.Size([8, 768])


In [None]:
embedding_output = sbert.embeddings(input_ids=encoded_input['input_ids'])#, attn_mask=encoded_input['attention_mask'])
embedding_output.shape

torch.Size([8, 22, 768])

In [None]:
head_mask = [None] * sbert.config.num_hidden_layers
encoder_output = sbert.transformer(
    embedding_output,
    attn_mask=torch.ones(encoded_input['input_ids'].size(), device=config.device),
    head_mask=head_mask
)[0]
encoder_output.shape

torch.Size([8, 22, 768])

In [9]:
whisper_model = WhisperModel.from_pretrained(
    config.whisper_model_id,
    cache_dir=CACHE_DIR
).to(config.device)

In [10]:
# Whisper-based tokenization
with torch.no_grad():
    outputs = processor.tokenizer(
        batch['text'],
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors='pt'
    ).to(config.device)
print(outputs['input_ids'].shape)
print(outputs['attention_mask'].shape)

torch.Size([8, 27])
torch.Size([8, 27])


In [11]:
embs = whisper_model(
    batch['audio_inputs'].to(config.device),
    decoder_input_ids=outputs['input_ids'],
    decoder_attention_mask=outputs['attention_mask']
).last_hidden_state
embs.shape

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


torch.Size([8, 27, 512])

In [None]:
# extended_attention_mask = sbert_model.get_extended_attention_mask(outputs['attention_mask'], whisper_embs.size()[:-1])
# extended_attention_mask = sbert_model.get_extended_attention_mask(outputs['attention_mask'], outputs['attention_mask'].size())
encoder_output = sbert_model.transformer(embedding_output, attn_mask=torch.ones(encoded_input['input_ids'].size(), device=config.device), head_mask=head_mask)[0]
encoder_output.shape