In [None]:
#| eval: false

from pathlib import Path
from datasets import load_dataset, Features, Value, Audio, ClassLabel
import json
import pandas as pd
from collections import Counter

paths = pd.read_csv('dataset/slices_train.csv').path.values
labels = [Path(p).name.split('_')[0] for p in paths]
n = Counter(labels)
labels = [l for l in labels if n[l] > 20]
label2id = {l: ix  for ix, l in enumerate(set(labels))}
label2id['unk'] = max(i for i in label2id.values()) + 1
id2label = {ix: l for l, ix in label2id.items()}
names = sorted([(n, i) for n, i in label2id.items()], key=lambda x: x[1])
names = [n for n, i in names]


In [None]:
#| eval: false

feats = Features({"path": Value("string"),
                  "audio": Audio(sampling_rate=16_000),
                  "label": ClassLabel(names=names)}
                  )

def _generate_examples(example, label2id: dict = label2id):
        label = Path(example['path']).name.split('_')[0]
        example['label'] = label2id.get(label, label2id['unk'])
        example['audio'] = example['path']
        return example

with open('tags_data.json', 'r') as f:
    data = json.load(f)

data_files = {'train': 'dataset/slices_train.csv', 'test': 'dataset/slices_test.csv', 'val': 'dataset/slices_val.csv'}
dataset = load_dataset("csv", data_files=data_files)
dataset = dataset.remove_columns(column_names=['Unnamed: 0', 'split'])


dataset = dataset.map(_generate_examples, features=feats)
dataset = dataset.rename_column('path', 'file')

Using custom data configuration default-c58ed15a5d5a3dac
Reusing dataset csv (/home/jovyan/.cache/huggingface/datasets/csv/default-c58ed15a5d5a3dac/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a)


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/36993 [00:00<?, ?ex/s]

  0%|          | 0/4648 [00:00<?, ?ex/s]

  0%|          | 0/4586 [00:00<?, ?ex/s]

In [None]:
#| eval: false

from transformers import AutoFeatureExtractor
from transformers import Wav2Vec2ForXVector, TrainingArguments, Trainer
from datasets import load_dataset, load_metric

model_checkpoint = "facebook/wav2vec2-base"
batch_size = 32
max_duration = 1

feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)

def preprocess_function(examples):
    audio_arrays = [x["array"] for x in examples["audio"]]
    inputs = feature_extractor(
        audio_arrays, 
        sampling_rate=feature_extractor.sampling_rate, 
        max_length=int(feature_extractor.sampling_rate * max_duration), 
        truncation=True, 
    )
    return inputs

encoded_dataset = dataset.map(preprocess_function, remove_columns=["audio", "file"], batched=True)



  0%|          | 0/37 [00:00<?, ?ba/s]

  tensor = as_tensor(value)


  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/5 [00:00<?, ?ba/s]

In [None]:
#| eval: false

num_labels = len(id2label)

model = Wav2Vec2ForXVector.from_pretrained(
    model_checkpoint, 
    num_labels=num_labels,
    label2id=label2id,
    id2label=id2label,
)

model_name = model_checkpoint.split("/")[-1]

args = TrainingArguments(
    f"{model_name}-finetuned-xvector",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=35,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    push_to_hub=False,
)

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2ForXVector: ['quantizer.weight_proj.bias', 'project_hid.weight', 'quantizer.codevectors', 'project_q.weight', 'project_q.bias', 'project_hid.bias', 'quantizer.weight_proj.weight']
- This IS expected if you are initializing Wav2Vec2ForXVector from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForXVector from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForXVector were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['tdnn.0.kernel.weight', 'tdnn.3.kernel.bias', 'feature_extractor.weight', 'objective.weight', 'feature_extract

In [None]:
#| eval: false

metric = load_metric("f1", cache_dir='/home/jovyan/.cache/huggingface/metrics')
metric

Metric(name: "f1", features: {'predictions': Value(dtype='int32', id=None), 'references': Value(dtype='int32', id=None)}, usage: """
Args:
    predictions (`list` of `int`): Predicted labels.
    references (`list` of `int`): Ground truth labels.
    labels (`list` of `int`): The set of labels to include when `average` is not set to `'binary'`, and the order of the labels if `average` is `None`. Labels present in the data can be excluded, for example to calculate a multiclass average ignoring a majority negative class. Labels not present in the data will result in 0 components in a macro average. For multilabel targets, labels are column indices. By default, all labels in `predictions` and `references` are used in sorted order. Defaults to None.
    pos_label (`int`): The class to be considered the positive class, in the case where `average` is set to `binary`. Defaults to 1.
    average (`string`): This parameter is required for multiclass/multilabel targets. If set to `None`, the sco

In [None]:
#| eval: false

import numpy as np
import torch
from torch import nn

def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    logits = eval_pred.predictions[0]
    proj = model.objective._parameters['weight'].cpu().detach().numpy()
    predicted_labels = np.argmax(np.dot(logits, proj), axis=1)
    res = metric.compute(predictions=predicted_labels, references=eval_pred.label_ids, average='weighted')
    print(res)
    return res

In [None]:
#| eval: false

trainer = Trainer(
    model,
    args,
    train_dataset=encoded_dataset['train'],
    eval_dataset=encoded_dataset["val"],
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics
)


In [None]:
#| eval: false

trainer.train()

***** Running training *****
  Num examples = 36993
  Num Epochs = 35
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 128
  Gradient Accumulation steps = 4
  Total optimization steps = 10115


Epoch,Training Loss,Validation Loss,F1
0,9.972,9.282754,0.494746
1,7.7301,6.484269,0.689371
2,5.968,4.822059,0.795504
3,5.1825,3.685891,0.832956
4,4.3529,3.03828,0.865561
5,3.8573,2.449143,0.896701
6,2.9675,2.071967,0.9097
7,3.0391,1.890948,0.91462
8,2.6654,1.711604,0.927954
9,2.3606,1.603896,0.932025


***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.49474644990778804}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-289
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-289/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-289/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-289/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.6893705687016398}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-578
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-578/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-578/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-578/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.7955035044782507}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-867
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-867/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-867/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-867/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.8329559775436511}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-1156
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-1156/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-1156/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-1156/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.865561052739612}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-1445
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-1445/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-1445/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-1445/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.8967009390545976}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-1734
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-1734/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-1734/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-1734/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9097000463410372}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-2023
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-2023/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-2023/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-2023/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9146199620187198}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-2312
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-2312/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-2312/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-2312/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9279538241747127}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-2601
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-2601/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-2601/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-2601/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9320251208426955}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-2890
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-2890/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-2890/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-2890/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9368229765369066}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-3179
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-3179/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-3179/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-3179/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9364810793693422}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-3468
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-3468/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-3468/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-3468/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9388092404231434}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-3757
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-3757/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-3757/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-3757/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9481100323068246}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-4046
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-4046/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-4046/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-4046/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9535473362353735}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-4335
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-4335/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-4335/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-4335/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9566449639826673}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-4624
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-4624/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-4624/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-4624/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9602031926260322}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-4913
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-4913/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-4913/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-4913/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9603232290773355}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-5202
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-5202/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-5202/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-5202/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9644543911689891}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-5491
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-5491/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-5491/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-5491/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9645954077460133}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-5780
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-5780/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-5780/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-5780/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9645197916695856}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-6069
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-6069/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-6069/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-6069/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9691239255185394}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-6358
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-6358/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-6358/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-6358/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9661176529231834}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-6647
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-6647/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-6647/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-6647/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9665197187638916}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-6936
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-6936/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-6936/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-6936/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9710602936387585}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-7225
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-7225/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-7225/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-7225/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9713701299682312}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-7514
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-7514/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-7514/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-7514/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9708238623533595}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-7803
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-7803/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-7803/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-7803/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9744048074888573}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-8092
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-8092/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-8092/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-8092/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.972901115518863}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-8381
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-8381/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-8381/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-8381/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9723351189232675}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-8670
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-8670/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-8670/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-8670/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9735086360418925}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-8959
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-8959/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-8959/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-8959/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9741139097461785}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-9248
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-9248/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-9248/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-9248/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9731876791817936}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-9537
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-9537/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-9537/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-9537/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9732213803231736}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-9826
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-9826/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-9826/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-9826/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 4586
  Batch size = 32


{'f1': 0.9728109377883826}


Saving model checkpoint to wav2vec2-base-finetuned-xvector/checkpoint-10115
Configuration saved in wav2vec2-base-finetuned-xvector/checkpoint-10115/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/checkpoint-10115/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/checkpoint-10115/preprocessor_config.json


Training completed. Do not forget to share your model on huggingface.co/models =)


Loading best model from wav2vec2-base-finetuned-xvector/checkpoint-8092 (score: 0.9744048074888573).


TrainOutput(global_step=10115, training_loss=2.4172881723452013, metrics={'train_runtime': 25124.955, 'train_samples_per_second': 51.533, 'train_steps_per_second': 0.403, 'total_flos': 1.2024582268257595e+19, 'train_loss': 2.4172881723452013, 'epoch': 35.0})

In [None]:
#| eval: false

trainer.save_model(f"{model_name}-finetuned-xvector/best_checkpoint")

Saving model checkpoint to wav2vec2-base-finetuned-xvector/best_checkpoint
Configuration saved in wav2vec2-base-finetuned-xvector/best_checkpoint/config.json
Model weights saved in wav2vec2-base-finetuned-xvector/best_checkpoint/pytorch_model.bin
Feature extractor saved in wav2vec2-base-finetuned-xvector/best_checkpoint/preprocessor_config.json


In [None]:
#| eval: false

trainer._load_from_checkpoint(f"{model_name}-finetuned-xvector/best_checkpoint")

Loading model from wav2vec2-base-finetuned-xvector/best_checkpoint.


In [None]:
#| eval: false

inputs = encoded_dataset['test']

with torch.no_grad():
        result = trainer.predict(test_dataset = inputs)
result.metrics

***** Running Prediction *****
  Num examples = 4648
  Batch size = 32


{'f1': 0.9685338664124005}


{'test_loss': 0.8347758054733276,
 'test_f1': 0.9685338664124005,
 'test_runtime': 37.9534,
 'test_samples_per_second': 122.466,
 'test_steps_per_second': 3.847}

In [None]:
#| eval: false

from collections import Counter
def get_predicted_labels(logits):
    proj = model.objective._parameters['weight'].cpu().detach().numpy()
    return np.argmax(np.dot(logits, proj), axis=1)
    
predicted_labels = get_predicted_labels(result.predictions[0])
hits = [(t, p==t) for p, t in zip(predicted_labels, result.label_ids)]

per_label_acc = {}
for l, h in hits:
    if h:
        per_label_acc[l] = per_label_acc.get(l, 0) + 1
accs = []
for k, v in per_label_acc.items():
    n = Counter(result.label_ids)[k]
    accs.append((id2label[k], round(v/n, 2), n))
sorted(accs, key=lambda x: x[2], reverse=True)

[('chicken', 1.0, 1811),
 ('orange', 1.0, 346),
 ('rice', 0.9970149253731343, 335),
 ('entrees', 1.0, 204),
 ('entree', 1.0, 199),
 ('can', 0.9259259259259259, 162),
 ('honey', 1.0, 102),
 ('shrimp', 1.0, 98),
 ('steak', 1.0, 94),
 ('drinks', 1.0, 82),
 ('steamed', 1.0, 70),
 ('unk', 0.7246376811594203, 69),
 ('chow', 0.8382352941176471, 68),
 ('tea', 1.0, 67),
 ('drink', 1.0, 66),
 ('plate', 0.8032786885245902, 61),
 ('side', 1.0, 58),
 ('mein', 0.9814814814814815, 54),
 ('beef', 0.8958333333333334, 48),
 ('lemonade', 1.0, 47),
 ('bowl', 0.9333333333333333, 45),
 ('large', 1.0, 41),
 ('small', 0.9411764705882353, 34),
 ('one', 0.6451612903225806, 31),
 ('greens', 1.0, 28),
 ('two', 0.6538461538461539, 26),
 ('fried', 0.76, 25),
 ('medium', 1.0, 24),
 ('walnut', 0.9090909090909091, 22),
 ('mushroom', 0.9473684210526315, 19),
 ('teriyaki', 0.9444444444444444, 18),
 ('kung', 0.5, 16),
 ('broccoli', 1.0, 16),
 ('pao', 0.8571428571428571, 14),
 ('meal', 0.9285714285714286, 14),
 ('instead'

In [None]:
#| eval: false

from numpy import dot
from numpy.linalg import norm
from scipy import spatial
from tqdm import tqdm

cosine_sim = torch.nn.CosineSimilarity(dim=-1)
similarities = []
for ix, emb in tqdm(enumerate(result.predictions[1]), total=len(result.predictions[1])):
    max_sim = 0
    max_label = None
    for jx, emb_2 in enumerate(result.predictions[1]):
        if ix == jx:
            continue
        # sim = dot(emb, emb_2)/(norm(emb)*norm(emb_2))
        sim = 1 - spatial.distance.cosine(emb, emb_2)
        if sim > max_sim:
            max_sim = sim
            max_label = id2label[result.label_ids[jx]]
    l = id2label[result.label_ids[ix]]
    similarities.append((l, max_label))
similarities[:5]


100%|██████████| 4648/4648 [14:14<00:00,  5.44it/s]


[('water', 'water'),
 ('water', 'water'),
 ('water', 'water'),
 ('small', 'small'),
 ('small', 'small')]

In [None]:
#| eval: false

sum([t==p for t, p in similarities])/len(similarities)


0.9677280550774526

In [None]:
#| eval: false

per_label_acc = {}
for t, p in similarities:
    if t==p:
        per_label_acc[t] = per_label_acc.get(t, 0) + 1
accs = []
for k, v in per_label_acc.items():
    n = Counter(result.label_ids)[label2id[k]]
    accs.append((k, round(v/n, 2), n))
sorted(accs, key=lambda x: x[2], reverse=True)

[('chicken', 1.0, 1811),
 ('orange', 1.0, 346),
 ('rice', 0.99, 335),
 ('entrees', 1.0, 204),
 ('entree', 1.0, 199),
 ('can', 0.9, 162),
 ('honey', 1.0, 102),
 ('shrimp', 1.0, 98),
 ('steak', 1.0, 94),
 ('drinks', 1.0, 82),
 ('steamed', 1.0, 70),
 ('unk', 0.64, 69),
 ('chow', 0.9, 68),
 ('tea', 1.0, 67),
 ('drink', 1.0, 66),
 ('plate', 0.84, 61),
 ('side', 1.0, 58),
 ('mein', 0.94, 54),
 ('beef', 0.88, 48),
 ('lemonade', 1.0, 47),
 ('bowl', 0.93, 45),
 ('large', 1.0, 41),
 ('small', 0.94, 34),
 ('one', 0.61, 31),
 ('greens', 0.96, 28),
 ('two', 0.73, 26),
 ('fried', 0.8, 25),
 ('medium', 0.96, 24),
 ('walnut', 0.91, 22),
 ('mushroom', 0.89, 19),
 ('teriyaki', 0.94, 18),
 ('kung', 0.56, 16),
 ('broccoli', 1.0, 16),
 ('pao', 0.86, 14),
 ('meal', 0.93, 14),
 ('instead', 1.0, 14),
 ('pepper', 0.71, 14),
 ('coke', 1.0, 13),
 ('strawberry', 1.0, 13),
 ('black', 0.92, 12),
 ('chickens', 1.0, 11),
 ('bigger', 0.9, 10),
 ('bowls', 0.8, 10),
 ('green', 0.78, 9),
 ('veggie', 1.0, 8),
 ('raspberry