In [1]:
import os
import pandas as pd
import torch
import torch.nn.functional as F
from tqdm import tqdm
from collections import OrderedDict
from pprint import pprint

from train import load_models, train
from utils import WhiSBERTConfig, AudioDataset, collate

CACHE_DIR = '/cronus_data/rrao/cache'
CHECKPOINT_DIR = '/cronus_data/rrao/WhisBERT/models/'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
load_path = ''
save_name = ''
config = WhiSBERTConfig(pooling_mode='cls', use_sbert_layers=False, batch_size=8, device='cuda')
processor, whisbert, sbert = load_models(config, load_path)




Available GPU IDs: [0, 1, 2, 3]
	GPU 0: NVIDIA RTX A6000
	GPU 1: NVIDIA RTX A6000
	GPU 2: NVIDIA RTX A6000
	GPU 3: NVIDIA RTX A6000



In [3]:
print('Preprocessing AudioDataset...')
dataset = AudioDataset('/cronus_data/rrao/wtc_clinic/whisper_segments_transripts.csv', processor)
mini_size = int(0.1 * len(dataset))
drop_size = len(dataset) - mini_size
mini_dataset, _ = torch.utils.data.random_split(dataset, [mini_size, drop_size])

# Calculate lengths for the train/val split (80:20)
total_size = len(mini_dataset)
train_size = int(0.8 * total_size)  # 80% for training
val_size = total_size - train_size  # 20% for validation
# Perform the split
train_dataset, val_dataset = torch.utils.data.random_split(mini_dataset, [train_size, val_size])
print(f'\tTotal dataset size (N): {total_size}')
print(f'\tTraining dataset size (N): {train_size}')
print(f'\tValidation dataset size (N): {val_size}')

Preprocessing AudioDataset...
	Total dataset size (N): 15518
	Training dataset size (N): 12414
	Validation dataset size (N): 3104


In [4]:
data_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    num_workers=config.num_workers,
    shuffle=config.shuffle,
    collate_fn=collate
)

In [5]:
batch = next(iter(data_loader))
print(batch['audio_inputs'].shape)
print(len(batch['text']))

torch.Size([8, 80, 3000])
8


In [14]:
# Get SBERT CLS/MEAN token
if isinstance(sbert, torch.nn.DataParallel):
    sbert_embs = torch.tensor(sbert.module.encode(batch['text']), device=config.device)
else:
    sbert_embs = torch.tensor(sbert.encode(batch['text']), device=config.device)

sbert_embs

here


tensor([[ 0.0650, -0.0318,  0.0308,  ..., -0.0022, -0.0055, -0.0334],
        [ 0.0121,  0.0264,  0.0174,  ...,  0.0414, -0.0145, -0.0559],
        [ 0.0674,  0.0181,  0.0026,  ...,  0.0015,  0.0401, -0.0405],
        ...,
        [-0.0227, -0.0337,  0.0067,  ..., -0.0138,  0.0184, -0.0350],
        [ 0.0160, -0.0176, -0.0094,  ...,  0.0034, -0.0264, -0.0017],
        [ 0.0576, -0.0430, -0.0356,  ..., -0.0485, -0.0295,  0.0043]],
       device='cuda:0')

In [7]:
def compare_model_states(before, after):
    changes = []
    for key in before:
        if not torch.equal(before[key], after[key]):
            changes.append(key)
    return changes

In [8]:
before_state = OrderedDict({name: param.clone() for name, param in whisbert.state_dict().items()})

print('\nStarting Training...')
train(
    train_dataset,
    val_dataset,
    processor,
    whisbert,
    sbert,
    config,
    save_name
)

after_state = OrderedDict({name: param.clone() for name, param in whisbert.state_dict().items()})




Starting Training...


Epoch 1/1 - Training:   0%|          | 0/64 [00:00<?, ?it/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
Epoch 1/1 - Training: 100%|██████████| 64/64 [01:52<00:00,  1.75s/it]
Epoch 1/1 - Validation: 100%|██████████| 16/16 [00:16<00:00,  1.05s/it]

Epoch 1/1
	Training Loss: 9989.1831
	Validation Loss: 351.5404
	Learning Rate: 5e-05





In [9]:
changed_parameters = compare_model_states(before_state, after_state)
print(f"Parameter types changed: {len(changed_parameters)}")

Parameter types changed: 478


In [10]:
print(before_state['module.whisper_model.decoder.layers.0.encoder_attn.k_proj.weight'])
print(after_state['module.whisper_model.decoder.layers.0.encoder_attn.k_proj.weight'])

tensor([[ 0.0107,  0.0136, -0.0111,  ...,  0.0122, -0.0311,  0.0051],
        [ 0.0085, -0.0050,  0.0219,  ...,  0.0007,  0.0348,  0.0010],
        [-0.0233,  0.0094,  0.0208,  ...,  0.0222, -0.0020,  0.0298],
        ...,
        [ 0.0050, -0.0530, -0.0053,  ...,  0.0317,  0.0212, -0.0889],
        [ 0.0296,  0.0020, -0.0041,  ..., -0.0171,  0.0079, -0.0177],
        [ 0.0023,  0.0196, -0.0061,  ..., -0.0279,  0.0132, -0.0052]],
       device='cuda:0')
tensor([[ 0.0110,  0.0135, -0.0110,  ...,  0.0121, -0.0312,  0.0046],
        [ 0.0082, -0.0052,  0.0219,  ...,  0.0006,  0.0348,  0.0017],
        [-0.0236,  0.0092,  0.0208,  ...,  0.0221, -0.0019,  0.0306],
        ...,
        [ 0.0048, -0.0530, -0.0050,  ...,  0.0315,  0.0210, -0.0887],
        [ 0.0292,  0.0021, -0.0037,  ..., -0.0174,  0.0076, -0.0174],
        [ 0.0020,  0.0197, -0.0058,  ..., -0.0281,  0.0130, -0.0050]],
       device='cuda:0')


In [11]:
if save_name:
    print(f'\nSaving `best/last` WhisBERT state_dict...')        
    save_dir = os.path.join(CHECKPOINT_DIR, save_name)
    os.makedirs(save_dir, exist_ok=True)

    best_path = os.path.join(save_dir, 'best.pth')
    torch.save(best_state, best_path)
    print(f'\tDone.\t`{best_path}`\n')

    last_path = os.path.join(save_dir, 'last.pth')
    torch.save(whisbert_model.state_dict(), last_path)
    print(f'\tDone.\t`{last_path}`\n')

torch.cuda.empty_cache()


Saving `best/last` WhisBERT state_dict...
	Done.	`/cronus_data/rrao/WhisBERT/models/test/best.pth`

	Done.	`/cronus_data/rrao/WhisBERT/models/test/last.pth`



In [12]:
config = WhisBERTConfig(use_new_encoder_layers=False, batch_size=196, num_workers=12, device='cuda')
_, last_model, _, _ = load_models(config, '/cronus_data/rrao/WhisBERT/models/test/last.pth')
print(last_model.state_dict()['module.whisper_model.decoder.layers.0.encoder_attn.k_proj.weight'])

Some weights of RobertaModel were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Available GPU IDs: [0, 1, 2, 3]
	GPU 0: NVIDIA RTX A6000
	GPU 1: NVIDIA RTX A6000
	GPU 2: NVIDIA RTX A6000
	GPU 3: NVIDIA RTX A6000

Instantiating WhisBERT with loaded state dict...


  state_dict = torch.load(load_path)


tensor([[ 0.0110,  0.0135, -0.0110,  ...,  0.0121, -0.0312,  0.0046],
        [ 0.0082, -0.0052,  0.0219,  ...,  0.0006,  0.0348,  0.0017],
        [-0.0236,  0.0092,  0.0208,  ...,  0.0221, -0.0019,  0.0306],
        ...,
        [ 0.0048, -0.0530, -0.0050,  ...,  0.0315,  0.0210, -0.0887],
        [ 0.0292,  0.0021, -0.0037,  ..., -0.0174,  0.0076, -0.0174],
        [ 0.0020,  0.0197, -0.0058,  ..., -0.0281,  0.0130, -0.0050]],
       device='cuda:0')


In [13]:
config = WhisBERTConfig(use_new_encoder_layers=False, batch_size=196, num_workers=12, device='cuda')
_, best_model, _, _ = load_models(config, '/cronus_data/rrao/WhisBERT/models/test/best.pth')
print(best_model.state_dict()['module.whisper_model.decoder.layers.0.encoder_attn.k_proj.weight'])

Some weights of RobertaModel were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Available GPU IDs: [0, 1, 2, 3]
	GPU 0: NVIDIA RTX A6000
	GPU 1: NVIDIA RTX A6000
	GPU 2: NVIDIA RTX A6000
	GPU 3: NVIDIA RTX A6000

Instantiating WhisBERT with loaded state dict...
tensor([[ 0.0107,  0.0136, -0.0111,  ...,  0.0122, -0.0311,  0.0051],
        [ 0.0085, -0.0050,  0.0219,  ...,  0.0007,  0.0348,  0.0010],
        [-0.0233,  0.0094,  0.0208,  ...,  0.0222, -0.0020,  0.0298],
        ...,
        [ 0.0050, -0.0530, -0.0053,  ...,  0.0317,  0.0212, -0.0889],
        [ 0.0296,  0.0020, -0.0041,  ..., -0.0171,  0.0079, -0.0177],
        [ 0.0023,  0.0196, -0.0061,  ..., -0.0279,  0.0132, -0.0052]],
       device='cuda:0')
