In [None]:
#| default_exp training

In [None]:
#| hide
#| eval: false
# %%capture

# !pip install git+https://github.com/huggingface/transformers.git
# !pip uninstall transformers -y



Found existing installation: transformers 4.22.0.dev0
Uninstalling transformers-4.22.0.dev0:
  Successfully uninstalled transformers-4.22.0.dev0


In [None]:
#| hide
#| export

from typing import List, Callable
from nbdev.showdoc import *
from IPython.display import display,SVG
import numpy as np
from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer
from datasets import load_metric
from datasets.dataset_dict import DatasetDict
from wav2keyword.datasets_tools import dataloader_pipeline
from wav2keyword.preprocesses import Preprocessor

# Code

In [None]:
#| export

TRAINING_ARGS = {
        'evaluation_strategy': "epoch",
        'save_strategy': "epoch",
        'learning_rate': 3e-5,
        'per_device_train_batch_size': 32,
        'gradient_accumulation_steps': 4,
        'per_device_eval_batch_size': 32,
        'num_train_epochs': 5,
        'warmup_ratio': 0.1,
        'logging_steps': 10,
        'load_best_model_at_end': True,
        'metric_for_best_model': "accuracy",
        'push_to_hub': False}

class W2KTrainer(object):

    def __init__(self, model_checkpoint: str = "facebook/wav2vec2-base", metric: str = 'accuracy'):
        self.model_checkpoint = model_checkpoint
        self.model_name = model_checkpoint.split("/")[-1]
        self.training_args = TRAINING_ARGS
        self.training_args['output_dir'] = f"{self.model_name}-finetuned-ks"
        print("loading metric")
        self.metric = load_metric(metric)
        self.preprocessor = Preprocessor(self.model_checkpoint)

    def _get_model(self, id2label: dict, label2id: dict):
        num_labels = len(id2label)
        model = AutoModelForAudioClassification.from_pretrained(
            self.model_checkpoint, 
            num_labels=num_labels,
            label2id=label2id,
            id2label=id2label,
        )
        return model

    def get_training_args(self, training_kwargs = None):
        training_args = self.training_args.copy()
        if training_kwargs:
            for k, v in training_kwargs.items():
                training_args[k] = v

        args = TrainingArguments(
            **training_args
        )
        return args

    def _compute_metrics(self, eval_pred):
        """Computes accuracy on a batch of predictions"""
        predictions = np.argmax(eval_pred.predictions, axis=1)
        return self.metric.compute(predictions=predictions, references=eval_pred.label_ids)

    def build_trainer(self, dataset, id2label, label2id, args = None, preprocess_kwargs: dict = {'max_duration': 1.0}):
        encoded_dataset = self.preprocessor.preprocess(dataset, fn_kwargs = preprocess_kwargs)
        trainer = Trainer(
            self._get_model(id2label, label2id),
            self.get_training_args(args),
            train_dataset=encoded_dataset["train"],
            eval_dataset=encoded_dataset["validation"],
            tokenizer=self.preprocessor.FEATURE_EXTRACTOR,
            compute_metrics=self._compute_metrics
        )
        return trainer

In [None]:
show_doc(W2KTrainer.get_training_args)

---

### W2KTrainer.get_training_args

>      W2KTrainer.get_training_args (training_kwargs=None)

In [None]:
show_doc(W2KTrainer.build_trainer)

---

### W2KTrainer.build_trainer

>      W2KTrainer.build_trainer (dataset, id2label, label2id, args=None)

# Examples

First we load the data

In [None]:
#|filter_stream Reusing
#|filter_stream UserWarning
#| eval: false

data = dataloader_pipeline({'path': "superb", 'name': "ks"})
dataset = data['dataset']

Reusing dataset superb (/home/jovyan/.cache/huggingface/datasets/superb/ks/1.9.0/ce836692657f82230c16b3bbcb93eaacdbfd7de4def3be90016f112d68683481)


  0%|          | 0/3 [00:00<?, ?it/s]

We can download the pretrained model and fine-tune it. `W2KTrainer` will use the `AutoModelForAudioClassification` class. Like with the feature extractor, the `from_pretrained` method in `W2KTrainer` will download and cache the model for us. As the label ids and the number of labels are dataset dependent, we pass `num_labels`, `label2id`, and `id2label` alongside the dataset here.

To instantiate a `Trainer`, we will need to define the [`TrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.TrainingArguments), which is a class that contains all the attributes to customize the training. `W2KTrainer` has a default `TrainingArguments` setup, but you can override any number of those parameters by passing a dictionary with them over the `args` argument.  
Since we are using the default `TrainingArguments`, we are not passing any custom args.

In [None]:
#|filter_stream UserWarning|_preprocess_function|VisibleDeprecationWarning
#| eval: false

w2ktrainer = W2KTrainer()
trainer = w2ktrainer.build_trainer(dataset, data['id2label'], data['label2id'])

loading metric




  0%|          | 0/52 [00:00<?, ?ba/s]

  tensor = as_tensor(value)


  0%|          | 0/7 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2ForSequenceClassification: ['project_hid.weight', 'project_q.weight', 'quantizer.weight_proj.weight', 'project_hid.bias', 'quantizer.weight_proj.bias', 'quantizer.codevectors', 'project_q.bias']
- This IS expected if you are initializing Wav2Vec2ForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.weight', 'classifier.bias', 'projector

The warning is telling us we are throwing away some weights (the `quantizer` and `project_q` layers) and randomly initializing some other (the `projector` and `classifier` layers). This is expected in this case, because we are removing the head used to pretrain the model on an unsupervised Vector Quantization objective and replacing it with a new head for which we don't have pretrained weights, so the library warns us we should fine-tune this model before using it for inference, which is exactly what we are going to do.

We will review the default `TrainingArguments` before continuing. We set the evaluation to be done at the end of each epoch, tweak the learning rate, use the `batch_size` of 32 and set the number of epochs for training in 5, as well as the weight decay. Since the best model might not be the one at the end of training, we ask the `Trainer` to load the best model it saved (according to `metric_name`) at the end of training.

In [None]:
#| eval: false

w2ktrainer.training_args

{'evaluation_strategy': 'epoch',
 'save_strategy': 'epoch',
 'learning_rate': 3e-05,
 'per_device_train_batch_size': 32,
 'gradient_accumulation_steps': 4,
 'per_device_eval_batch_size': 32,
 'num_train_epochs': 5,
 'warmup_ratio': 0.1,
 'logging_steps': 10,
 'load_best_model_at_end': True,
 'metric_for_best_model': 'accuracy',
 'push_to_hub': False,
 'output_dir': 'wav2vec2-base-finetuned-ks'}

`W2KTrainer` defines an internal method `_compute_metrics` for how to compute the metrics from the predictions, which will just use the metric passed during instantiation which defaults to `accuracy`.  
The only preprocessing it has to do is to take the argmax of our predicted logits. This is all done internally to instantiate the `Trainer`, but we show it here for completeness:

In [None]:
#| eval: false

w2ktrainer.metric

Metric(name: "accuracy", features: {'predictions': Value(dtype='int32', id=None), 'references': Value(dtype='int32', id=None)}, usage: """
Args:
    predictions: Predicted labels, as returned by a model.
    references: Ground truth labels.
    normalize: If False, return the number of correctly classified samples.
        Otherwise, return the fraction of correctly classified samples.
    sample_weight: Sample weights.
Returns:
    accuracy: Accuracy score.
Examples:

    >>> accuracy_metric = datasets.load_metric("accuracy")
    >>> results = accuracy_metric.compute(references=[0, 1], predictions=[0, 1])
    >>> print(results)
    {'accuracy': 1.0}
""", stored examples: 0)

In [None]:
#| eval: false

show_doc(W2KTrainer._compute_metrics)

---

### W2KTrainer._compute_metrics

>      W2KTrainer._compute_metrics (eval_pred)

Computes accuracy on a batch of predictions

Now we can finetune our model by calling the `train` method:

In [None]:
#| eval: false

trainer.train()

***** Running training *****
  Num examples = 51094
  Num Epochs = 5
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 128
  Gradient Accumulation steps = 4
  Total optimization steps = 1995


Epoch,Training Loss,Validation Loss,Accuracy
0,0.6597,0.567452,0.953074
1,0.292,0.175072,0.976317
2,0.1881,0.116128,0.980141
3,0.1761,0.094171,0.979847
4,0.1321,0.09012,0.981906


***** Running Evaluation *****
  Num examples = 6798
  Batch size = 32
Saving model checkpoint to wav2vec2-base-finetuned-ks/checkpoint-399
Configuration saved in wav2vec2-base-finetuned-ks/checkpoint-399/config.json
Model weights saved in wav2vec2-base-finetuned-ks/checkpoint-399/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-ks/checkpoint-399/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 6798
  Batch size = 32
Saving model checkpoint to wav2vec2-base-finetuned-ks/checkpoint-798
Configuration saved in wav2vec2-base-finetuned-ks/checkpoint-798/config.json
Model weights saved in wav2vec2-base-finetuned-ks/checkpoint-798/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-ks/checkpoint-798/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 6798
  Batch size = 32
Saving model checkpoint to wav2vec2-base-finetuned-ks/checkpoint-1197
Configuration saved in wav2vec2-base-finetuned-ks/checkpoint-1197/confi

TrainOutput(global_step=1995, training_loss=0.4566893815097952, metrics={'train_runtime': 4864.0932, 'train_samples_per_second': 52.522, 'train_steps_per_second': 0.41, 'total_flos': 2.31918157475328e+18, 'train_loss': 0.4566893815097952, 'epoch': 5.0})

We can check with the `evaluate` method that our `Trainer` did reload the best model properly (if it was not the last one):

In [None]:
#| eval: false

trainer.evaluate()

***** Running Evaluation *****
  Num examples = 6798
  Batch size = 32


{'eval_loss': 0.09011975675821304,
 'eval_accuracy': 0.9819064430714917,
 'eval_runtime': 62.8135,
 'eval_samples_per_second': 108.225,
 'eval_steps_per_second': 3.391,
 'epoch': 5.0}

Now we export the best checkpoint.

In [None]:
#| eval: false

trainer.save_model(f"{w2ktrainer.training_args['output_dir']}/best_checkpoint")

Saving model checkpoint to wav2vec2-base-finetuned-ks/best_checkpoint
Configuration saved in wav2vec2-base-finetuned-ks/best_checkpoint/config.json
Model weights saved in wav2vec2-base-finetuned-ks/best_checkpoint/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-ks/best_checkpoint/preprocessor_config.json


In [None]:
#| hide
from nbdev import nbdev_export
nbdev_export()