In [None]:
import time
from datetime import timedelta
import os
import argparse
import torch
import torch.nn.functional as F
from transformers import (
    AutoTokenizer,
    AutoModel,
    WhisperProcessor
)
from sentence_transformers import SentenceTransformer
from collections import OrderedDict
import pandas as pd
from tqdm import tqdm
from matplotlib import pyplot as plt
from pprint import pprint

from config import WhiSBERTConfig, CACHE_DIR, CHECKPOINT_DIR
from model import WhiSBERTModel
from data import AudioDataset, collate_train
from train import load_models
from utils import (
    mean_pooling,
    cos_sim_loss,
    sim_clr_loss,
    norm_temp_ce_loss
)

os.environ["TOKENIZERS_PARALLELISM"] = "false"

  from .autonotebook import tqdm as notebook_tqdm


In [60]:
hitop_segments_df = pd.read_csv('/cronus_data/rrao/hitop/segment_outcomes.csv')
hitop_segments_df

Unnamed: 0,user_id,segment_id,segment_filename,segment_message,valence,arousal,ope,agr,ext,con,neu
0,PP572,PP572_0_1,HiTOP_PP572_07062023_AA - 0001.wav,"I've never been an inpatient, but I've had ou...",5.141840,2.158170,9.027416,1.830520,-2.399522,-5.423723,2.477061
1,PP572,PP572_0_2,HiTOP_PP572_07062023_AA - 0002.wav,Okay.,,,-0.006014,0.209157,0.531530,0.181332,-0.120239
2,PP572,PP572_0_4,HiTOP_PP572_07062023_AA - 0004.wav,Yes.,,,-0.006014,0.209157,0.531530,0.181332,-0.120239
3,PP572,PP572_0_8,HiTOP_PP572_07062023_AA - 0008.wav,I like to go out with friends.,5.589249,2.401895,0.018631,1.967006,1.614968,0.604015,0.321846
4,PP572,PP572_0_9,HiTOP_PP572_07062023_AA - 0009.wav,I'm a music major and I don't really get much...,4.117138,2.018602,5.752272,0.132441,-0.953059,-5.101045,5.423269
...,...,...,...,...,...,...,...,...,...,...,...
571415,P692,P692_0_1186,P692_08062024_BB_12760_S - 1186.wav,"No, I know when enough's enough.",,,-0.006014,0.209157,0.531530,0.181332,-0.120239
571416,P692,P692_0_1187,P692_08062024_BB_12760_S - 1187.wav,This is some really strange weather will happen.,5.143191,2.198984,9.212935,2.315006,-7.801025,-7.491872,3.579342
571417,P692,P692_0_1188,P692_08062024_BB_12760_S - 1188.wav,The sun is shining and it's thundering.,5.014759,2.366074,11.030973,1.020247,-1.806953,-4.181623,-5.101836
571418,P692,P692_0_1191,P692_08062024_BB_12760_S - 1191.wav,are you in New York too?,5.862231,2.522707,-0.677049,-2.860874,1.331570,1.301145,0.250222


In [61]:
feature_df = pd.read_csv('/cronus_data/rrao/WhiSBERT/embeddings/hitop_all-MiniLM-L12-v2.csv')
feature_df

Unnamed: 0,segment_id,f000,f001,f002,f003,f004,f005,f006,f007,f008,...,f374,f375,f376,f377,f378,f379,f380,f381,f382,f383
0,PP572_0_1,0.009571,0.019323,0.007478,0.041222,-0.104351,-0.034637,-0.040082,0.048785,-0.049720,...,0.016217,0.035958,0.039362,-0.010150,-0.011600,0.040785,0.031120,0.030377,0.008324,-0.058333
1,PP572_0_2,0.063735,-0.090145,0.076856,0.016580,-0.015392,0.004488,-0.014066,-0.021347,-0.007709,...,-0.104106,0.071222,0.020406,0.048025,0.031107,0.021189,0.102623,0.039707,0.047844,0.039239
2,PP572_0_4,-0.009147,0.044874,0.095438,0.088864,0.026287,0.003342,0.003761,-0.001537,0.002872,...,-0.008186,-0.041032,-0.010027,0.012494,0.034509,0.038356,0.064894,0.012516,0.084423,0.010010
3,PP572_0_8,0.019812,-0.029717,0.070069,0.034346,0.003685,0.001872,0.164560,0.050357,-0.008711,...,0.014106,0.020565,-0.031883,0.069381,-0.025189,-0.059564,-0.057068,-0.010954,0.017886,-0.061186
4,PP572_0_9,0.073630,-0.007963,0.084720,0.063671,0.041942,0.036758,0.003557,0.054233,0.024699,...,-0.025823,-0.039665,0.044497,0.060275,-0.090286,-0.095973,0.009361,-0.100035,0.005244,-0.068029
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
571415,P692_0_1186,0.021623,-0.037660,0.036766,-0.066743,-0.012067,-0.032461,0.029833,0.037347,0.053613,...,-0.062468,0.008806,-0.059389,0.014321,-0.085285,0.089797,-0.038306,0.023783,-0.013505,0.037073
571416,P692_0_1187,-0.013937,-0.014916,0.041315,0.013619,0.048789,-0.042261,0.006820,-0.012841,0.032349,...,0.046992,-0.001953,-0.058689,-0.005668,0.034324,0.031395,-0.075359,0.046120,-0.029264,0.046044
571417,P692_0_1188,-0.041092,0.032440,0.028295,0.031878,0.046165,-0.005409,0.135404,-0.052996,-0.018056,...,-0.127967,-0.000660,-0.006330,-0.017517,0.008965,-0.011897,0.035845,0.022179,-0.117585,0.055960
571418,P692_0_1191,0.065206,-0.039234,0.128294,0.013154,-0.005981,0.036254,0.030744,0.081652,-0.068541,...,-0.000537,0.016490,-0.003721,-0.015859,-0.056227,0.027022,0.054575,-0.036430,-0.086609,-0.061887


In [None]:
sentences = ["This is an example sentence.", "Each sentence is converted"]

model = SentenceTransformer('sentence-transformers/distiluse-base-multilingual-cased-v1', cache_folder=CACHE_DIR)
# model = torch.nn.DataParallel(model, device_ids=[1,2,3])
embeddings = torch.from_numpy(model.encode(sentences))
print(embeddings.shape)
model



torch.Size([2, 512])


SentenceTransformer(
  (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: DistilBertModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Dense({'in_features': 768, 'out_features': 512, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})
)

In [9]:
embeddings[0].norm()

tensor(0.8855)

In [53]:
model[0].auto_model.config.num_hidden_layers

6

In [55]:
model[0].auto_model.transformer

Transformer(
  (layer): ModuleList(
    (0-5): 6 x TransformerBlock(
      (attention): MultiHeadSelfAttention(
        (dropout): Dropout(p=0.1, inplace=False)
        (q_lin): Linear(in_features=768, out_features=768, bias=True)
        (k_lin): Linear(in_features=768, out_features=768, bias=True)
        (v_lin): Linear(in_features=768, out_features=768, bias=True)
        (out_lin): Linear(in_features=768, out_features=768, bias=True)
      )
      (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (ffn): FFN(
        (dropout): Dropout(p=0.1, inplace=False)
        (lin1): Linear(in_features=768, out_features=3072, bias=True)
        (lin2): Linear(in_features=3072, out_features=768, bias=True)
        (activation): GELUActivation()
      )
      (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    )
  )
)

In [58]:
model[2].__dict__

{'training': False,
 '_parameters': OrderedDict(),
 '_buffers': OrderedDict(),
 '_non_persistent_buffers_set': set(),
 '_backward_pre_hooks': OrderedDict(),
 '_backward_hooks': OrderedDict(),
 '_is_full_backward_hook': None,
 '_forward_hooks': OrderedDict(),
 '_forward_hooks_with_kwargs': OrderedDict(),
 '_forward_hooks_always_called': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_forward_pre_hooks_with_kwargs': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_post_hooks': OrderedDict(),
 '_modules': OrderedDict([('activation_function', Tanh()),
              ('linear',
               Linear(in_features=768, out_features=512, bias=True))]),
 'in_features': 768,
 'out_features': 512,
 'bias': True}

In [42]:
model[2].linear

Linear(in_features=768, out_features=512, bias=True)

In [45]:
model[2].activation_function

Tanh()

In [None]:
config = WhiSBERTConfig(
    whisper_model_id='openai/whisper-base',
    pooling_mode='mean',
    use_sbert_encoder=True,
    batch_size=8,
    shuffle=False,
    device='cuda'
)
processor, _, _, _ = load_models(config, '')



DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(119547, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): L

In [3]:
print('Preprocessing AudioDataset...')
dataset = AudioDataset(processor)
mini_size = int(0.1 * len(dataset))
drop_size = len(dataset) - mini_size
mini_dataset, _ = torch.utils.data.random_split(dataset, [mini_size, drop_size])

# Calculate lengths for the train/val split (80:20)
total_size = len(mini_dataset)
train_size = int(0.8 * total_size)  # 80% for training
val_size = total_size - train_size  # 20% for validation
# Perform the split
train_dataset, val_dataset = torch.utils.data.random_split(mini_dataset, [train_size, val_size])
print(f'\tTotal dataset size (N): {total_size}')
print(f'\tTraining dataset size (N): {train_size}')
print(f'\tValidation dataset size (N): {val_size}')

Preprocessing AudioDataset...
	Total dataset size (N): 50369
	Training dataset size (N): 40295
	Validation dataset size (N): 10074


In [4]:
data_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    num_workers=config.num_workers,
    shuffle=config.shuffle,
    collate_fn=collate
)

In [5]:
batch = next(iter(data_loader))
print(batch['audio_inputs'].shape)
print(len(batch['text']))

torch.Size([8, 80, 3000])
8


In [6]:
encoded_input = tokenizer(batch['text'], padding=True, truncation=True, return_tensors='pt').to(config.device)
encoded_input['input_ids'].shape

torch.Size([8, 28])

In [7]:
with torch.no_grad():
    sbert_output = sbert(**encoded_input)
print(sbert_output.last_hidden_state.shape)
sentence_embeddings = mean_pooling(sbert_output.last_hidden_state, encoded_input['attention_mask'])
print(sentence_embeddings.shape)

torch.Size([8, 28, 768])
torch.Size([8, 768])


In [None]:
embedding_output = sbert.embeddings(input_ids=encoded_input['input_ids'])#, attn_mask=encoded_input['attention_mask'])
embedding_output.shape

torch.Size([8, 22, 768])

In [None]:
head_mask = [None] * sbert.config.num_hidden_layers
encoder_output = sbert.transformer(
    embedding_output,
    attn_mask=torch.ones(encoded_input['input_ids'].size(), device=config.device),
    head_mask=head_mask
)[0]
encoder_output.shape

torch.Size([8, 22, 768])

In [9]:
whisper_model = WhisperModel.from_pretrained(
    config.whisper_model_id,
    cache_dir=CACHE_DIR
).to(config.device)

In [10]:
# Whisper-based tokenization
with torch.no_grad():
    outputs = processor.tokenizer(
        batch['text'],
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors='pt'
    ).to(config.device)
print(outputs['input_ids'].shape)
print(outputs['attention_mask'].shape)

torch.Size([8, 27])
torch.Size([8, 27])


In [11]:
embs = whisper_model(
    batch['audio_inputs'].to(config.device),
    decoder_input_ids=outputs['input_ids'],
    decoder_attention_mask=outputs['attention_mask']
).last_hidden_state
embs.shape

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


torch.Size([8, 27, 512])

In [None]:
# extended_attention_mask = sbert_model.get_extended_attention_mask(outputs['attention_mask'], whisper_embs.size()[:-1])
# extended_attention_mask = sbert_model.get_extended_attention_mask(outputs['attention_mask'], outputs['attention_mask'].size())
encoder_output = sbert_model.transformer(embedding_output, attn_mask=torch.ones(encoded_input['input_ids'].size(), device=config.device), head_mask=head_mask)[0]
encoder_output.shape