# Whisper Embedding Model

In [13]:
from transformers import WhisperForAudioClassification, WhisperConfig
import torch
import evaluate
import librosa
import wandb
from torch.utils.data import DataLoader
from datasets import load_dataset
from datasets import Dataset
from tqdm import tqdm

In [2]:
class KWS_dataset(torch.utils.data.Dataset):
    def __init__(self, input_data, output_data):
        self.input_data = input_data
        self.output_data = output_data
        
    def __len__(self):
        return len(self.input_data)
    
    def __getitem__(self, index):
        keyword = self.output_data[index]
        audio_features = self.input_data[index]
        # return audio_features, keyword
        return {'audio': audio_features,
                'keyword': keyword
               }

In [3]:
data_path = '../data/'

In [4]:
train_dataloader = torch.load('../data/en_splits_30.trainloader')
dev_dataloader = torch.load('../data/en_splits_30.devloader')
test_dataloader = torch.load('../data/en_splits_30.testloader')

In [5]:
metric = evaluate.load("accuracy")

In [6]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [8]:
whisper_model = WhisperForAudioClassification(WhisperConfig(
    num_mel_bins=80,
    vocab_size=30,
    num_labels=31,
    max_source_positions=50,
    classifier_proj_size=512,
    encoder_layer=8,
    decoder_layer=8,
    dropout=0.2
))

In [7]:
whisper_model = WhisperForAudioClassification(WhisperConfig(
    num_mel_bins=80,
    vocab_size=30,
    num_labels=31,
    max_source_positions=50,
    classifier_proj_size=512,
    encoder_layer=8,
    decoder_layer=8,
    dropout=0.2
))

whisper_model.to(device)
whisper_model.float()
whisper_model.train()
optim = torch.optim.Adam(whisper_model.parameters(), lr=5e-5)



#start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="whisper",
    
    # track hyperparameters and run metadata
    config={
    "architecture": "whisper",
    "dataset": "en_30",
    "epochs": "10",
    "channels": "80",
    "encoder_attention_heads": "6",
    "decoder_attention_heads": "6",

    }
)


epochs = 10

for epoch in range(epochs):
    whisper_model.train()
    for batch in tqdm(train_dataloader):
        optim.zero_grad()
        audio = batch['audio'].to(device)
        labels = batch['keyword'].to(device)
        outputs = whisper_model(audio, labels=labels)
        loss = outputs['loss']
        loss.backward()
        optim.step()
        
    torch.save({
            'epoch': epoch,
            'model_state_dict': whisper_model.state_dict(),
            'optimizer_state_dict': optim.state_dict(),
            'loss': loss,
            }, f'../model/whisper/epoch_{epoch+1}')
        
    whisper_model.eval()
        
    for batch in tqdm(dev_dataloader):
        audio = batch['audio'].to(device)
        labels = batch['keyword'].to(device)
        outputs = whisper_model(audio, labels=labels)
        
        metric.add_batch(predictions=outputs.logits.argmax(-1), references=labels)

    wandb.log({"acc": metric.compute()['accuracy'], "loss": loss})
    
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mjaeihn[0m ([33mthe-wild-bunch[0m). Use [1m`wandb login --relogin`[0m to force relogin


100%|█████████████████████████████| 3292/3292 [05:53<00:00,  9.31it/s]
100%|█████████████████████████████████| 14/14 [00:11<00:00,  1.23it/s]
100%|█████████████████████████████| 3292/3292 [06:03<00:00,  9.06it/s]
100%|█████████████████████████████████| 14/14 [00:11<00:00,  1.23it/s]
100%|█████████████████████████████| 3292/3292 [06:08<00:00,  8.93it/s]
100%|█████████████████████████████████| 14/14 [00:10<00:00,  1.30it/s]
100%|█████████████████████████████| 3292/3292 [06:13<00:00,  8.81it/s]
100%|█████████████████████████████████| 14/14 [00:11<00:00,  1.26it/s]
100%|█████████████████████████████| 3292/3292 [06:09<00:00,  8.92it/s]
100%|█████████████████████████████████| 14/14 [00:11<00:00,  1.18it/s]
100%|█████████████████████████████| 3292/3292 [06:03<00:00,  9.06it/s]
100%|█████████████████████████████████| 14/14 [00:10<00:00,  1.29it/s]
100%|█████████████████████████████| 3292/3292 [06:10<00:00,  8.88it/s]
100%|█████████████████████████████████| 14/14 [00:11<00:00,  1.23it/s]
100%|█

VBox(children=(Label(value='0.001 MB of 0.026 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.034497…

0,1
acc,▁▄▅▆▆▇█▇██
loss,▆▅▄▁▁▅▆█▄▂

0,1
acc,0.66426
loss,0.60512


In [7]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [11]:
whisper_model.push_to_hub("jaeihn/kws_embedding")

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

pytorch_model.bin:   0%|          | 0.00/27.0M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/jaeihn/kws_embedding/commit/31db4c2b2adc7c5cf7dd62cf57855ac3d372fe34', commit_message='Upload WhisperForAudioClassification', commit_description='', oid='31db4c2b2adc7c5cf7dd62cf57855ac3d372fe34', pr_url=None, pr_revision=None, pr_num=None)