In [28]:
from transformers import WhisperForAudioClassification, WhisperConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
import evaluate
import librosa
import wandb
from torch.utils.data import DataLoader
from datasets import load_dataset
from datasets import Dataset
from tqdm import tqdm

In [2]:
class KWS_dataset(torch.utils.data.Dataset):
    def __init__(self, input_data, output_data):
        self.input_data = input_data
        self.output_data = output_data
        
    def __len__(self):
        return len(self.input_data)
    
    def __getitem__(self, index):
        keyword = self.output_data[index]
        audio_features = self.input_data[index]
        # return audio_features, keyword
        return {'audio': audio_features,
                'keyword': keyword
               }

In [3]:
data_path = '../data/'

In [4]:
train_dataloader = torch.load('../data/en_splits_30.trainloader')
dev_dataloader = torch.load('../data/en_splits_30.devloader')
test_dataloader = torch.load('../data/en_splits_30.testloader')

In [5]:
metric = evaluate.load("accuracy")

In [6]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [8]:
whisper_model = WhisperForAudioClassification(WhisperConfig(
    num_mel_bins=80,
    vocab_size=30,
    num_labels=31,
    max_source_positions=50,
    classifier_proj_size=512,
    encoder_layer=8,
    decoder_layer=8,
    dropout=0.2
))

In [9]:
whisper_model.from_pretrained("jaeihn/kws_embedding")

Downloading (…)lve/main/config.json:   0%|          | 0.00/2.50k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/27.0M [00:00<?, ?B/s]

WhisperForAudioClassification(
  (encoder): WhisperEncoder(
    (conv1): Conv1d(80, 256, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(256, 256, kernel_size=(3,), stride=(2,), padding=(1,))
    (embed_positions): Embedding(50, 256)
    (layers): ModuleList(
      (0-5): 6 x WhisperEncoderLayer(
        (self_attn): WhisperAttention(
          (k_proj): Linear(in_features=256, out_features=256, bias=False)
          (v_proj): Linear(in_features=256, out_features=256, bias=True)
          (q_proj): Linear(in_features=256, out_features=256, bias=True)
          (out_proj): Linear(in_features=256, out_features=256, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (activation_fn): GELUActivation()
        (fc1): Linear(in_features=256, out_features=1536, bias=True)
        (fc2): Linear(in_features=1536, out_features=256, bias=True)
        (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine

In [11]:
class KWS_classifier(nn.Module):
    def __init__(self, input_size=31, output_size=3):
        super(KWS_classifier, self).__init__()
        self.linear = nn.Linear(input_size, output_size)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        # Freeze embedding weights
        x = self.linear(x)
        x = self.softmax(x)
        return x

In [27]:
with open(data_path + 'keywords_en_50.txt') as f:
    keywords = [word.strip() for word in f.readlines()][30:]
print(" ".join(keywords))

got school way name work city however little right found may four much known years called alchemist make world come


```bash
for word in got school way name work city however little right found may four much known years called alchemist make world come; 
do 
    python kws_preparation.py en/en_splits.csv keywords_en_30.txt $word; done
```

In [32]:
def keyword_spotting(word):
    kws_train_dataloader = torch.load(data_path+word+'_128_kws.trainloader')
    kws_dev_dataloader = torch.load(data_path+word+'_128_kws.devloader')

    kws_model = KWS_classifier(input_size=31)
    whisper_model.to(device)
    optim = torch.optim.Adam(kws_model.parameters(),lr=0.01)
    loss_fn = nn.CrossEntropyLoss()
    
    epochs = 100

    for epoch in tqdm(range(epochs)):
        kws_model.train()
        whisper_model.eval()
        for batch in kws_train_dataloader:
            optim.zero_grad()
            audio = batch['audio'].to(device)
            labels = batch['keyword'].to(device)
            outputs = whisper_model(audio)
            outputs = kws_model(outputs.logits)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optim.step()

        kws_model.eval()

        for batch in kws_dev_dataloader:
            audio = batch['audio'].to(device)
            labels = batch['keyword'].to(device)
            outputs = whisper_model(audio)
            outputs = kws_model(outputs.logits)

            metric.add_batch(predictions=outputs.argmax(-1), references=labels)
    return metric.compute()['accuracy']

In [None]:
mono_en_accuracy = {}
for word in keywords:
    mono_en_accuracy[word] = keyword_spotting(word)
print("COMPLETE")

100%|█████████████████████████████████████████| 100/100 [03:20<00:00,  2.01s/it]
100%|█████████████████████████████████████████| 100/100 [03:19<00:00,  2.00s/it]
100%|█████████████████████████████████████████| 100/100 [03:20<00:00,  2.00s/it]
100%|█████████████████████████████████████████| 100/100 [04:30<00:00,  2.70s/it]
100%|█████████████████████████████████████████| 100/100 [05:19<00:00,  3.20s/it]
100%|█████████████████████████████████████████| 100/100 [05:18<00:00,  3.19s/it]
100%|█████████████████████████████████████████| 100/100 [05:15<00:00,  3.16s/it]
100%|█████████████████████████████████████████| 100/100 [05:16<00:00,  3.17s/it]
100%|█████████████████████████████████████████| 100/100 [05:16<00:00,  3.17s/it]
100%|█████████████████████████████████████████| 100/100 [05:18<00:00,  3.18s/it]
100%|█████████████████████████████████████████| 100/100 [05:18<00:00,  3.19s/it]
100%|█████████████████████████████████████████| 100/100 [05:20<00:00,  3.20s/it]
100%|███████████████████████

In [85]:
kws_train_dataloader = torch.load('../data/people_128_kws.trainloader')
kws_dev_dataloader = torch.load('../data/people_128_kws.trainloader-1')


kws_model = KWS_classifier(input_size=31)
whisper_model.to(device)
optim = torch.optim.Adam(kws_model.parameters(),lr=0.01)
loss_fn = nn.CrossEntropyLoss()

# wandb.init(
#     # set the wandb project where this run will be logged
#     project="kws",
#     config= {
#     "architecture": "softmax",
#     "dataset": "people",
#     "epochs": "10", 
#     }
    
# )

epochs = 50

for epoch in tqdm(range(epochs)):
    kws_model.train()
    whisper_model.eval()
    for batch in kws_train_dataloader:
        optim.zero_grad()
        audio = batch['audio'].to(device)
        labels = batch['keyword'].to(device)
        outputs = whisper_model(audio)
        outputs = kws_model(outputs.logits)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optim.step()

    kws_model.eval()
        
    for batch in kws_dev_dataloader:
        audio = batch['audio'].to(device)
        labels = batch['keyword'].to(device)
        outputs = whisper_model(audio)
        outputs = kws_model(outputs.logits)
        
        metric.add_batch(predictions=outputs.argmax(-1), references=labels)

print(metric.compute()['accuracy'])
    # wandb.log({"acc": metric.compute()['accuracy'], "loss": loss})
    
# wandb.finish()

 10%|███▍                              | 5/50 [00:20<03:07,  4.16s/it]wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
 16%|█████▍                            | 8/50 [00:33<02:55,  4.17s/it]wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
 24%|███████▉                         | 12/50 [00:49<02:38,  4.17s/it]wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
 30%|█████████▉                       | 15/50 [01:02<02:26,  4.20s/it]wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
 38%|████████████▌                    | 19/50 [01:19<02:11,  4.25s/it]wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
 44%|██████████████▌                  | 22/50 [01:32<01:57,  4.20s/it]wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
 52%|█████████████████▏               | 26/50 [01:48<01:41,  4.21s/it]wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)

0.7992957746478874



