In [None]:
# !pip uninstall -y xlsr_finetune

Found existing installation: xlsr-finetune 0.0.1
Uninstalling xlsr-finetune-0.0.1:
  Successfully uninstalled xlsr-finetune-0.0.1


In [None]:
!pip install -Uqqq git+https://github.com/morganmcg1/xlsr_finetune.git

In [None]:
%load_ext autoreload
%autoreload 2

# Train Demo

In [None]:
from xlsr_finetune.data import *
from xlsr_finetune.training import *

In [None]:
import os
import random
import numpy as np
from functools import partial
from datasets import load_dataset
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor, Wav2Vec2ForCTC
from transformers import Trainer, TrainingArguments

## Load data

In [None]:
train_ds = load_dataset("common_voice", "ga-IE", split="train", cache_dir='data')
valid_ds = load_dataset("common_voice", "ga-IE", split="test", cache_dir='data')
test_ds = valid_ds

Reusing dataset common_voice (data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f)
Reusing dataset common_voice (data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f)


Drop any rows where "path" doesn't contain a file

In [None]:
train_ds = drop_missing_files(train_ds)
valid_ds = drop_missing_files(valid_ds)

Loading cached processed dataset at data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-a42182362ebe8af9.arrow
Loading cached processed dataset at data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-261ae281cd2b2161.arrow
Loading cached processed dataset at data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-3cba840606e8833f.arrow
Loading cached processed dataset at data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-bb2eb43e52438c8b.arrow


All files found
All files found


[Optional] Merge another Dataset to 

In [None]:
# train_ds = merge_ds(train_ds, new_ds)

Clean data and create Vocab

In [None]:
train_ds = train_ds.map(remove_special_characters)
valid_ds = valid_ds.map(remove_special_characters)

Loading cached processed dataset at data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-1e10abde901ba23d.arrow
Loading cached processed dataset at data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-fee29f6a1533a028.arrow


In [None]:
vocab = extract_vocab(train_ds, valid_ds, save=True, save_dir='data')
vocab

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Check your vocab and add any additional characters to exclulde from it like so

In [None]:
# from xlsr_finetune.data import chars_to_ignore_regex 

# chars_to_ignore_regex = chars_to_ignore_regex[:-1] + '\/\\]'

# chars_to_ignore_regex

In [None]:
# vocab = extract_vocab(train_ds, valid_ds, save=True, save_dir='data')

## Convert Audio to Array

In [None]:
sp2a = partial(speech_file_to_array, resample=True, new_sr=16_000)

In [None]:
train_ds = train_ds.map(sp2a)
valid_ds = valid_ds.map(sp2a)

Loading cached processed dataset at data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-377fbe0656230030.arrow
Loading cached processed dataset at data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-b82ea0a89d072ca4.arrow


## Quick Check

In [None]:
rand_int = random.randint(0, len(train_ds)-1)

print("Target text:", train_ds[rand_int]["sentence"])
print("Input array shape:", np.asarray(train_ds[rand_int]["speech"]).shape)
print("Sampling rate:", train_ds[rand_int]["sampling_rate"])

Target text: eolas fúinn 
Input array shape: (40320,)
Sampling rate: 16000


## PreProcess to Create Model Input Values

In [None]:
tokenizer = Wav2Vec2CTCTokenizer("data/vocab.json", unk_token="[UNK]", 
                                 pad_token="[PAD]", word_delimiter_token="|")

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, 
                                          padding_value=0.0, do_normalize=True, return_attention_mask=True)

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [None]:
def prepare_dataset(batch):
    # check that all files have the correct sampling rate
    assert (
        len(set(batch["sampling_rate"])) == 1
    ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."

    batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["sentence"]).input_ids
    return batch

In [None]:
#hide_output
n_cpus = os.cpu_count()  # Num cpus in case you'd like to set num_proc

train_ds = train_ds.map(prepare_dataset, remove_columns=train_ds.column_names, 
                                            batch_size=8,  batched=True) # num_proc=n_cpus,
valid_ds = valid_ds.map(prepare_dataset, remove_columns=test_ds.column_names, 
                                          batch_size=8, batched=True)  # num_proc=n_cpus, 

           

HBox(children=(FloatProgress(value=0.0, description='#3', max=2.0, style=ProgressStyle(description_width='init…

HBox(children=(FloatProgress(value=0.0, description='#1', max=2.0, style=ProgressStyle(description_width='init…

     

HBox(children=(FloatProgress(value=0.0, description='#4', max=2.0, style=ProgressStyle(description_width='init…

HBox(children=(FloatProgress(value=0.0, description='#2', max=2.0, style=ProgressStyle(description_width='init…

   

HBox(children=(FloatProgress(value=0.0, description='#0', max=2.0, style=ProgressStyle(description_width='init…

          

HBox(children=(FloatProgress(value=0.0, description='#11', max=2.0, style=ProgressStyle(description_width='ini…

  

HBox(children=(FloatProgress(value=0.0, description='#5', max=2.0, style=ProgressStyle(description_width='init…

 

HBox(children=(FloatProgress(value=0.0, description='#14', max=2.0, style=ProgressStyle(description_width='ini…

   

HBox(children=(FloatProgress(value=0.0, description='#15', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#13', max=2.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#20', max=2.0, style=ProgressStyle(description_width='ini…

     

HBox(children=(FloatProgress(value=0.0, description='#22', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#21', max=2.0, style=ProgressStyle(description_width='ini…

   

HBox(children=(FloatProgress(value=0.0, description='#6', max=2.0, style=ProgressStyle(description_width='init…

 

HBox(children=(FloatProgress(value=0.0, description='#10', max=2.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#18', max=2.0, style=ProgressStyle(description_width='ini…

  

HBox(children=(FloatProgress(value=0.0, description='#16', max=2.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#30', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#36', max=2.0, style=ProgressStyle(description_width='ini…

  

HBox(children=(FloatProgress(value=0.0, description='#9', max=2.0, style=ProgressStyle(description_width='init…

HBox(children=(FloatProgress(value=0.0, description='#31', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#19', max=2.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#12', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#17', max=2.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#27', max=2.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#34', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#39', max=2.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#8', max=2.0, style=ProgressStyle(description_width='init…

HBox(children=(FloatProgress(value=0.0, description='#28', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#43', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#23', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#32', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#7', max=2.0, style=ProgressStyle(description_width='init…

HBox(children=(FloatProgress(value=0.0, description='#25', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#44', max=2.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#26', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#24', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#29', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#33', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#38', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#37', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#35', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#41', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#47', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#40', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#48', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#42', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#49', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#50', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#52', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#45', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#54', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#53', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#55', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#46', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#51', max=2.0, style=ProgressStyle(description_width='ini…

























































      

HBox(children=(FloatProgress(value=0.0, description='#0', max=2.0, style=ProgressStyle(description_width='init…

   

HBox(children=(FloatProgress(value=0.0, description='#1', max=2.0, style=ProgressStyle(description_width='init…

HBox(children=(FloatProgress(value=0.0, description='#3', max=2.0, style=ProgressStyle(description_width='init…

 

HBox(children=(FloatProgress(value=0.0, description='#4', max=2.0, style=ProgressStyle(description_width='init…

   

HBox(children=(FloatProgress(value=0.0, description='#5', max=2.0, style=ProgressStyle(description_width='init…

   

HBox(children=(FloatProgress(value=0.0, description='#2', max=2.0, style=ProgressStyle(description_width='init…

        

HBox(children=(FloatProgress(value=0.0, description='#10', max=2.0, style=ProgressStyle(description_width='ini…

  

HBox(children=(FloatProgress(value=0.0, description='#13', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#6', max=2.0, style=ProgressStyle(description_width='init…

        

HBox(children=(FloatProgress(value=0.0, description='#21', max=2.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#12', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#19', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#9', max=2.0, style=ProgressStyle(description_width='init…

HBox(children=(FloatProgress(value=0.0, description='#7', max=2.0, style=ProgressStyle(description_width='init…

 

HBox(children=(FloatProgress(value=0.0, description='#8', max=2.0, style=ProgressStyle(description_width='init…

 

HBox(children=(FloatProgress(value=0.0, description='#15', max=2.0, style=ProgressStyle(description_width='ini…

  

HBox(children=(FloatProgress(value=0.0, description='#23', max=2.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#24', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#17', max=2.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#25', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#18', max=2.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#20', max=2.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#27', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#14', max=2.0, style=ProgressStyle(description_width='ini…

  

HBox(children=(FloatProgress(value=0.0, description='#11', max=2.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#16', max=2.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#26', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#32', max=2.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#38', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#31', max=2.0, style=ProgressStyle(description_width='ini…

  

HBox(children=(FloatProgress(value=0.0, description='#36', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#28', max=2.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#39', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#22', max=2.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#41', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#34', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#29', max=2.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#33', max=2.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#42', max=2.0, style=ProgressStyle(description_width='ini…

  

HBox(children=(FloatProgress(value=0.0, description='#43', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#30', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#45', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#40', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#37', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#35', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#46', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#44', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#48', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#53', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#50', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#55', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#49', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#47', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#51', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#52', max=2.0, style=ProgressStyle(description_width='ini…

HBox(children=(FloatProgress(value=0.0, description='#54', max=2.0, style=ProgressStyle(description_width='ini…



























































## Prep Training

In [None]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True,
                                          pad_to_multiple_of=8, pad_to_multiple_of_labels=8)

In [None]:
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-xlsr-53", 
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.1,
    gradient_checkpointing=True, 
    ctc_loss_reduction="mean", 
    ctc_zero_infinity=True,
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-xlsr-53 and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Set Training arguments

In [None]:
training_args = TrainingArguments(
  output_dir="data/my_xlsr",
  group_by_length=True,
  per_device_train_batch_size=32,
  per_device_eval_batch_size=64,
  gradient_accumulation_steps=1,
  evaluation_strategy="steps",
  num_train_epochs=2,
  fp16=True,
  save_steps=25,
  eval_steps=25,
  logging_steps=5,
  learning_rate=3e-4,
  warmup_steps=200,
  save_total_limit=1,
    
  # WANDB LOGGING: 
  report_to = 'wandb',  # enable logging to W&B
  run_name = 'baseline_model_3e-4',   # Name your run, optional
  load_best_model_at_end = True,  # This will ensure your best model will be uploaded to W&B
  metric_for_best_model='wer',    # Load best model based on "wer", not eval loss
  greater_is_better=False,    # Define "best" wer score as the lowest score
)

Create trainer

In [None]:
trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=partial(compute_wer_metric, processor=processor),   # compute_wer_metric imported from xlsr_finetune
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    tokenizer=processor.feature_extractor,
)

## Set up Monitoring [optional]
Log in to Weights and Biases and set your entity (username) and project name, or else use the publicly available entity and project below

In [None]:
setup_wandb(entity='wandb', project_name='xlsr', log_model=True)

[34m[1mwandb[0m: Currently logged in as: [33mwandb[0m (use `wandb login --relogin` to force relogin)


('wandb', 'xlsr')

## Train

In [None]:
trainer.train()

# If using W&B and not doing any futher evaluation, then use wandb.finish()
# wandb.finish()  

## Evaluate

In [None]:
sp2a = partial(speech_file_to_array, resample=True, new_sr=16_000, evaluate=True)
test_ds = test_ds.map(sp2a)

In [None]:
ev = partial(evaluate_xlsr, model=model, processor=processor)
result = test_ds.map(ev, batched=True, batch_size=8)

In [None]:
wer_true = 100 * wer_metric.compute(predictions=result["pred_strings"], references=result["sentence"])
print(f"WER: {wer_true:2f}")

In [None]:
import wandb
wandb.log({'test/wer_true': wer_true})
wandb.finish()