In [None]:
# !pip uninstall -y xlsr_finetune
# !pip install -Uqqq git+https://github.com/morganmcg1/xlsr_finetune.git
# !pip install -Uqqq git+https://github.com/huggingface/transformers.git

In [None]:
%load_ext autoreload
%autoreload 2

# Train Demo

In [None]:
from xlsr_finetune.data import *
from xlsr_finetune.training import *
from xlsr_finetune.wandbutils import *

In [None]:
import os
import wandb
import random
import numpy as np
from functools import partial
from datasets import load_dataset, load_from_disk
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor, Wav2Vec2ForCTC
from transformers import Trainer, TrainingArguments

## Load data

In [None]:
train_ds = load_dataset("common_voice", "ga-IE", split="train+validation", cache_dir='../../data')
valid_ds = load_dataset("common_voice", "ga-IE", split="test", cache_dir='../../data')
test_ds = valid_ds

Reusing dataset common_voice (../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f)
Reusing dataset common_voice (../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f)


In [None]:
train_ds, valid_ds

(Dataset({
     features: ['client_id', 'path', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
     num_rows: 1038
 }),
 Dataset({
     features: ['client_id', 'path', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
     num_rows: 506
 }))

Drop any rows where "path" doesn't contain a file

In [None]:
train_ds = drop_missing_files(train_ds)
valid_ds = drop_missing_files(valid_ds)

Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-f5da9e5da686b7b7.arrow
Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-9947222ea9d3c8e6.arrow
Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-0a293c5a0be6b07f.arrow
Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-76374acea6f14da3.arrow


All files found
All files found


[Optional] Merge another Dataset to Your training dataset and shuffle them

In [None]:
# train_ds = merge_ds(train_ds, new_ds)

In [None]:
# new_ds = load_dataset("common_voice", "en", split="test[20:40%]", cache_dir='../../data')
# en_train_sample = load_dataset("common_voice", "en", split="test", cache_dir='../../data')

Reusing dataset common_voice (../../data/common_voice/en/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f)
Reusing dataset common_voice (../../data/common_voice/en/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f)


In [None]:
# new_ds = drop_missing_files(new_ds)
# train_ds = merge_ds(train_ds, new_ds)

Loading cached processed dataset at ../../data/common_voice/en/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-217fc91e450410ec.arrow
Loading cached processed dataset at ../../data/common_voice/en/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-bfd96deffd340d2f.arrow


All files found


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Clean data and create Vocab

In [None]:
train_ds = train_ds.map(remove_special_characters)
valid_ds = valid_ds.map(remove_special_characters)

HBox(children=(FloatProgress(value=0.0, max=4271.0), HTML(value='')))

Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-f4552cc1a8c9143b.arrow





In [None]:
vocab = extract_vocab(train_ds, valid_ds, save=True, save_dir='data')
vocab

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




{'m': 0,
 'u': 1,
 'k': 2,
 'ú': 4,
 'ó': 5,
 'v': 6,
 'o': 7,
 'd': 8,
 's': 9,
 'e': 10,
 'p': 11,
 'w': 12,
 'n': 13,
 'í': 14,
 'x': 15,
 'h': 16,
 'r': 17,
 "'": 18,
 'c': 19,
 '-': 20,
 'b': 21,
 't': 22,
 'q': 23,
 'l': 24,
 'f': 25,
 'j': 26,
 'z': 27,
 'g': 28,
 'i': 29,
 'é': 30,
 'y': 31,
 'a': 32,
 'á': 33,
 '|': 3,
 '[UNK]': 34,
 '[PAD]': 35}

Check your vocab and add additional characters to ignore to the `chars_to_ignore_regex` string if needed

In [None]:
chars_to_ignore_regex = chars_to_ignore_regex[:-1] + '\𝓧]'
chars_to_ignore_regex

'[\\,\\?\\.\\!\\-\\;\\:"\\“\\%\\‘\\”\\�\\(\\)\\-\\*\\/\\\\\\𝓧]'

In [None]:
new_remove_special_characters = partial(remove_special_characters, 
                                        chars_to_ignore_regex=chars_to_ignore_regex)
train_ds = train_ds.map(new_remove_special_characters)
valid_ds = valid_ds.map(new_remove_special_characters)

HBox(children=(FloatProgress(value=0.0, max=4271.0), HTML(value='')))

Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-6b27d62a4e69d4a6.arrow





In [None]:
vocab = extract_vocab(train_ds, valid_ds, save=True, save_dir='data')
vocab

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




{'m': 0,
 'u': 1,
 'k': 2,
 'ú': 4,
 'ó': 5,
 'v': 6,
 'o': 7,
 'd': 8,
 's': 9,
 'e': 10,
 'p': 11,
 'w': 12,
 'n': 13,
 'í': 14,
 'x': 15,
 'h': 16,
 'r': 17,
 "'": 18,
 'c': 19,
 'b': 20,
 't': 21,
 'q': 22,
 'l': 23,
 'f': 24,
 'j': 25,
 'z': 26,
 'g': 27,
 'i': 28,
 'é': 29,
 'y': 30,
 'a': 31,
 'á': 32,
 '|': 3,
 '[UNK]': 33,
 '[PAD]': 34}

## Convert Audio to Array

In [None]:
sp2a = partial(speech_file_to_array, resample=True, new_sr=16_000)

In [None]:
train_ds = train_ds.map(sp2a, num_proc=8)
valid_ds = valid_ds.map(sp2a, num_proc=8)

     

HBox(children=(FloatProgress(value=0.0, description='#0', max=534.0, style=ProgressStyle(description_width='in…

 

HBox(children=(FloatProgress(value=0.0, description='#1', max=534.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='#3', max=534.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='#2', max=534.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='#5', max=534.0, style=ProgressStyle(description_width='in…

 

HBox(children=(FloatProgress(value=0.0, description='#4', max=534.0, style=ProgressStyle(description_width='in…

 

HBox(children=(FloatProgress(value=0.0, description='#6', max=534.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='#7', max=533.0, style=ProgressStyle(description_width='in…









  

Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-7dfc6472e2a227fc.arrow


  

Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-fa2bace49b0890f8.arrow
Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-06ac96092ae2ebb7.arrow


 

Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-79f9222581a150a8.arrow
Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-5fa8395867bdaf8e.arrow


 

Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-c894a44da9ff5b3a.arrow


  

Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-3e0b17b4d205e675.arrow
Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-3753555aeb19f897.arrow


## Filter Out Files That Could Not Be Read

`speech_file_to_array` adds 0 to `speech` items where the path could not be read. Lets remove these

In [None]:
prev_l = len(train_ds)
train_ds = train_ds.filter(lambda example: len(example['speech'])>1, batch_size=1)
print(f'{prev_l - len(train_ds)} samples removed')

HBox(children=(FloatProgress(value=0.0, max=4271.0), HTML(value='')))


0 samples removed


## Filter out Long Audio in Training Set
Longer audio can cause cuda oom errors, 112k frames @ 16k sample rate == 7s of audio

In [None]:
prev_l = len(train_ds)
train_ds = train_ds.filter(lambda example: len(example['speech'])<=112_000, batch_size=1)
print(f'{prev_l - len(train_ds)} samples removed')

HBox(children=(FloatProgress(value=0.0, max=4271.0), HTML(value='')))


850 samples removed


## Quick Data Check

In [None]:
rand_int = random.randint(0, len(train_ds)-1)

print("Target text:", train_ds[rand_int]["sentence"])
print("Input array shape:", np.asarray(train_ds[rand_int]["speech"]).shape)
print("Sampling rate:", train_ds[rand_int]["sampling_rate"])

Target text: cosc ar thobac  
Input array shape: (24192,)
Sampling rate: 16000


## PreProcess to Create Model Input Values

In [None]:
tokenizer = Wav2Vec2CTCTokenizer("data/vocab.json", unk_token="[UNK]", 
                                 pad_token="[PAD]", word_delimiter_token="|")

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, 
                                          padding_value=0.0, do_normalize=True, return_attention_mask=True)

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [None]:
def prepare_dataset(batch):
    # check that all files have the correct sampling rate
    assert (
        len(set(batch["sampling_rate"])) == 1
    ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."

    batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["sentence"]).input_ids
    return batch

In [None]:
#hide_output
# n_cpus = os.cpu_count() - 2

train_ds = train_ds.map(prepare_dataset, remove_columns=train_ds.column_names, num_proc=8, 
                                            batch_size=16, batched=True)

valid_ds = valid_ds.map(prepare_dataset, remove_columns=test_ds.column_names, num_proc=8,
                                          batch_size=16, batched=True)

 

HBox(children=(FloatProgress(value=0.0, description='#0', max=27.0, style=ProgressStyle(description_width='ini…

 

HBox(children=(FloatProgress(value=0.0, description='#1', max=27.0, style=ProgressStyle(description_width='ini…


 

HBox(children=(FloatProgress(value=0.0, description='#2', max=27.0, style=ProgressStyle(description_width='ini…


 

HBox(children=(FloatProgress(value=0.0, description='#3', max=27.0, style=ProgressStyle(description_width='ini…


 

HBox(children=(FloatProgress(value=0.0, description='#4', max=27.0, style=ProgressStyle(description_width='ini…


 

HBox(children=(FloatProgress(value=0.0, description='#5', max=27.0, style=ProgressStyle(description_width='ini…


 

HBox(children=(FloatProgress(value=0.0, description='#6', max=27.0, style=ProgressStyle(description_width='ini…


 

HBox(children=(FloatProgress(value=0.0, description='#7', max=27.0, style=ProgressStyle(description_width='ini…



        

HBox(children=(FloatProgress(value=0.0, description='#4', max=4.0, style=ProgressStyle(description_width='init…

HBox(children=(FloatProgress(value=0.0, description='#1', max=4.0, style=ProgressStyle(description_width='init…

HBox(children=(FloatProgress(value=0.0, description='#6', max=4.0, style=ProgressStyle(description_width='init…

HBox(children=(FloatProgress(value=0.0, description='#3', max=4.0, style=ProgressStyle(description_width='init…

HBox(children=(FloatProgress(value=0.0, description='#2', max=4.0, style=ProgressStyle(description_width='init…

HBox(children=(FloatProgress(value=0.0, description='#5', max=4.0, style=ProgressStyle(description_width='init…

HBox(children=(FloatProgress(value=0.0, description='#0', max=4.0, style=ProgressStyle(description_width='init…

HBox(children=(FloatProgress(value=0.0, description='#7', max=4.0, style=ProgressStyle(description_width='init…











## [Optional] Store Data in W&B

We will create a W&B Run. You can now log any data you'd like to W&B Artifacts, and it will be tied to this Run. When we use `Trainer` this run will also be used.

In [None]:
import wandb
entity, project = setup_wandb(entity='wandb', project_name='xlsr-irish', log_model=True)

wandb_run = wandb.init(name='ie-en_baseline', project='xlsr-irish', entity='wandb', 
                       tags=['ie-en','baseline'], group='baseline', reinit=True)



VBox(children=(Label(value=' 3627.44MB of 3627.44MB uploaded (309.15MB deduped)\r'), FloatProgress(value=1.0, …

In [None]:
# Log the datasets
ds = {'train_ready_train_ds':train_ds,
     'train_ready_valid_ds':valid_ds,
     'raw_test_ds': test_ds,
#      'raw_en_train_ds': en_train_sample
     }

for name in ds.keys():
    f_path = f'../../data/{name}'
    ds[name].save_to_disk(f_path)
    artifact = wandb.Artifact(name=name, type='dataset',
                             description='My dataset',
                             metadata={'dataset_length':len(ds[name])})
    artifact.add_dir(f_path)
    wandb_run.log_artifact(artifact)

[34m[1mwandb[0m: Adding directory to artifact (./../../data/train_ready_train_ds)... Done. 3.4s
[34m[1mwandb[0m: Adding directory to artifact (./../../data/train_ready_valid_ds)... Done. 1.1s
[34m[1mwandb[0m: Adding directory to artifact (./../../data/raw_test_ds)... Done. 0.1s
[34m[1mwandb[0m: Adding directory to artifact (./../../data/raw_en_train_ds)... Done. 0.1s


In [None]:
# f_path = f'../../data/train_ready_valid_ds'
# valid_ds.save_to_disk(f_path)
# artifact = wandb.Artifact(name=name, type='dataset',
#                          description='My dataset',
#                          metadata={'dataset_length':len(valid_ds)})
# artifact.add_dir(f_path)
# wandb_run.log_artifact(artifact)

In [None]:
# Log vocab file
vcb_f_path = 'data/vocab.json'
artifact = wandb.Artifact(name='vocab', type='vocab', 
                          description='Vocab for combined ie and en, len 36',
                          metadata={'vocab_length':len(vocab.keys())})
artifact.add_file(vcb_f_path)
wandb_run.log_artifact(artifact)

<wandb.sdk.wandb_artifacts.Artifact at 0x7f7221ca8490>

In [None]:
# Log processor
processor_path = './data'
processor.save_pretrained(processor_path)
artifact = wandb.Artifact(name='processor', type='processor')
artifact.add_dir(processor_path)
wandb_run.log_artifact(artifact)

[34m[1mwandb[0m: Adding directory to artifact (./data)... Done. 1.8s


<wandb.sdk.wandb_artifacts.Artifact at 0x7f7222808b50>

## Load Saved Data

In [None]:
# with wandb.init(project='xlsr-irish', entity='wandb') as run:
#     # Connect an Artifact to your run
#     train_ds_artifact = run.use_artifact('train_ready_train_ds:v0')
# #     valid_ds_artifact = run.use_artifact('raw_en_train_ds:v0')

#     # Download model weights to a folder and return the path
#     train_ds_dir = train_ds_artifact.download()
# #     valid_ds_dir = valid_ds_artifact.download()

# # Load your Hugging Face model from that folder, e.g. SequenceClassification model
# train_ds = load_from_disk(train_ds_dir)

## Prep Training

In [None]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True,
                                          pad_to_multiple_of=8, pad_to_multiple_of_labels=8)

In [None]:
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-xlsr-53", 
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.1,
    gradient_checkpointing=True, 
    ctc_loss_reduction="mean", 
    ctc_zero_infinity=True,
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-xlsr-53 and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Set Training arguments

In [None]:
training_args = TrainingArguments(
  output_dir="../../data/my_xlsr",
  group_by_length=True,
  per_device_train_batch_size=8,
  per_device_eval_batch_size=16,
  gradient_accumulation_steps=4,
  evaluation_strategy="steps",
  num_train_epochs=15,
  fp16=True,
  save_steps=96,
  eval_steps=64,
  logging_steps=8,
  learning_rate=3e-4,
  warmup_steps=96,
  save_total_limit=1,
    
  # WANDB LOGGING: 
  report_to = 'wandb',  # enable logging to W&B
  run_name = 'ie-en_baseline_15e',   # Name your run, optional
  load_best_model_at_end = True,  # This will ensure your best model will be uploaded to W&B
  metric_for_best_model='wer',    # Load best model based on "wer", not eval loss
  greater_is_better=False,    # Define "best" wer score as the lowest score
)

Create trainer

In [None]:
trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=partial(compute_wer_metric, processor=processor),   # compute_wer_metric imported from xlsr_finetune
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    tokenizer=processor.feature_extractor,
)

## Set up Monitoring [optional]
Log in to Weights and Biases and set your entity (username) and project name, or else use the publicly available entity and project below

In [None]:
setup_wandb(entity='wandb', project_name='xlsr-irish', log_model=True)

## Train

In [None]:
trainer.train()

# If using W&B and not doing any futher evaluation, then use wandb.finish()
# wandb.finish()  



Step,Training Loss,Validation Loss,Wer,Runtime,Samples Per Second
64,3.1934,3.089865,1.0,46.056,10.987
128,3.0053,3.093399,1.0,43.739,11.569
192,2.9268,2.945194,1.0,43.5632,11.615
256,2.883,2.88296,1.0,44.1315,11.466
320,2.7605,2.714602,1.0,44.0936,11.476
384,1.9728,1.862079,0.988294,44.5143,11.367
448,1.5002,1.505159,0.925666,43.9864,11.504
512,1.2846,1.410872,0.913667,44.2648,11.431
576,1.0073,1.242821,0.832602,44.2364,11.439
640,1.031,1.080747,0.797191,43.7183,11.574


**wandb.finish** - If using W&B and not doing any futher training or evaluation, then use wandb.finish()

In [None]:
# wandb.finish()  

## Evaluate

In [None]:
from xlsr_finetune.evaluation import *

In [None]:
sp2a = partial(speech_file_to_array, resample=True, new_sr=16_000, evaluate=True)
test_ds = test_ds.map(sp2a)

In [None]:
ev = partial(evaluate_xlsr, model=model, processor=processor)
result = test_ds.map(ev, batched=True, batch_size=8)

In [None]:
wer_true = 100 * wer_metric.compute(predictions=result["pred_strings"], references=result["sentence"])
print(f"WER: {wer_true:2f}")

In [None]:
import wandb
wandb_run.log({'test/wer_true': wer_true})
wandb_run.finish()

In [None]:
# import wandb
# wandb.log({'test/wer_true': wer_true})
# wandb.finish()