In [None]:
#| default_exp inference

In [None]:
#| hide
#| eval: false
# %%capture

!pip install git+https://github.com/huggingface/transformers.git



In [None]:
#| hide
#| export

from typing import List, Callable
from nbdev.showdoc import *
from IPython.display import display, Audio
import numpy as np
from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer
from datasets import load_metric
from datasets.dataset_dict import DatasetDict
from wav2keyword.datasets import dataloader_pipeline
from wav2keyword.preprocesses import Preprocessor
import torch

# Code

In [None]:
#| export

class W2KInference(object):

    def __init__(self, model_checkpoint: str, id2label: dict, label2id: dict, metric: str = 'accuracy'):
        self.model_checkpoint = model_checkpoint
        print("loading metric")
        self.metric = load_metric(metric)
        self.preprocessor = Preprocessor(self.model_checkpoint)
        self.model = self._get_model(id2label, label2id)

    def _get_model(self, id2label: dict, label2id: dict):
        num_labels = len(id2label)
        model = AutoModelForAudioClassification.from_pretrained(
            self.model_checkpoint, 
            num_labels=num_labels,
            label2id=label2id,
            id2label=id2label,
        )
        return model
    
    def predict(self, datapoint):
        encoded_dataset = self.preprocessor.FEATURE_EXTRACTOR(datapoint, return_tensors="pt")
        with torch.no_grad():
            logits = self.model(**encoded_dataset).logits
        predicted_class_ids = torch.argmax(logits, dim=-1).numpy()
        return [self.model.config.id2label[str(c)] for c in predicted_class_ids]

In [None]:
show_doc(W2KInference.predict)

---

### W2KInference.get_model

>      W2KInference.get_model (id2label:dict, label2id:dict)

# Examples

First we load the data

In [None]:
#|filter_stream Reusing
#|filter_stream UserWarning
#| eval: false

data = dataloader_pipeline({'path': "superb", 'name': "ks"})
dataset = data['dataset']

Reusing dataset superb (/home/jovyan/.cache/huggingface/datasets/superb/ks/1.9.0/ce836692657f82230c16b3bbcb93eaacdbfd7de4def3be90016f112d68683481)


  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
#|filter_stream UserWarning|_preprocess_function|VisibleDeprecationWarning
#| eval: false

w2kinference = W2KInference('wav2vec2-base-finetuned-ks/best_checkpoint', data['id2label'], data['label2id'])
w2kinference.model

loading metric


Wav2Vec2ForSequenceClassification(
  (wav2vec2): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2GroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (activation): GELUActivation()
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1): Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (2): Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (3): Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (4): Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), strid

In [None]:
#| eval: false
import torch

i = 720
datapoint = dataset['test'][i]["audio"]["array"]
with torch.no_grad():
    predicted_label = w2kinference.predict(datapoint)
print(f"predicted label: {predicted_label} - expected label: {w2kinference.model.config.id2label[str(dataset['test'][i]['label'])]}")
display(Audio(datapoint, rate=dataset['test'][i]['audio']['sampling_rate']))

It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


predicted label: ['down'] - expected label: down


In [None]:
#| eval: false

s = 750
e = 770
datapoint = [a['array'] for a in dataset['test']['audio']][s:e]
encoded_dataset = w2kinference.preprocessor.FEATURE_EXTRACTOR(datapoint, return_tensors="pt")
with torch.no_grad():
    predicted_label = w2kinference.predict(datapoint)

for i in range(s, e):
    example = dataset["test"][i]
    audio = example["audio"]
    label = str(predicted_label[i-s])
    print(f'Label: {label}')
    print(f'Shape: {audio["array"].shape}, sampling rate: {audio["sampling_rate"]}')
    display(Audio(audio["array"], rate=audio["sampling_rate"]))
    print()

It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


Label: down
Shape: (16000,), sampling rate: 16000



Label: down
Shape: (16000,), sampling rate: 16000



Label: down
Shape: (16000,), sampling rate: 16000



Label: down
Shape: (16000,), sampling rate: 16000



Label: down
Shape: (16000,), sampling rate: 16000



Label: down
Shape: (16000,), sampling rate: 16000



Label: down
Shape: (16000,), sampling rate: 16000



Label: down
Shape: (16000,), sampling rate: 16000



Label: down
Shape: (16000,), sampling rate: 16000



Label: down
Shape: (16000,), sampling rate: 16000



Label: down
Shape: (16000,), sampling rate: 16000



Label: down
Shape: (16000,), sampling rate: 16000



Label: down
Shape: (16000,), sampling rate: 16000



Label: down
Shape: (16000,), sampling rate: 16000



Label: down
Shape: (16000,), sampling rate: 16000



Label: down
Shape: (16000,), sampling rate: 16000



Label: down
Shape: (16000,), sampling rate: 16000



Label: go
Shape: (16000,), sampling rate: 16000



Label: go
Shape: (16000,), sampling rate: 16000



Label: go
Shape: (16000,), sampling rate: 16000





In [None]:
#| eval: false

w2ktrainer.metric

Metric(name: "accuracy", features: {'predictions': Value(dtype='int32', id=None), 'references': Value(dtype='int32', id=None)}, usage: """
Args:
    predictions: Predicted labels, as returned by a model.
    references: Ground truth labels.
    normalize: If False, return the number of correctly classified samples.
        Otherwise, return the fraction of correctly classified samples.
    sample_weight: Sample weights.
Returns:
    accuracy: Accuracy score.
Examples:

    >>> accuracy_metric = datasets.load_metric("accuracy")
    >>> results = accuracy_metric.compute(references=[0, 1], predictions=[0, 1])
    >>> print(results)
    {'accuracy': 1.0}
""", stored examples: 0)

In [None]:
#| eval: false

show_doc(W2KTrainer._compute_metrics)

---

### W2KTrainer._compute_metrics

>      W2KTrainer._compute_metrics (eval_pred)

Computes accuracy on a batch of predictions

Now we can finetune our model by calling the `train` method:

In [None]:
#| eval: false

trainer.train()

***** Running training *****
  Num examples = 51094
  Num Epochs = 5
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 128
  Gradient Accumulation steps = 4
  Total optimization steps = 1995


Epoch,Training Loss,Validation Loss,Accuracy
0,0.6597,0.567452,0.953074
1,0.292,0.175072,0.976317
2,0.1881,0.116128,0.980141
3,0.1761,0.094171,0.979847
4,0.1321,0.09012,0.981906


***** Running Evaluation *****
  Num examples = 6798
  Batch size = 32
Saving model checkpoint to wav2vec2-base-finetuned-ks/checkpoint-399
Configuration saved in wav2vec2-base-finetuned-ks/checkpoint-399/config.json
Model weights saved in wav2vec2-base-finetuned-ks/checkpoint-399/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-ks/checkpoint-399/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 6798
  Batch size = 32
Saving model checkpoint to wav2vec2-base-finetuned-ks/checkpoint-798
Configuration saved in wav2vec2-base-finetuned-ks/checkpoint-798/config.json
Model weights saved in wav2vec2-base-finetuned-ks/checkpoint-798/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-ks/checkpoint-798/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 6798
  Batch size = 32
Saving model checkpoint to wav2vec2-base-finetuned-ks/checkpoint-1197
Configuration saved in wav2vec2-base-finetuned-ks/checkpoint-1197/confi

TrainOutput(global_step=1995, training_loss=0.4566893815097952, metrics={'train_runtime': 4864.0932, 'train_samples_per_second': 52.522, 'train_steps_per_second': 0.41, 'total_flos': 2.31918157475328e+18, 'train_loss': 0.4566893815097952, 'epoch': 5.0})

We can check with the `evaluate` method that our `Trainer` did reload the best model properly (if it was not the last one):

In [None]:
#| eval: false

trainer.evaluate()

***** Running Evaluation *****
  Num examples = 6798
  Batch size = 32


{'eval_loss': 0.09011975675821304,
 'eval_accuracy': 0.9819064430714917,
 'eval_runtime': 62.8135,
 'eval_samples_per_second': 108.225,
 'eval_steps_per_second': 3.391,
 'epoch': 5.0}

In [None]:
#| hide
from nbdev import nbdev_export
nbdev_export()