In [None]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Train Demo

In [None]:
# If running from colab or the OVH notebook
#!pip install git+https://github.com/morganmcg1/xlsr_finetune.git

In [None]:
from xlsr_finetune.data import *
from xlsr_finetune.training import *

In [None]:
import os
import random
import numpy as np
from functools import partial
from datasets import load_dataset
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor, Wav2Vec2ForCTC
from transformers import Trainer, TrainingArguments

## Load data

In [None]:
train_ds = load_dataset("common_voice", "ga-IE", split="train", cache_dir='data')
test_ds = load_dataset("common_voice", "ga-IE", split="test", cache_dir='data')

Reusing dataset common_voice (data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f)
Reusing dataset common_voice (data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f)


Drop any rows where "path" doesn't contain a file

In [None]:
train_ds = drop_missing_files(train_ds)
test_ds = drop_missing_files(test_ds)

Loading cached processed dataset at data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-f36e299c0e5cc39d.arrow
Loading cached processed dataset at data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-653c4b1b4f5e347a.arrow
Loading cached processed dataset at data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-4b4c74dbd55255a8.arrow
Loading cached processed dataset at data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-864aceba1c436f68.arrow


All files found
All files found


[Optional] Merge another Dataset to 

In [None]:
# train_ds = merge_ds(train_ds, new_ds)

Clean data and create Vocab

In [None]:
train_ds = train_ds.map(remove_special_characters)
test_ds = test_ds.map(remove_special_characters)

Loading cached processed dataset at data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-c35fef9f5e36dcd4.arrow
Loading cached processed dataset at data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-2d91e1c27f4f525d.arrow


In [None]:
vocab = extract_vocab(train_ds, test_ds, save=True, save_dir='data')

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




## PreProcess

In [None]:
tokenizer = Wav2Vec2CTCTokenizer("data/vocab.json", unk_token="[UNK]", 
                                 pad_token="[PAD]", word_delimiter_token="|")

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, 
                                          padding_value=0.0, do_normalize=True, return_attention_mask=True)

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

## Convert Audio to Array

In [None]:
sp2a = partial(speech_file_to_array, resample=True, new_sr=16_000)

In [None]:
train_ds = train_ds.map(sp2a)
test_ds = test_ds.map(sp2a)

HBox(children=(FloatProgress(value=0.0, max=541.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=506.0), HTML(value='')))




## Quick Check

In [None]:
rand_int = random.randint(0, len(train_ds)-1)

print("Target text:", train_ds[rand_int]["sentence"])
print("Input array shape:", np.asarray(train_ds[rand_int]["speech"]).shape)
print("Sampling rate:", train_ds[rand_int]["sampling_rate"])

Target text: ní raibh a thuairimí radacacha inghlactha ag muintir na cathrach ag an am sin 
Input array shape: (82560,)
Sampling rate: 16000


## Create Model Input Values

In [None]:
def prepare_dataset(batch):
    # check that all files have the correct sampling rate
    assert (
        len(set(batch["sampling_rate"])) == 1
    ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."

    batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["sentence"]).input_ids
    return batch

In [None]:
n_cpus = os.cpu_count() 

train_ds = train_ds.map(prepare_dataset, remove_columns=train_ds.column_names, 
                                            batch_size=8, num_proc=n_cpus, batched=True)
test_ds = test_ds.map(prepare_dataset, remove_columns=test_ds.column_names, 
                                          batch_size=8, num_proc=n_cpus, batched=True)

  return array(a, dtype, copy=False, order=order)




















## Prep Training

In [None]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True,
                                          pad_to_multiple_of=8, pad_to_multiple_of_labels=8)

In [None]:
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-xlsr-53", 
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.1,
    gradient_checkpointing=True, 
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)

Set Training arguments

In [None]:
training_args = TrainingArguments(
  output_dir="data/my_xlsr",
  group_by_length=True,
  per_device_train_batch_size=32,
  per_device_eval_batch_size=64,
  gradient_accumulation_steps=1,
  evaluation_strategy="steps",
  num_train_epochs=50,
  fp16=True,
  save_steps=25,
  eval_steps=25,
  logging_steps=5,
  learning_rate=3e-4,
  warmup_steps=200,
  save_total_limit=1,
    
  # WANDB LOGGING: 
  report_to = 'wandb',  # enable logging to W&B
  run_name = 'baseline_model_3e-4',   # Name your run, optional
  load_best_model_at_end = True,  # This will ensure your best model will be uploaded to W&B
  metric_for_best_model='wer',    # Load best model based on "wer", not eval loss
  greater_is_better=False,    # Define "best" wer score as the lowest score
)

Create trainer

In [None]:
trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor,
)

## Set up Monitoring [optional]
Log in to Weights and Biases and set your entity (username) and project name, or else use the publicly available entity and project below

In [None]:
entity, project_name = setup_wandb(entity='wandb', project_name='xlsr', log_model=True)

[34m[1mwandb[0m: Currently logged in as: [33mwandb[0m (use `wandb login --relogin` to force relogin)


('wandb', 'xlsr')

## Train

In [None]:
trainer.train()
# wand.finish()

## Evaluate

In [None]:
trainer.evaluate()

NameError: name 'trainer' is not defined

In [None]:
from nbdev.export2html import create_default_sidebar

In [None]:
create_default_sidebar()