In [50]:
from transformers import WhisperForAudioClassification, WhisperConfig
import torch
import evaluate
import librosa
from torch.utils.data import DataLoader
from datasets import load_dataset
from datasets import Dataset
from tqdm import tqdm

In [15]:
class KWS_dataset(torch.utils.data.Dataset):
    def __init__(self, input_data, output_data):
        self.input_data = input_data
        self.output_data = output_data
        
    def __len__(self):
        return len(self.input_data)
    
    def __getitem__(self, index):
        keyword = self.output_data[index]
        audio_features = self.input_data[index]
        # return audio_features, keyword
        return {'audio': audio_features,
                'keyword': keyword
               }

In [16]:
data_path = '../data/'

In [54]:
train_dataloader = torch.load('../data/en_splits_10_trainloader')
dev_dataloader = torch.load('../data/en_splits_10_devloader')
test_dataloader = torch.load('../data/en_splits_10_testloader')

In [55]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [71]:
whisper_model = WhisperForAudioClassification(WhisperConfig(
    num_mel_bins=80,
    vocab_size=10,
    num_labels=10,
    max_source_positions=50,
    classifier_proj_size=512,
    encoder_layers=8
))

In [72]:
whisper_model.to(device)
whisper_model.float()
whisper_model.train()
optim = torch.optim.Adam(whisper_model.parameters(), lr=5e-5)

In [73]:
metric = evaluate.load("accuracy")

In [None]:
import wandb
#start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="whisper",
    
    # track hyperparameters and run metadata
    config={
    "architecture": "whisper",
    "dataset": "en_10",
    "epochs": "10",
    "channels": "80",
    "encoder_layers": "8"

    }
)


epochs = 10

for epoch in range(epochs):
    whisper_model.train()
    for batch in tqdm(train_dataloader):
        optim.zero_grad()
        audio = batch['audio'].to(device)
        labels = batch['keyword'].to(device)
        outputs = whisper_model(audio, labels=labels)
        loss = outputs['loss']
        loss.backward()
        optim.step()
        
    torch.save({
            'epoch': epoch,
            'model_state_dict': whisper_model.state_dict(),
            'optimizer_state_dict': optim.state_dict(),
            'loss': loss,
            }, f'../model/whisper/epoch_{epoch+1}')
        
    whisper_model.eval()
        
    for batch in tqdm(dev_dataloader):
        audio = batch['audio'].to(device)
        labels = batch['keyword'].to(device)
        outputs = whisper_model(audio, labels=labels)
        
        metric.add_batch(predictions=outputs.logits.argmax(-1), references=labels)

    wandb.log({"acc": metric.compute()['accuracy'], "loss": loss})
    
wandb.finish()

100%|█████████████████████████████████████████| 997/997 [01:33<00:00, 10.67it/s]
100%|█████████████████████████████████████████████| 4/4 [00:04<00:00,  1.16s/it]
100%|█████████████████████████████████████████| 997/997 [01:32<00:00, 10.81it/s]
100%|█████████████████████████████████████████████| 4/4 [00:04<00:00,  1.13s/it]
100%|█████████████████████████████████████████| 997/997 [01:34<00:00, 10.51it/s]
100%|█████████████████████████████████████████████| 4/4 [00:04<00:00,  1.16s/it]
100%|█████████████████████████████████████████| 997/997 [01:34<00:00, 10.58it/s]
100%|█████████████████████████████████████████████| 4/4 [00:04<00:00,  1.16s/it]
100%|█████████████████████████████████████████| 997/997 [01:34<00:00, 10.50it/s]
100%|█████████████████████████████████████████████| 4/4 [00:04<00:00,  1.14s/it]
100%|█████████████████████████████████████████| 997/997 [01:32<00:00, 10.75it/s]
100%|█████████████████████████████████████████████| 4/4 [00:04<00:00,  1.16s/it]
100%|███████████████████████