In [1]:
import os, time
from collections import OrderedDict
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from transformers import (
    AutoTokenizer,
    AutoModel,
    WhisperProcessor,
    WhisperModel
)
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt
from pprint import pprint

from config import WhiSBERTConfig
from utils import (
    mean_pooling,
    last_pooling,
    cos_sim_loss,
    clr_cos_loss,
    sim_clr_loss
)
from data import AudioDataset, collate
from train import load_models, train

CACHE_DIR = '/cronus_data/rrao/cache'
CHECKPOINT_DIR = '/cronus_data/rrao/WhisBERT/models/'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# a batch of data
v1 = np.array([0.5, 0.6, 0.5, 0.6])
v2 = np.array([0.1, 0.1, 0.2, 0.2])
v3 = np.array([0.9, 0.8, 0.9, 0.8])
v4 = np.array([0.3, 0.7, 0.7, 0.3])

print("\nBatch of unlabeled data, v1 to v4: ")
print(v1); print(v2); print(v3); print(v4)

# augmented data
v5 = np.array([0.55, 0.65, 0.50, 0.60])  # from v1
v6 = np.array([0.10, 0.15, 0.25, 0.20])  # from v2
v7 = np.array([0.90, 0.85, 0.95, 0.80])  # from v3
v8 = np.array([0.35, 0.70, 0.75, 0.30])  # from v4

print("\nAugmented data, v5 to v8: ")
print(v5); print(v6); print(v7); print(v8)


Batch of unlabeled data, v1 to v4: 
[0.5 0.6 0.5 0.6]
[0.1 0.1 0.2 0.2]
[0.9 0.8 0.9 0.8]
[0.3 0.7 0.7 0.3]

Augmented data, v5 to v8: 
[0.55 0.65 0.5  0.6 ]
[0.1  0.15 0.25 0.2 ]
[0.9  0.85 0.95 0.8 ]
[0.35 0.7  0.75 0.3 ]


In [3]:
whis_embs = torch.from_numpy(np.stack([v1, v2, v3, v4]))
sbert_embs = torch.from_numpy(np.stack([v5, v6, v7, v8]))

In [None]:
def a(whis_embs, sbert_embs, tau=0.10):
    # Helpful link I used for reference:
    # https://jamesmccaffrey.wordpress.com/2022/04/11/an-example-of-normalized-temperature-scaled-cross-entropy-loss/
    z_audio = F.normalize(whis_embs, dim=1)
    z_text = F.normalize(sbert_embs, dim=1)
    
    # Compute cosine similarity for all pairs in the batch
    similarity_matrix = torch.matmul(z_audio, z_text.T) / tau
    similarity_matrix = torch.exp(similarity_matrix)

    # Sum over each row, excluding the diagonal (self-similarity terms)
    pos_sims = torch.diag(similarity_matrix)
    neg_sims_sum = similarity_matrix.sum(dim=1) - pos_sims
    print(neg_sims_sum)
    
    losses = -torch.log(pos_sims / neg_sims_sum)
    return losses.sum()


def b(whis_embs, sbert_embs, tau=0.10):
    # Helpful link I used for reference:
    # https://jamesmccaffrey.wordpress.com/2022/04/11/an-example-of-normalized-temperature-scaled-cross-entropy-loss/
    combined = torch.cat([whis_embs, sbert_embs], dim=0)  # shape (2 * batch_size, emb_dim)
    combined = F.normalize(combined, dim=1)

    # Define positive pairs (each original data with its corresponding augmented data)
    batch_size = whis_embs.shape[0]
    pos_pairs = torch.arange(batch_size)
    pos_indices = pos_pairs + batch_size  # offset by batch_size to point to the sbert_embs
    
    # Compute cosine similarity for all pairs in the batch
    similarity_matrix = torch.matmul(combined, combined.T) / tau
    similarity_matrix = torch.exp(similarity_matrix)
    
    pos_sims = similarity_matrix[pos_pairs, pos_indices]
    neg_sims_sum = similarity_matrix[:batch_size].sum(dim=1) - torch.diag(similarity_matrix[:batch_size])
    print(neg_sims_sum)
    
    losses = -torch.log(pos_sims / neg_sims_sum)
    return losses.sum()

In [6]:
start_time = time.time()
print(clr_cos_loss(whis_embs, sbert_embs))
end_time = time.time()

elapsed_time = end_time - start_time
print("Elapsed time:", elapsed_time, "seconds")

tensor(4.6910, dtype=torch.float64)
Elapsed time: 0.03089737892150879 seconds


In [47]:
start_time = time.time()
print(a(whis_embs, sbert_embs))
end_time = time.time()

elapsed_time = end_time - start_time
print("Elapsed time:", elapsed_time, "seconds")

tensor([21819.4160, 19883.0742, 21936.7833, 21780.0413], dtype=torch.float64)
tensor([43334.7961, 30951.2489, 44569.1498, 33562.6894], dtype=torch.float64)
tensor(2.2700, dtype=torch.float64)
Elapsed time: 0.0026175975799560547 seconds


In [48]:
start_time = time.time()
print(b(whis_embs, sbert_embs))
end_time = time.time()

elapsed_time = end_time - start_time
print("Elapsed time:", elapsed_time, "seconds")

tensor([21819.4160, 19883.0742, 21936.7833, 21780.0413], dtype=torch.float64)
tensor([108116.1506,  85319.3133, 109872.0390,  83249.7785],
       dtype=torch.float64)
tensor(6.0089, dtype=torch.float64)
Elapsed time: 0.003621339797973633 seconds


In [9]:
def sim(v1, v2):
  v1_normed = v1 / np.linalg.norm(v1)
  v2_normed = v2 / np.linalg.norm(v2)
  return np.dot(v1_normed, v2_normed)  # normalized dot prod

np.set_printoptions(precision=2)

tau = 0.10  # temperature

print("Computing loss for positive pair v1,v5 ")
# loss for positive pair (v1, v5)
v1v1 = np.exp(sim(v1,v1)/tau)  # not used
v1v2 = np.exp(sim(v1,v2)/tau)
v1v3 = np.exp(sim(v1,v3)/tau)
v1v4 = np.exp(sim(v1,v4)/tau)
v1v5 = np.exp(sim(v1,v5)/tau)  # should be small
v1v6 = np.exp(sim(v1,v6)/tau)
v1v7 = np.exp(sim(v1,v7)/tau)
v1v8 = np.exp(sim(v1,v8)/tau)

numerator = v1v5
denom = v1v2 + v1v3 + v1v4 + v1v5 + v1v6 + v1v7 + v1v8

loss_v1v5 = -np.log(numerator / denom)
print("\n%0.6f" % loss_v1v5)

Computing loss for positive pair v1,v5 

1.598489


In [68]:
sentences = ["This is an example sentence", "Each sentence is converted"]

model = SentenceTransformer('sentence-transformers/distiluse-base-multilingual-cased-v1', cache_folder=CACHE_DIR)
# model = torch.nn.DataParallel(model, device_ids=[1,2,3])
embeddings = torch.from_numpy(model.encode(sentences))
print(embeddings.shape)
model



torch.Size([2, 512])


SentenceTransformer(
  (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: DistilBertModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Dense({'in_features': 768, 'out_features': 512, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})
)

In [69]:
isinstance(model, torch.nn.DataParallel)

False

In [53]:
model[0].auto_model.config.num_hidden_layers

6

In [55]:
model[0].auto_model.transformer

Transformer(
  (layer): ModuleList(
    (0-5): 6 x TransformerBlock(
      (attention): MultiHeadSelfAttention(
        (dropout): Dropout(p=0.1, inplace=False)
        (q_lin): Linear(in_features=768, out_features=768, bias=True)
        (k_lin): Linear(in_features=768, out_features=768, bias=True)
        (v_lin): Linear(in_features=768, out_features=768, bias=True)
        (out_lin): Linear(in_features=768, out_features=768, bias=True)
      )
      (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (ffn): FFN(
        (dropout): Dropout(p=0.1, inplace=False)
        (lin1): Linear(in_features=768, out_features=3072, bias=True)
        (lin2): Linear(in_features=3072, out_features=768, bias=True)
        (activation): GELUActivation()
      )
      (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    )
  )
)

In [58]:
model[2].__dict__

{'training': False,
 '_parameters': OrderedDict(),
 '_buffers': OrderedDict(),
 '_non_persistent_buffers_set': set(),
 '_backward_pre_hooks': OrderedDict(),
 '_backward_hooks': OrderedDict(),
 '_is_full_backward_hook': None,
 '_forward_hooks': OrderedDict(),
 '_forward_hooks_with_kwargs': OrderedDict(),
 '_forward_hooks_always_called': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_forward_pre_hooks_with_kwargs': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_post_hooks': OrderedDict(),
 '_modules': OrderedDict([('activation_function', Tanh()),
              ('linear',
               Linear(in_features=768, out_features=512, bias=True))]),
 'in_features': 768,
 'out_features': 512,
 'bias': True}

In [42]:
model[2].linear

Linear(in_features=768, out_features=512, bias=True)

In [45]:
model[2].activation_function

Tanh()

In [None]:
config = WhiSBERTConfig(
    whisper_model_id='openai/whisper-base',
    pooling_mode='mean',
    use_sbert_encoder=True,
    batch_size=8,
    shuffle=False,
    device='cuda'
)
processor, _, _, _ = load_models(config, '')



DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(119547, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): L

In [3]:
print('Preprocessing AudioDataset...')
dataset = AudioDataset(processor)
mini_size = int(0.1 * len(dataset))
drop_size = len(dataset) - mini_size
mini_dataset, _ = torch.utils.data.random_split(dataset, [mini_size, drop_size])

# Calculate lengths for the train/val split (80:20)
total_size = len(mini_dataset)
train_size = int(0.8 * total_size)  # 80% for training
val_size = total_size - train_size  # 20% for validation
# Perform the split
train_dataset, val_dataset = torch.utils.data.random_split(mini_dataset, [train_size, val_size])
print(f'\tTotal dataset size (N): {total_size}')
print(f'\tTraining dataset size (N): {train_size}')
print(f'\tValidation dataset size (N): {val_size}')

Preprocessing AudioDataset...
	Total dataset size (N): 50369
	Training dataset size (N): 40295
	Validation dataset size (N): 10074


In [4]:
data_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    num_workers=config.num_workers,
    shuffle=config.shuffle,
    collate_fn=collate
)

In [5]:
batch = next(iter(data_loader))
print(batch['audio_inputs'].shape)
print(len(batch['text']))

torch.Size([8, 80, 3000])
8


In [6]:
encoded_input = tokenizer(batch['text'], padding=True, truncation=True, return_tensors='pt').to(config.device)
encoded_input['input_ids'].shape

torch.Size([8, 28])

In [7]:
with torch.no_grad():
    sbert_output = sbert(**encoded_input)
print(sbert_output.last_hidden_state.shape)
sentence_embeddings = mean_pooling(sbert_output.last_hidden_state, encoded_input['attention_mask'])
print(sentence_embeddings.shape)

torch.Size([8, 28, 768])
torch.Size([8, 768])


In [None]:
embedding_output = sbert.embeddings(input_ids=encoded_input['input_ids'])#, attn_mask=encoded_input['attention_mask'])
embedding_output.shape

torch.Size([8, 22, 768])

In [None]:
head_mask = [None] * sbert.config.num_hidden_layers
encoder_output = sbert.transformer(
    embedding_output,
    attn_mask=torch.ones(encoded_input['input_ids'].size(), device=config.device),
    head_mask=head_mask
)[0]
encoder_output.shape

torch.Size([8, 22, 768])

In [9]:
whisper_model = WhisperModel.from_pretrained(
    config.whisper_model_id,
    cache_dir=CACHE_DIR
).to(config.device)

In [10]:
# Whisper-based tokenization
with torch.no_grad():
    outputs = processor.tokenizer(
        batch['text'],
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors='pt'
    ).to(config.device)
print(outputs['input_ids'].shape)
print(outputs['attention_mask'].shape)

torch.Size([8, 27])
torch.Size([8, 27])


In [11]:
embs = whisper_model(
    batch['audio_inputs'].to(config.device),
    decoder_input_ids=outputs['input_ids'],
    decoder_attention_mask=outputs['attention_mask']
).last_hidden_state
embs.shape

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


torch.Size([8, 27, 512])

In [None]:
# extended_attention_mask = sbert_model.get_extended_attention_mask(outputs['attention_mask'], whisper_embs.size()[:-1])
# extended_attention_mask = sbert_model.get_extended_attention_mask(outputs['attention_mask'], outputs['attention_mask'].size())
encoder_output = sbert_model.transformer(embedding_output, attn_mask=torch.ones(encoded_input['input_ids'].size(), device=config.device), head_mask=head_mask)[0]
encoder_output.shape