In [1]:
from transformers import WhisperForAudioClassification, WhisperConfig
import torch
import evaluate
import librosa
from torch.utils.data import DataLoader
from datasets import load_dataset
from datasets import Dataset
from tqdm import tqdm

In [2]:
class KWS_dataset(torch.utils.data.Dataset):
    def __init__(self, input_data, output_data):
        self.input_data = input_data
        self.output_data = output_data
        
    def __len__(self):
        return len(self.input_data)
    
    def __getitem__(self, index):
        keyword = self.output_data[index]
        audio_features = self.input_data[index]
        # return audio_features, keyword
        return {'audio': audio_features,
                'keyword': keyword
               }

In [3]:
data_path = '../data/'

In [4]:
train_dataloader = torch.load('../data/en_splits_30.trainloader')
dev_dataloader = torch.load('../data/en_splits_30.devloader')
test_dataloader = torch.load('../data/en_splits_30.testloader')

In [5]:
metric = evaluate.load("accuracy")

In [6]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
whisper_model = WhisperForAudioClassification(WhisperConfig(
    num_mel_bins=80,
    vocab_size=30,
    num_labels=31,
    max_source_positions=50,
    classifier_proj_size=1024,
    # encoder_attention_heads=8,
    # decoder_attention_heads=8,
    dropout=0.2
))

whisper_model.to(device)
whisper_model.float()
whisper_model.train()
optim = torch.optim.Adam(whisper_model.parameters(), lr=5e-5)


import wandb
#start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="whisper",
    
    # track hyperparameters and run metadata
    config={
    "architecture": "whisper",
    "dataset": "en_30",
    "epochs": "10",
    "channels": "80",
    "encoder_attention_heads": "6",
    "decoder_attention_heads": "6",
    "classifier_proj_size" : 1024,
    "lr": "5e-5",
    "dropout": 0.3

    }
)


epochs = 10

for epoch in range(epochs):
    whisper_model.train()
    for batch in tqdm(train_dataloader):
        optim.zero_grad()
        audio = batch['audio'].to(device)
        labels = batch['keyword'].to(device)
        outputs = whisper_model(audio, labels=labels)
        loss = outputs['loss']
        loss.backward()
        optim.step()
        
    torch.save({
            'epoch': epoch,
            'model_state_dict': whisper_model.state_dict(),
            'optimizer_state_dict': optim.state_dict(),
            'loss': loss,
            }, f'../model/whisper/epoch_{epoch+1}')
        
    whisper_model.eval()
        
    for batch in tqdm(dev_dataloader):
        audio = batch['audio'].to(device)
        labels = batch['keyword'].to(device)
        outputs = whisper_model(audio, labels=labels)
        
        metric.add_batch(predictions=outputs.logits.argmax(-1), references=labels)

    wandb.log({"acc": metric.compute()['accuracy'], "loss": loss})
    
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mjaeihn[0m ([33mthe-wild-bunch[0m). Use [1m`wandb login --relogin`[0m to force relogin


  2%|▋                              | 75/3292 [00:05<04:00, 13.36it/s]

In [None]:
whisper_model = WhisperForAudioClassification(WhisperConfig(
    num_mel_bins=80,
    vocab_size=30,
    num_labels=31,
    max_source_positions=50,
    classifier_proj_size=512,
    encoder_ffn_dim=3072,
    # encoder_attention_heads=8,
    # decoder_attention_heads=8,
    dropout=0.2
))

whisper_model.to(device)
whisper_model.float()
whisper_model.train()
optim = torch.optim.Adam(whisper_model.parameters(), lr=5e-5)


import wandb
#start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="whisper",
    
    # track hyperparameters and run metadata
    config={
    "architecture": "whisper",
    "dataset": "en_30",
    "epochs": "10",
    "channels": "80",
    "encoder_attention_heads": "6",
    "decoder_attention_heads": "6",
    "encoder_ffn_dim": 3072,
    "classifier_proj_size" : 1024,
    "lr": "5e-5",
    "dropout": 0.3

    }
)


epochs = 10

for epoch in range(epochs):
    whisper_model.train()
    for batch in tqdm(train_dataloader):
        optim.zero_grad()
        audio = batch['audio'].to(device)
        labels = batch['keyword'].to(device)
        outputs = whisper_model(audio, labels=labels)
        loss = outputs['loss']
        loss.backward()
        optim.step()
        
    torch.save({
            'epoch': epoch,
            'model_state_dict': whisper_model.state_dict(),
            'optimizer_state_dict': optim.state_dict(),
            'loss': loss,
            }, f'../model/whisper/epoch_{epoch+1}')
        
    whisper_model.eval()
        
    for batch in tqdm(dev_dataloader):
        audio = batch['audio'].to(device)
        labels = batch['keyword'].to(device)
        outputs = whisper_model(audio, labels=labels)
        
        metric.add_batch(predictions=outputs.logits.argmax(-1), references=labels)

    wandb.log({"acc": metric.compute()['accuracy'], "loss": loss})
    
wandb.finish()