In [1]:
from datasets import load_from_disk, load_dataset, Audio

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "facebook/wav2vec2-base"
#model_name = "facebook/wav2vec2-large"

In [3]:
from transformers import AutoFeatureExtractor

feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)



In [4]:
datasets = load_from_disk("data")

In [5]:
datasets

DatasetDict({
    train: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
        num_rows: 70578
    })
    test: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
        num_rows: 9951
    })
    new_unseen: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
        num_rows: 29556
    })
    drift: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
        num_rows: 42697
    })
})

In [6]:
labels = ["male", "female"]
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [7]:
label2id

{'male': '0', 'female': '1'}

In [8]:
id2label

{'0': 'male', '1': 'female'}

In [9]:
train_dataset = datasets['train']
test_dataset = datasets['test']

In [10]:
train_dataset = train_dataset.cast_column("audio", Audio(sampling_rate=16_000))
train_dataset[0]

{'client_id': '29b1e5a58d1667d4ac45832ec195356598a69f66680877b0d5ee465ce2404c0186affc81ffe4a29df35203fc07a0fc5714c60d914a88aa36d7f84c94dc381d2f',
 'path': '/home/students/s289159/.cache/huggingface/datasets/downloads/extracted/bab7205fb7eb744fb5f8ca6c7fc1a1096a0632122d952545fad5b84bb21ff6e4/cv-corpus-6.1-2020-12-11/en/clips/common_voice_en_122577.mp3',
 'audio': {'path': '/home/students/s289159/.cache/huggingface/datasets/downloads/extracted/bab7205fb7eb744fb5f8ca6c7fc1a1096a0632122d952545fad5b84bb21ff6e4/cv-corpus-6.1-2020-12-11/en/clips/common_voice_en_122577.mp3',
  'array': array([ 0.        ,  0.        ,  0.        , ..., -0.00863345,
          0.00414815, -0.00097276], dtype=float32),
  'sampling_rate': 16000},
 'sentence': 'Two women are smiling next to a microphone on a stage.',
 'up_votes': 2,
 'down_votes': 0,
 'age': 'teens',
 'gender': 'male',
 'accent': 'us',
 'locale': 'en',
 'segment': "''"}

In [11]:
test_dataset = test_dataset.cast_column("audio", Audio(sampling_rate=16_000))
test_dataset[0]

{'client_id': 'f148bbf4cd30561010300193263d00b4b009118933da4c5cc7c8cb166f24e9a1cd232f8073c7574055f8dbb373fb0d69b28b5f5e9659d011feff4345e160044f',
 'path': '/home/students/s289159/.cache/huggingface/datasets/downloads/extracted/bab7205fb7eb744fb5f8ca6c7fc1a1096a0632122d952545fad5b84bb21ff6e4/cv-corpus-6.1-2020-12-11/en/clips/common_voice_en_162540.mp3',
 'audio': {'path': '/home/students/s289159/.cache/huggingface/datasets/downloads/extracted/bab7205fb7eb744fb5f8ca6c7fc1a1096a0632122d952545fad5b84bb21ff6e4/cv-corpus-6.1-2020-12-11/en/clips/common_voice_en_162540.mp3',
  'array': array([ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
         -8.9534951e-06,  5.4251259e-06, -2.7791944e-05], dtype=float32),
  'sampling_rate': 16000},
 'sentence': 'Two young, White males are outside near many bushes.',
 'up_votes': 3,
 'down_votes': 0,
 'age': 'seventies',
 'gender': 'male',
 'accent': 'us',
 'locale': 'en',
 'segment': "''"}

In [12]:
def preprocess_function(examples):
    audio_arrays = [x["array"] for x in examples["audio"]]
    inputs = feature_extractor(
        audio_arrays, sampling_rate=feature_extractor.sampling_rate, max_length=16000, truncation=True
    )
    return inputs

def convert_label(example):
    example['gender'] = int(label2id[example['gender']])
    return example

    

In [13]:
train_dataset['gender'][:10]

['male',
 'male',
 'male',
 'male',
 'male',
 'male',
 'male',
 'male',
 'male',
 'male']

In [14]:
train_dataset = train_dataset.shuffle(seed=42)
test_dataset = test_dataset.shuffle(seed=42)

In [15]:
train_dataset = train_dataset.map(convert_label)
encoded_train_audios = train_dataset.map(preprocess_function, remove_columns="audio", batched=True)
encoded_train_audios = encoded_train_audios.rename_column("gender", "label")

100%|██████████| 70578/70578 [00:17<00:00, 3958.81ex/s]
100%|██████████| 71/71 [22:35<00:00, 19.09s/ba]


In [16]:
test_dataset = test_dataset.map(convert_label)
encoded_test_audios = test_dataset.map(preprocess_function, remove_columns="audio", batched=True)
encoded_test_audios = encoded_test_audios.rename_column("gender", "label")

100%|██████████| 9951/9951 [00:02<00:00, 4116.63ex/s]
100%|██████████| 10/10 [03:06<00:00, 18.62s/ba]


In [17]:
encoded_train_audios

Dataset({
    features: ['client_id', 'path', 'sentence', 'up_votes', 'down_votes', 'age', 'label', 'accent', 'locale', 'segment', 'input_values'],
    num_rows: 70578
})

In [18]:
encoded_train_audios['label'][:10]

[0, 1, 1, 1, 1, 0, 1, 1, 1, 0]

In [19]:
encoded_test_audios

Dataset({
    features: ['client_id', 'path', 'sentence', 'up_votes', 'down_votes', 'age', 'label', 'accent', 'locale', 'segment', 'input_values'],
    num_rows: 9951
})

In [20]:
encoded_test_audios['label'][:10]

[1, 1, 1, 1, 1, 0, 0, 0, 1, 0]

In [21]:
from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer

In [22]:
from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer

num_labels = len(id2label)

model = AutoModelForAudioClassification.from_pretrained(model_name,
                                                        num_labels=num_labels, 
                                                        label2id=label2id, 
                                                        id2label=id2label)

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2ForSequenceClassification: ['project_q.weight', 'quantizer.weight_proj.weight', 'quantizer.weight_proj.bias', 'quantizer.codevectors', 'project_q.bias', 'project_hid.bias', 'project_hid.weight']
- This IS expected if you are initializing Wav2Vec2ForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['projector.bias', 'classifier.weight', 'projector.

In [23]:
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, classification_report
import sklearn

def compute_metrics(pred):
    labels = pred.label_ids
    print(pred)
    try:
        preds = pred.predictions.argmax(-1)
    except:
        preds = pred.predictions[0].argmax(-1)
    precision, recall, f1, _ = sklearn.metrics.precision_recall_fscore_support(
        labels, preds, average="macro", labels=list(set(labels))
    )
    print(sklearn.metrics.classification_report(labels, preds, digits=4))
    acc = sklearn.metrics.accuracy_score(labels, preds)
    return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}

In [24]:
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    #learning_rate=3e-5,
    learning_rate=3e-5,
    num_train_epochs=10,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    metric_for_best_model="f1",
    load_best_model_at_end=True,
    greater_is_better=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_train_audios,
    eval_dataset=encoded_test_audios,
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
)

In [25]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `Wav2Vec2ForSequenceClassification.forward` and have been ignored: client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes. If client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes are not expected by `Wav2Vec2ForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 70578
  Num Epochs = 10
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 44120


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,0.2617,0.498229,0.819315,0.819307,0.81932,0.819302
2,0.2598,0.406211,0.843935,0.841268,0.867323,0.843415
3,0.382,0.696439,0.502261,0.334875,0.626081,0.500203
4,0.244,0.643022,0.776404,0.770976,0.806931,0.777052
5,0.303,0.54003,0.835192,0.834859,0.838373,0.835392
6,0.191,0.498012,0.860919,0.860827,0.861625,0.860828
7,0.1817,0.734451,0.822631,0.821148,0.834617,0.82302
8,0.1593,0.608336,0.842126,0.841636,0.84699,0.84237
9,0.1279,0.651757,0.843131,0.842626,0.848178,0.843379
10,0.1338,0.703203,0.849965,0.849439,0.855535,0.850223


The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForSequenceClassification.forward` and have been ignored: client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes. If client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes are not expected by `Wav2Vec2ForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 9951
  Batch size = 16


<transformers.trainer_utils.EvalPrediction object at 0x7f662e0f1550>
              precision    recall  f1-score   support

           0     0.8200    0.8163    0.8182      4955
           1     0.8187    0.8223    0.8205      4996

    accuracy                         0.8193      9951
   macro avg     0.8193    0.8193    0.8193      9951
weighted avg     0.8193    0.8193    0.8193      9951



Saving model checkpoint to ./results/checkpoint-4412
Configuration saved in ./results/checkpoint-4412/config.json
Model weights saved in ./results/checkpoint-4412/pytorch_model.bin
Feature extractor saved in ./results/checkpoint-4412/preprocessor_config.json
The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForSequenceClassification.forward` and have been ignored: client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes. If client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes are not expected by `Wav2Vec2ForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 9951
  Batch size = 16


<transformers.trainer_utils.EvalPrediction object at 0x7f662e4bc4c0>
              precision    recall  f1-score   support

           0     0.9590    0.7173    0.8207      4955
           1     0.7757    0.9696    0.8618      4996

    accuracy                         0.8439      9951
   macro avg     0.8673    0.8434    0.8413      9951
weighted avg     0.8669    0.8439    0.8414      9951



Saving model checkpoint to ./results/checkpoint-8824
Configuration saved in ./results/checkpoint-8824/config.json
Model weights saved in ./results/checkpoint-8824/pytorch_model.bin
Feature extractor saved in ./results/checkpoint-8824/preprocessor_config.json
The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForSequenceClassification.forward` and have been ignored: client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes. If client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes are not expected by `Wav2Vec2ForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 9951
  Batch size = 16


<transformers.trainer_utils.EvalPrediction object at 0x7f662e0efbe0>
              precision    recall  f1-score   support

           0     0.7500    0.0006    0.0012      4955
           1     0.5022    0.9998    0.6685      4996

    accuracy                         0.5023      9951
   macro avg     0.6261    0.5002    0.3349      9951
weighted avg     0.6256    0.5023    0.3362      9951



Saving model checkpoint to ./results/checkpoint-13236
Configuration saved in ./results/checkpoint-13236/config.json
Model weights saved in ./results/checkpoint-13236/pytorch_model.bin
Feature extractor saved in ./results/checkpoint-13236/preprocessor_config.json
The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForSequenceClassification.forward` and have been ignored: client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes. If client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes are not expected by `Wav2Vec2ForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 9951
  Batch size = 16


<transformers.trainer_utils.EvalPrediction object at 0x7f662e0d8fa0>
              precision    recall  f1-score   support

           0     0.7091    0.9342    0.8062      4955
           1     0.9048    0.6199    0.7357      4996

    accuracy                         0.7764      9951
   macro avg     0.8069    0.7771    0.7710      9951
weighted avg     0.8073    0.7764    0.7708      9951



Saving model checkpoint to ./results/checkpoint-17648
Configuration saved in ./results/checkpoint-17648/config.json
Model weights saved in ./results/checkpoint-17648/pytorch_model.bin
Feature extractor saved in ./results/checkpoint-17648/preprocessor_config.json
The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForSequenceClassification.forward` and have been ignored: client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes. If client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes are not expected by `Wav2Vec2ForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 9951
  Batch size = 16


<transformers.trainer_utils.EvalPrediction object at 0x7f662e0e69d0>
              precision    recall  f1-score   support

           0     0.8045    0.8838    0.8423      4955
           1     0.8722    0.7870    0.8274      4996

    accuracy                         0.8352      9951
   macro avg     0.8384    0.8354    0.8349      9951
weighted avg     0.8385    0.8352    0.8348      9951



Saving model checkpoint to ./results/checkpoint-22060
Configuration saved in ./results/checkpoint-22060/config.json
Model weights saved in ./results/checkpoint-22060/pytorch_model.bin
Feature extractor saved in ./results/checkpoint-22060/preprocessor_config.json
The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForSequenceClassification.forward` and have been ignored: client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes. If client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes are not expected by `Wav2Vec2ForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 9951
  Batch size = 16


<transformers.trainer_utils.EvalPrediction object at 0x7f6634335d00>
              precision    recall  f1-score   support

           0     0.8766    0.8387    0.8573      4955
           1     0.8466    0.8829    0.8644      4996

    accuracy                         0.8609      9951
   macro avg     0.8616    0.8608    0.8608      9951
weighted avg     0.8616    0.8609    0.8608      9951



Saving model checkpoint to ./results/checkpoint-26472
Configuration saved in ./results/checkpoint-26472/config.json
Model weights saved in ./results/checkpoint-26472/pytorch_model.bin
Feature extractor saved in ./results/checkpoint-26472/preprocessor_config.json
The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForSequenceClassification.forward` and have been ignored: client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes. If client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes are not expected by `Wav2Vec2ForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 9951
  Batch size = 16


<transformers.trainer_utils.EvalPrediction object at 0x7f662e0efbe0>
              precision    recall  f1-score   support

           0     0.7702    0.9175    0.8374      4955
           1     0.8990    0.7286    0.8049      4996

    accuracy                         0.8226      9951
   macro avg     0.8346    0.8230    0.8211      9951
weighted avg     0.8349    0.8226    0.8211      9951



Saving model checkpoint to ./results/checkpoint-30884
Configuration saved in ./results/checkpoint-30884/config.json
Model weights saved in ./results/checkpoint-30884/pytorch_model.bin
Feature extractor saved in ./results/checkpoint-30884/preprocessor_config.json
The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForSequenceClassification.forward` and have been ignored: client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes. If client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes are not expected by `Wav2Vec2ForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 9951
  Batch size = 16


<transformers.trainer_utils.EvalPrediction object at 0x7f6634335f70>
              precision    recall  f1-score   support

           0     0.8049    0.9015    0.8505      4955
           1     0.8891    0.7832    0.8328      4996

    accuracy                         0.8421      9951
   macro avg     0.8470    0.8424    0.8416      9951
weighted avg     0.8472    0.8421    0.8416      9951



Saving model checkpoint to ./results/checkpoint-35296
Configuration saved in ./results/checkpoint-35296/config.json
Model weights saved in ./results/checkpoint-35296/pytorch_model.bin
Feature extractor saved in ./results/checkpoint-35296/preprocessor_config.json
The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForSequenceClassification.forward` and have been ignored: client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes. If client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes are not expected by `Wav2Vec2ForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 9951
  Batch size = 16


<transformers.trainer_utils.EvalPrediction object at 0x7f662e0ef9a0>
              precision    recall  f1-score   support

           0     0.8052    0.9035    0.8515      4955
           1     0.8911    0.7832    0.8337      4996

    accuracy                         0.8431      9951
   macro avg     0.8482    0.8434    0.8426      9951
weighted avg     0.8484    0.8431    0.8426      9951



Saving model checkpoint to ./results/checkpoint-39708
Configuration saved in ./results/checkpoint-39708/config.json
Model weights saved in ./results/checkpoint-39708/pytorch_model.bin
Feature extractor saved in ./results/checkpoint-39708/preprocessor_config.json
The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForSequenceClassification.forward` and have been ignored: client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes. If client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes are not expected by `Wav2Vec2ForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 9951
  Batch size = 16


<transformers.trainer_utils.EvalPrediction object at 0x7f6634335f70>
              precision    recall  f1-score   support

           0     0.8100    0.9128    0.8583      4955
           1     0.9011    0.7876    0.8405      4996

    accuracy                         0.8500      9951
   macro avg     0.8555    0.8502    0.8494      9951
weighted avg     0.8557    0.8500    0.8494      9951



Saving model checkpoint to ./results/checkpoint-44120
Configuration saved in ./results/checkpoint-44120/config.json
Model weights saved in ./results/checkpoint-44120/pytorch_model.bin
Feature extractor saved in ./results/checkpoint-44120/preprocessor_config.json


Training completed. Do not forget to share your model on huggingface.co/models =)


Loading best model from ./results/checkpoint-26472 (score: 0.8608271100478655).


TrainOutput(global_step=44120, training_loss=0.23970040963761718, metrics={'train_runtime': 13093.2625, 'train_samples_per_second': 53.904, 'train_steps_per_second': 3.37, 'total_flos': 6.4075173446592e+18, 'train_loss': 0.23970040963761718, 'epoch': 10.0})

In [26]:
print("EVALUATION BEST MODEL ON TEST SET")
print(trainer.evaluate())

The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForSequenceClassification.forward` and have been ignored: client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes. If client_id, down_votes, path, age, segment, sentence, accent, locale, up_votes are not expected by `Wav2Vec2ForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 9951
  Batch size = 16


EVALUATION BEST MODEL ON TEST SET


<transformers.trainer_utils.EvalPrediction object at 0x7f662e0efbe0>
              precision    recall  f1-score   support

           0     0.8766    0.8387    0.8573      4955
           1     0.8466    0.8829    0.8644      4996

    accuracy                         0.8609      9951
   macro avg     0.8616    0.8608    0.8608      9951
weighted avg     0.8616    0.8609    0.8608      9951

{'eval_loss': 0.49801209568977356, 'eval_accuracy': 0.8609185006532006, 'eval_f1': 0.8608271100478655, 'eval_precision': 0.8616246926695332, 'eval_recall': 0.8608275318539393, 'eval_runtime': 73.831, 'eval_samples_per_second': 134.781, 'eval_steps_per_second': 8.425, 'epoch': 10.0}


In [27]:
import os
trainer.save_model(os.path.join("saved_model", "best_model_wav2vec_base"))

Saving model checkpoint to saved_model/best_model_wav2vec_base
Configuration saved in saved_model/best_model_wav2vec_base/config.json
Model weights saved in saved_model/best_model_wav2vec_base/pytorch_model.bin
Feature extractor saved in saved_model/best_model_wav2vec_base/preprocessor_config.json
