In [188]:
from transformers import Wav2Vec2Config, Wav2Vec2ForSequenceClassification
import torch
import librosa
from torch.utils.data import DataLoader
from datasets import load_dataset
from datasets import Dataset
from tqdm import tqdm
import numpy as np
import evaluate

In [189]:
class KWS_dataset(torch.utils.data.Dataset):
    def __init__(self, input_data, output_data):
        self.input_data = input_data
        self.output_data = output_data
        
    def __len__(self):
        return len(self.input_data)
    
    def __getitem__(self, index):
        keyword = self.output_data[index]
        audio_features = self.input_data[index]
        # return audio_features, keyword
        return {'audio': audio_features,
                'keyword': keyword
               }

In [190]:
data_path = '../data/'

In [191]:
train_dataloader = torch.load('../data/en_splits_10_trainloader')
dev_dataloader = torch.load('../data/en_splits_10_devloader')
test_dataloader = torch.load('../data/en_splits_10_testloader')

In [192]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [193]:
wav2vec_model = Wav2Vec2ForSequenceClassification(Wav2Vec2Config(
    vocab_size=10,
    hidden_size=512,
    num_hidden_layers=2,
    num_attention_heads=2,
    conv_dim=[512,512,512],
    conv_kernel=[2,2,2], 
    conv_stride=[5,2,2],
    mask_time_length=1,
    num_labels=10
))

In [194]:
wav2vec_model.to(device)
wav2vec_model.double()
wav2vec_model.train()

optim = torch.optim.Adam(wav2vec_model.parameters(), lr=5e-5)

In [196]:
metric = evaluate.load("accuracy")

In [197]:
import wandb
import random

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="wav2vec",
    
    # track hyperparameters and run metadata
    config={
    "architecture": "wav2vec2",
    "dataset": "en_10_1600",
    "epochs": 10,
    'hidden_size': 512,
    'num_hidden_layers': 2,
    'num_attention_heads': 2,
    'conv_dim': [512,512,512],
    'conv_kernel': [2,2,2], 
    'conv_stride': [5,2,2],
    'mask_time_length': 1,
    }
)

# simulate training
epochs = 10

for epoch in range(epochs):
    wav2vec_model.train()
    for batch in tqdm(train_dataloader):
        optim.zero_grad()
        audio = batch['audio'].to(device)
        labels = batch['keyword'].to(device)
        outputs = wav2vec_model(audio, labels=labels)
        loss = outputs['loss']
        loss.backward()
        optim.step()
        
    torch.save({
            'epoch': epoch,
            'model_state_dict': wav2vec_model.state_dict(),
            'optimizer_state_dict': optim.state_dict(),
            'loss': loss,
            }, f'../model/wav2vec/epoch_{epoch+1}')
        
    wav2vec_model.eval()
        
    for batch in tqdm(dev_dataloader):
        audio = batch['audio'].to(device)
        labels = batch['keyword'].to(device)
        outputs = wav2vec_model(audio, labels=labels)
        
        metric.add_batch(predictions=outputs.logits.argmax(-1), references=labels)

    wandb.log({"acc": metric.compute()['accuracy'], "loss": loss})
    
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
acc,▃▁▇▆█▅
loss,█▁▇▇▄▁

0,1
acc,0.13124
loss,2.18014


100%|█████████████████████████████████████████| 998/998 [04:58<00:00,  3.34it/s]
100%|█████████████████████████████████████████████| 4/4 [00:17<00:00,  4.41s/it]
100%|█████████████████████████████████████████| 998/998 [04:56<00:00,  3.37it/s]
100%|█████████████████████████████████████████████| 4/4 [00:21<00:00,  5.26s/it]
100%|█████████████████████████████████████████| 998/998 [04:54<00:00,  3.39it/s]
100%|█████████████████████████████████████████████| 4/4 [00:18<00:00,  4.53s/it]
100%|█████████████████████████████████████████| 998/998 [04:54<00:00,  3.39it/s]
100%|█████████████████████████████████████████████| 4/4 [00:16<00:00,  4.11s/it]
100%|█████████████████████████████████████████| 998/998 [05:03<00:00,  3.29it/s]
100%|█████████████████████████████████████████████| 4/4 [00:20<00:00,  5.21s/it]
100%|█████████████████████████████████████████| 998/998 [05:24<00:00,  3.07it/s]
100%|█████████████████████████████████████████████| 4/4 [00:18<00:00,  4.52s/it]
100%|███████████████████████

0,1
acc,▁▅▄▇▆▆█▇▇▇
loss,▆▆█▃▇▄▁▃▁▂

0,1
acc,0.32905
loss,1.3794
