In [1]:
import time
from datetime import timedelta
import os
import argparse
import torch
import torch.nn.functional as F
from transformers import (
    AutoTokenizer,
    AutoModel,
    WhisperProcessor
)
from sentence_transformers import SentenceTransformer
from collections import OrderedDict
import pandas as pd
from tqdm import tqdm
from matplotlib import pyplot as plt
from pprint import pprint

from config import WhiSBERTConfig, CACHE_DIR, CHECKPOINT_DIR
from model import WhiSBERTModel
from data import AudioDataset, collate_train
from train import load_models
from utils import (
    mean_pooling,
    cos_sim_loss,
    sim_clr_loss,
    norm_temp_ce_loss
)


os.environ["TOKENIZERS_PARALLELISM"] = "false"

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
wtc_segments_df = pd.read_csv('/cronus_data/rrao/wtc_clinic/segment_outcomes.csv')
print(wtc_segments_df[wtc_segments_df['segment_id'] == '417-Q5-0']['user_id'].values[0])
wtc_segments_df

417


Unnamed: 0,user_id,segment_id,segment_filename,segment_message,valence,arousal,ope,agr,ext,con,neu
0,417,417-Q5-0,LS - Q5 - 0417-converted - 0000.wav,"Over the last five years, what are the three ...",5.492248,2.434264,1.078491,3.538978,2.419177,-1.841604,-0.317128
1,417,417-Q5-1,LS - Q5 - 0417-converted - 0001.wav,"Over the last five years, I've had four grand...",5.345472,2.215654,5.080293,0.166491,-2.165736,-2.827166,1.538005
2,417,417-Q5-2,LS - Q5 - 0417-converted - 0002.wav,"In the last six years, I've had five grandchi...",5.345472,2.215654,5.080293,0.166491,-2.165736,-2.827166,1.538005
3,417,417-Q5-3,LS - Q5 - 0417-converted - 0003.wav,Those are the nicest things that have happene...,5.862403,2.205863,1.145501,6.618364,4.583606,-1.471796,-0.623336
4,417,417-Q5-4,LS - Q5 - 0417-converted - 0004.wav,Plus we've stayed together.,4.469452,1.632236,2.591754,1.388986,-0.087364,-7.546930,5.881711
...,...,...,...,...,...,...,...,...,...,...,...
154598,1766,1766--99-130,WTCHP Open Ended Linguistics - 1766-converted ...,"Yes, lower the camera.",4.597764,2.151000,-1.472627,-1.503602,-2.542642,1.420707,0.526871
154599,1766,1766--99-131,WTCHP Open Ended Linguistics - 1766-converted ...,Keep the camera.,5.060868,2.280629,-1.724761,-0.399285,-0.700788,-2.833173,-0.502207
154600,1766,1766--99-132,WTCHP Open Ended Linguistics - 1766-converted ...,Maybe have a camera in the screen.,5.020720,2.168814,-1.478302,0.586663,-2.069125,-3.993055,0.031983
154601,1766,1766--99-133,WTCHP Open Ended Linguistics - 1766-converted ...,I think it would be a lot easier to answer th...,5.026930,2.495549,1.534325,3.015803,1.425991,2.239760,-3.279220


In [31]:
feature_df = pd.read_csv('/cronus_data/rrao/WhiSBERT/embeddings/wtc_all-MiniLM-L12-v2.csv')
feature_df

Unnamed: 0,segment_id,f000,f001,f002,f003,f004,f005,f006,f007,f008,...,f374,f375,f376,f377,f378,f379,f380,f381,f382,f383
0,417-Q5-0,-0.075584,0.040513,0.087918,-0.061818,0.016908,0.054135,-0.043687,-0.032497,-0.080257,...,0.031678,-0.036200,0.033863,0.004600,0.063568,0.063140,0.052456,0.024611,0.053543,0.059069
1,417-Q5-1,0.010263,0.028935,0.052418,-0.054661,-0.052739,0.056436,-0.048300,-0.039810,-0.043792,...,0.000716,0.005423,0.080140,0.005369,0.030937,0.011622,0.058191,0.037399,0.026069,0.044603
2,417-Q5-2,0.015640,0.026896,0.055683,-0.039167,-0.044886,0.059167,-0.046924,-0.064981,-0.040909,...,-0.019591,-0.001969,0.043371,0.010914,0.013996,0.022936,0.057911,0.041148,0.008372,0.041009
3,417-Q5-3,-0.031706,0.083616,0.106976,-0.072090,-0.040272,-0.016612,0.007439,0.031125,-0.003797,...,0.045717,-0.090909,0.034390,0.020780,0.058275,0.102914,0.070823,0.082695,-0.009661,0.046030
4,417-Q5-4,0.011252,0.017688,0.103457,-0.013963,-0.037668,0.025187,-0.000924,-0.027211,0.025800,...,-0.080296,-0.086024,0.017136,0.019137,0.040391,-0.072536,0.018324,0.019036,0.066409,-0.091498
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
154598,1766--99-130,0.097473,0.091618,0.018798,0.020551,0.004536,0.019850,-0.027589,0.020975,0.079535,...,-0.026897,0.008701,0.066491,0.025070,-0.140160,0.039958,0.008782,-0.094962,-0.065360,0.065812
154599,1766--99-131,0.033504,0.096108,0.022334,0.002139,0.110668,0.023967,0.010173,0.026386,0.072415,...,0.044142,0.008134,0.115293,0.045375,-0.089254,0.018317,-0.010786,-0.040346,0.000643,0.046477
154600,1766--99-132,-0.034302,0.091456,0.080520,-0.039176,0.015758,-0.023656,0.001591,-0.009069,0.061610,...,0.031829,0.013746,0.113839,0.012622,-0.056482,0.069467,-0.053189,-0.016472,-0.035657,0.024322
154601,1766--99-133,-0.015464,0.014041,0.052497,0.033327,0.076290,-0.024053,-0.065092,-0.006300,0.003020,...,-0.065374,0.023941,0.036343,0.006751,-0.013511,0.067152,0.055123,0.009439,-0.089196,0.039081


In [40]:
merged_df = wtc_segments_df[['user_id', 'segment_id']].merge(feature_df, on='segment_id', how='left')

In [41]:
merged_df

Unnamed: 0,user_id,segment_id,f000,f001,f002,f003,f004,f005,f006,f007,...,f374,f375,f376,f377,f378,f379,f380,f381,f382,f383
0,417,417-Q5-0,-0.075584,0.040513,0.087918,-0.061818,0.016908,0.054135,-0.043687,-0.032497,...,0.031678,-0.036200,0.033863,0.004600,0.063568,0.063140,0.052456,0.024611,0.053543,0.059069
1,417,417-Q5-1,0.010263,0.028935,0.052418,-0.054661,-0.052739,0.056436,-0.048300,-0.039810,...,0.000716,0.005423,0.080140,0.005369,0.030937,0.011622,0.058191,0.037399,0.026069,0.044603
2,417,417-Q5-2,0.015640,0.026896,0.055683,-0.039167,-0.044886,0.059167,-0.046924,-0.064981,...,-0.019591,-0.001969,0.043371,0.010914,0.013996,0.022936,0.057911,0.041148,0.008372,0.041009
3,417,417-Q5-3,-0.031706,0.083616,0.106976,-0.072090,-0.040272,-0.016612,0.007439,0.031125,...,0.045717,-0.090909,0.034390,0.020780,0.058275,0.102914,0.070823,0.082695,-0.009661,0.046030
4,417,417-Q5-4,0.011252,0.017688,0.103457,-0.013963,-0.037668,0.025187,-0.000924,-0.027211,...,-0.080296,-0.086024,0.017136,0.019137,0.040391,-0.072536,0.018324,0.019036,0.066409,-0.091498
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
154598,1766,1766--99-130,0.097473,0.091618,0.018798,0.020551,0.004536,0.019850,-0.027589,0.020975,...,-0.026897,0.008701,0.066491,0.025070,-0.140160,0.039958,0.008782,-0.094962,-0.065360,0.065812
154599,1766,1766--99-131,0.033504,0.096108,0.022334,0.002139,0.110668,0.023967,0.010173,0.026386,...,0.044142,0.008134,0.115293,0.045375,-0.089254,0.018317,-0.010786,-0.040346,0.000643,0.046477
154600,1766,1766--99-132,-0.034302,0.091456,0.080520,-0.039176,0.015758,-0.023656,0.001591,-0.009069,...,0.031829,0.013746,0.113839,0.012622,-0.056482,0.069467,-0.053189,-0.016472,-0.035657,0.024322
154601,1766,1766--99-133,-0.015464,0.014041,0.052497,0.033327,0.076290,-0.024053,-0.065092,-0.006300,...,-0.065374,0.023941,0.036343,0.006751,-0.013511,0.067152,0.055123,0.009439,-0.089196,0.039081


In [52]:
import numpy as np
for user_id in np.unique(merged_df['user_id']):
    mean_feats = merged_df[merged_df['user_id'] == user_id].iloc[:, 2:].mean()
    for feat_name, value in mean_feats.items():
        values = (user_id, feat_name, value, value)
    break
values

(2, 'f383', 0.0013391561822598491, 0.0013391561822598491)

In [None]:
df = pd.read_csv('/cronus_data/rrao/WhiSBERT/embeddings/all-MiniLM-L12-v2.csv')
hitop_df = df[:-154603]
wtc_df = df[-154603:].reset_index(drop=True)
# hitop_df.to_csv('/cronus_data/rrao/WhiSBERT/embeddings/hitop_all-MiniLM-L12-v2.csv', index=False)
# wtc_df.to_csv('/cronus_data/rrao/WhiSBERT/embeddings/wtc_all-MiniLM-L12-v2.csv', index=False)

In [None]:
sentences = ["This is an example sentence.", "Each sentence is converted"]

model = SentenceTransformer('sentence-transformers/distiluse-base-multilingual-cased-v1', cache_folder=CACHE_DIR)
# model = torch.nn.DataParallel(model, device_ids=[1,2,3])
embeddings = torch.from_numpy(model.encode(sentences))
print(embeddings.shape)
model



torch.Size([2, 512])


SentenceTransformer(
  (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: DistilBertModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Dense({'in_features': 768, 'out_features': 512, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})
)

In [9]:
embeddings[0].norm()

tensor(0.8855)

In [53]:
model[0].auto_model.config.num_hidden_layers

6

In [55]:
model[0].auto_model.transformer

Transformer(
  (layer): ModuleList(
    (0-5): 6 x TransformerBlock(
      (attention): MultiHeadSelfAttention(
        (dropout): Dropout(p=0.1, inplace=False)
        (q_lin): Linear(in_features=768, out_features=768, bias=True)
        (k_lin): Linear(in_features=768, out_features=768, bias=True)
        (v_lin): Linear(in_features=768, out_features=768, bias=True)
        (out_lin): Linear(in_features=768, out_features=768, bias=True)
      )
      (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (ffn): FFN(
        (dropout): Dropout(p=0.1, inplace=False)
        (lin1): Linear(in_features=768, out_features=3072, bias=True)
        (lin2): Linear(in_features=3072, out_features=768, bias=True)
        (activation): GELUActivation()
      )
      (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    )
  )
)

In [58]:
model[2].__dict__

{'training': False,
 '_parameters': OrderedDict(),
 '_buffers': OrderedDict(),
 '_non_persistent_buffers_set': set(),
 '_backward_pre_hooks': OrderedDict(),
 '_backward_hooks': OrderedDict(),
 '_is_full_backward_hook': None,
 '_forward_hooks': OrderedDict(),
 '_forward_hooks_with_kwargs': OrderedDict(),
 '_forward_hooks_always_called': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_forward_pre_hooks_with_kwargs': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_post_hooks': OrderedDict(),
 '_modules': OrderedDict([('activation_function', Tanh()),
              ('linear',
               Linear(in_features=768, out_features=512, bias=True))]),
 'in_features': 768,
 'out_features': 512,
 'bias': True}

In [42]:
model[2].linear

Linear(in_features=768, out_features=512, bias=True)

In [45]:
model[2].activation_function

Tanh()

In [None]:
config = WhiSBERTConfig(
    whisper_model_id='openai/whisper-base',
    pooling_mode='mean',
    use_sbert_encoder=True,
    batch_size=8,
    shuffle=False,
    device='cuda'
)
processor, _, _, _ = load_models(config, '')



DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(119547, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): L

In [3]:
print('Preprocessing AudioDataset...')
dataset = AudioDataset(processor)
mini_size = int(0.1 * len(dataset))
drop_size = len(dataset) - mini_size
mini_dataset, _ = torch.utils.data.random_split(dataset, [mini_size, drop_size])

# Calculate lengths for the train/val split (80:20)
total_size = len(mini_dataset)
train_size = int(0.8 * total_size)  # 80% for training
val_size = total_size - train_size  # 20% for validation
# Perform the split
train_dataset, val_dataset = torch.utils.data.random_split(mini_dataset, [train_size, val_size])
print(f'\tTotal dataset size (N): {total_size}')
print(f'\tTraining dataset size (N): {train_size}')
print(f'\tValidation dataset size (N): {val_size}')

Preprocessing AudioDataset...
	Total dataset size (N): 50369
	Training dataset size (N): 40295
	Validation dataset size (N): 10074


In [4]:
data_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    num_workers=config.num_workers,
    shuffle=config.shuffle,
    collate_fn=collate
)

In [5]:
batch = next(iter(data_loader))
print(batch['audio_inputs'].shape)
print(len(batch['text']))

torch.Size([8, 80, 3000])
8


In [6]:
encoded_input = tokenizer(batch['text'], padding=True, truncation=True, return_tensors='pt').to(config.device)
encoded_input['input_ids'].shape

torch.Size([8, 28])

In [7]:
with torch.no_grad():
    sbert_output = sbert(**encoded_input)
print(sbert_output.last_hidden_state.shape)
sentence_embeddings = mean_pooling(sbert_output.last_hidden_state, encoded_input['attention_mask'])
print(sentence_embeddings.shape)

torch.Size([8, 28, 768])
torch.Size([8, 768])


In [None]:
embedding_output = sbert.embeddings(input_ids=encoded_input['input_ids'])#, attn_mask=encoded_input['attention_mask'])
embedding_output.shape

torch.Size([8, 22, 768])

In [None]:
head_mask = [None] * sbert.config.num_hidden_layers
encoder_output = sbert.transformer(
    embedding_output,
    attn_mask=torch.ones(encoded_input['input_ids'].size(), device=config.device),
    head_mask=head_mask
)[0]
encoder_output.shape

torch.Size([8, 22, 768])

In [9]:
whisper_model = WhisperModel.from_pretrained(
    config.whisper_model_id,
    cache_dir=CACHE_DIR
).to(config.device)

In [10]:
# Whisper-based tokenization
with torch.no_grad():
    outputs = processor.tokenizer(
        batch['text'],
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors='pt'
    ).to(config.device)
print(outputs['input_ids'].shape)
print(outputs['attention_mask'].shape)

torch.Size([8, 27])
torch.Size([8, 27])


In [11]:
embs = whisper_model(
    batch['audio_inputs'].to(config.device),
    decoder_input_ids=outputs['input_ids'],
    decoder_attention_mask=outputs['attention_mask']
).last_hidden_state
embs.shape

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


torch.Size([8, 27, 512])

In [None]:
# extended_attention_mask = sbert_model.get_extended_attention_mask(outputs['attention_mask'], whisper_embs.size()[:-1])
# extended_attention_mask = sbert_model.get_extended_attention_mask(outputs['attention_mask'], outputs['attention_mask'].size())
encoder_output = sbert_model.transformer(embedding_output, attn_mask=torch.ones(encoded_input['input_ids'].size(), device=config.device), head_mask=head_mask)[0]
encoder_output.shape