In [None]:
# !pip uninstall -y xlsr_finetune
# !pip install -Uqqq git+https://github.com/huggingface/transformers.git
# !pip install -Uqqq git+https://github.com/huggingface/datasets.git
# !pip install -Uqqq git+https://github.com/morganmcg1/xlsr_finetune.git
# !pip install wandb --upgrade

In [None]:
%load_ext autoreload
%autoreload 2

# Train Demo

In [None]:
from xlsr_finetune.data import *
from xlsr_finetune.training import *
from xlsr_finetune.wandbutils import *

In [None]:
import os
import wandb
import random
import numpy as np
from functools import partial
from datasets import load_dataset, load_from_disk
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor, Wav2Vec2ForCTC
from transformers import Trainer, TrainingArguments

## Load data

In [None]:
data_dir = '../../data'

In [None]:
train_ds = load_dataset("common_voice", "ga-IE", split="train+validation", cache_dir=data_dir)
valid_ds = load_dataset("common_voice", "ga-IE", split="test", cache_dir=data_dir)
test_ds = valid_ds

Reusing dataset common_voice (../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f)
Reusing dataset common_voice (../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f)


In [None]:
train_ds['sentence'][:20]

['"D\'ísligh Máire ar a glúine le hais an chliabháin"',
 'An Phoblacht Dhoiminiceach',
 'Gabhaim buíochas libh a chairde táimid ag tnúth le tuilleadh uaibh ar ball',
 'Ar mhaith leat lón?',
 'An Naoú Bliain',
 'Go raibh maith agaibh agus bainigí taitneamh as an bhfilíocht agus as an gceol.',
 'Guím gach rath agus beannacht oraibh don todhchaí',
 'An mbeidh Dónall Ó Laoire anseo',
 'An Cheardchomhairle.',
 'Comhghairdeas libh go léir agus guím gach rath oraibh sa todhchaí.',
 'A dhaoine córa, a chairde dílse,',
 'Gura fada buan sibh i mbun cheol bhinn na hÉireann.',
 'Is í sin an obair atá á ceiliúradh againn anocht',
 'An Clár',
 'Mar a dúirt sé féin',
 'Titanic, Béal Feirste, an seachtú lá de mhí an Mhárta dhá mhíle a cúig déag',
 'An mbeidh siad ag obair anocht',
 'Éirí as; Comhaltas a Fhionraí; Oibríochtaí a Fhionraí.',
 'D’fhéach sé ar an bhfrog',
 'Airleacain ó údaráis tithe chun tithe a athfhoirgniú, a dheisiú agus a fheabhsú.']

Drop any rows where "path" doesn't contain a file

In [None]:
# # if files have moved location between sessions, remap the path location
# def remap_data_dir(e):
#     e['path'] = f'{data_dir}/' + '/'.join(e['path'].split('/')[1:])
#     return e

# train_ds = train_ds.map(remap_data_dir)
# valid_ds = valid_ds.map(remap_data_dir)

In [None]:
train_ds = drop_missing_files(train_ds)
valid_ds = drop_missing_files(valid_ds)

Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-ac60caf50be0ec56.arrow
Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-f7ccf59d7eefb8c0.arrow
Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-29f7764b1cebe257.arrow
Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-3c5e525f3a6edcad.arrow


All files found
All files found


[Optional] Merge another Dataset to Your training dataset and shuffle them

In [None]:
# new_ds = load_dataset("common_voice", "en", split="test[20:40%]", cache_dir=data_dir)
# en_train_sample = load_dataset("common_voice", "en", split="test", cache_dir=data_dir)

In [None]:
# new_ds = new_ds.map(remap_data_dir)
# en_train_sample = en_train_sample.map(remap_data_dir)

In [None]:
# new_ds = drop_missing_files(new_ds)
# train_ds = merge_ds(train_ds, new_ds)

## Jim O'Regan Processing

In [None]:
import re

# So, tolower() for Irish is a bit complicated: tAthar -> t-athair
# toupper() is non-deterministic :)
def is_upper_vowel(letter):
    if letter in ['A', 'E', 'I', 'O', 'U', 'Á', 'É', 'Í', 'Ó', 'Ú']:
        return True
    else:
        return False
    
def irish_lower(word):
    if len(word) > 1 and word[0] in ['n', 't'] and is_upper_vowel(word[1]):
        return word[0] + '-' + word[1:].lower()
    else:
        return word.lower()
    
def irish_lower_sentence(sentence):
    return " ".join([irish_lower(w) for w in sentence.split(" ")])

chars_to_ignore_regex = '[,\?\.\!\;\:\"\“\%\‘\”\(\)\*]'

# def remove_special_characters(sentence):
#     tmp = re.sub('’ ', ' ', sentence)
#     tmp = re.sub("’{{%htmlContent%}}quot;", '', tmp)
#     tmp = re.sub('’', '\'', tmp)
#     tmp = re.sub(chars_to_ignore_regex, '', tmp)
#     sentence = irish_lower_sentence(tmp) + ' '
#     return sentence

def remove_special_characters(batch):
    tmp = re.sub('’ ', ' ', batch['sentence'])
    tmp = re.sub("’{{%htmlContent%}}quot;", '', tmp)
    tmp = re.sub('’', '\'', tmp)
    
    # MORGAN ADDED "-"" regex
    tmp = re.sub('[\–]', '-', tmp)
    tmp = re.sub('[\—]', '-', tmp)
    
    tmp = re.sub(chars_to_ignore_regex, '', tmp)
    batch['sentence'] = irish_lower_sentence(tmp) + ' '
    return batch

# # MINE!!
# chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\(\)\-\*\/\\\]' 
# def remove_special_characters(batch, evaluate:bool=False, chars_to_ignore_regex:str=chars_to_ignore_regex):
#     if evaluate: batch["sentence"] = re.sub(chars_to_ignore_regex, '', 
#                                             batch["sentence"]).lower()
#     else: batch["sentence"] = re.sub(chars_to_ignore_regex, '', 
#                                             batch["sentence"]).lower() + " "
        
#     batch["sentence"] = re.sub('[\’]', '\'', batch["sentence"])
#     batch["sentence"] = re.sub('[\’]', '\'', batch["sentence"])
#     batch["sentence"] = re.sub('[\–]', '-', batch["sentence"])
#     batch["sentence"] = re.sub('[\—]', '-', batch["sentence"])
#     batch["sentence"] = re.sub('[&]', ' and ', batch["sentence"])
#     return batch

Clean data and create Vocab

In [None]:
train_ds = train_ds.map(remove_special_characters)
valid_ds = valid_ds.map(remove_special_characters)

Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-447b03b5d4e9b362.arrow
Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-62261d1474eea0fe.arrow


In [None]:
train_ds['sentence'][:20]

["d'ísligh máire ar a glúine le hais an chliabháin ",
 'an phoblacht dhoiminiceach ',
 'gabhaim buíochas libh a chairde táimid ag tnúth le tuilleadh uaibh ar ball ',
 'ar mhaith leat lón ',
 'an naoú bliain ',
 'go raibh maith agaibh agus bainigí taitneamh as an bhfilíocht agus as an gceol ',
 'guím gach rath agus beannacht oraibh don todhchaí ',
 'an mbeidh dónall ó laoire anseo ',
 'an cheardchomhairle ',
 'comhghairdeas libh go léir agus guím gach rath oraibh sa todhchaí ',
 'a dhaoine córa a chairde dílse ',
 'gura fada buan sibh i mbun cheol bhinn na héireann ',
 'is í sin an obair atá á ceiliúradh againn anocht ',
 'an clár ',
 'mar a dúirt sé féin ',
 'titanic béal feirste an seachtú lá de mhí an mhárta dhá mhíle a cúig déag ',
 'an mbeidh siad ag obair anocht ',
 'éirí as comhaltas a fhionraí oibríochtaí a fhionraí ',
 "d'fhéach sé ar an bhfrog ",
 'airleacain ó údaráis tithe chun tithe a athfhoirgniú a dheisiú agus a fheabhsú ']

In [None]:
vocab = extract_vocab(train_ds, valid_ds, save=True, save_dir='data')
vocab

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




{'j': 0,
 'k': 1,
 't': 2,
 'r': 3,
 'g': 4,
 'é': 5,
 'f': 6,
 'ó': 7,
 's': 8,
 'e': 9,
 'a': 11,
 'v': 12,
 "'": 13,
 'o': 14,
 'h': 15,
 'á': 16,
 'c': 17,
 'x': 18,
 'w': 19,
 'u': 20,
 'b': 21,
 'd': 22,
 'l': 23,
 'i': 24,
 'm': 25,
 'y': 26,
 'n': 27,
 'ú': 28,
 'p': 29,
 'í': 30,
 '-': 31,
 '|': 10,
 '[UNK]': 32,
 '[PAD]': 33}

Check your vocab and add additional characters to ignore to the `chars_to_ignore_regex` string if needed

In [None]:
# chars_to_ignore_regex = chars_to_ignore_regex[:-1] + '\𝓧]'
# chars_to_ignore_regex

In [None]:
# new_remove_special_characters = partial(remove_special_characters, 
#                                         chars_to_ignore_regex=chars_to_ignore_regex)
# train_ds = train_ds.map(new_remove_special_characters)
# valid_ds = valid_ds.map(new_remove_special_characters)

In [None]:
# vocab = extract_vocab(train_ds, valid_ds, save=True, save_dir='data')
# vocab

## Convert Audio to Array

In [None]:
sp2a = partial(speech_file_to_array, resample=True, new_sr=16_000)

In [None]:
train_ds = train_ds.map(sp2a, num_proc=8)
valid_ds = valid_ds.map(sp2a, num_proc=8)

  

Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-540b0b5f20a00d7e.arrow
Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-e36f8c4ef762ef4e.arrow


  

Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-3487b66236ee960a.arrow
Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-1a78b14c1c75739a.arrow


 

Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-b0e98eafe455de15.arrow


 

Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-4ab87bb1d26ec961.arrow


 

Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-9c5f5f0533a1626e.arrow


 

Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-cba1000c04901986.arrow


     

Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-c48fb5fe6f3b74fa.arrow
Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-5576375484537ddb.arrow
Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-075427426b013895.arrow
Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-b29ef75aaff84fde.arrow
Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-7ce9e902e505f2f2.arrow


 

Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-a2c08ada0e8ce856.arrow


 

Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-12a9d3d125caaaf0.arrow


 

Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-330e8b8cbf360325.arrow


## Filter Out Files That Could Not Be Read

`speech_file_to_array` adds 0 to `speech` items where the path could not be read. Lets remove these

In [None]:
prev_l = len(train_ds)
train_ds = train_ds.filter(lambda example: len(example['speech'])>1, batch_size=1)
print(f'{prev_l - len(train_ds)} samples removed')

Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-96236c3cc8b36975.arrow


0 samples removed


## Filter out Long Audio in Training Set
Longer audio can cause cuda oom errors, 112k frames @ 16k sample rate == 7s of audio

In [None]:
prev_l = len(train_ds)
train_ds = train_ds.filter(lambda example: len(example['speech'])<=160_000) 
print(f'{prev_l - len(train_ds)} samples out of {prev_l} removed')

Loading cached processed dataset at ../../data/common_voice/ga-IE/6.1.0/0041e06ab061b91d0a23234a2221e87970a19cf3a81b20901474cffffeb7869f/cache-0b3fb300b1905b1f.arrow


1 samples out of 1038 removed


## Quick Data Check

In [None]:
rand_int = random.randint(0, len(train_ds)-1)

print(f"Number of train samples: {len(train_ds)}")
print("Target text:", train_ds[rand_int]["sentence"])
print("Input array shape:", np.asarray(train_ds[rand_int]["speech"]).shape)
print("Sampling rate:", train_ds[rand_int]["sampling_rate"])

Number of train samples: 1037
Target text: tá siad sa bhaile 
Input array shape: (41472,)
Sampling rate: 16000


In [None]:
train_ds['sentence'][:20]

["d'ísligh máire ar a glúine le hais an chliabháin ",
 'an phoblacht dhoiminiceach ',
 'gabhaim buíochas libh a chairde táimid ag tnúth le tuilleadh uaibh ar ball ',
 'ar mhaith leat lón ',
 'an naoú bliain ',
 'go raibh maith agaibh agus bainigí taitneamh as an bhfilíocht agus as an gceol ',
 'guím gach rath agus beannacht oraibh don todhchaí ',
 'an mbeidh dónall ó laoire anseo ',
 'an cheardchomhairle ',
 'comhghairdeas libh go léir agus guím gach rath oraibh sa todhchaí ',
 'a dhaoine córa a chairde dílse ',
 'gura fada buan sibh i mbun cheol bhinn na héireann ',
 'is í sin an obair atá á ceiliúradh againn anocht ',
 'an clár ',
 'mar a dúirt sé féin ',
 'titanic béal feirste an seachtú lá de mhí an mhárta dhá mhíle a cúig déag ',
 'an mbeidh siad ag obair anocht ',
 'éirí as comhaltas a fhionraí oibríochtaí a fhionraí ',
 "d'fhéach sé ar an bhfrog ",
 'airleacain ó údaráis tithe chun tithe a athfhoirgniú a dheisiú agus a fheabhsú ']

## PreProcess to Create Model Input Values

In [None]:
tokenizer = Wav2Vec2CTCTokenizer("data/vocab.json", unk_token="[UNK]", 
                                 pad_token="[PAD]", word_delimiter_token="|")

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, 
                                          padding_value=0.0, do_normalize=True, return_attention_mask=True)

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [None]:
def prepare_dataset(batch):
    # check that all files have the correct sampling rate
    assert (
        len(set(batch["sampling_rate"])) == 1
    ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."

    batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["sentence"]).input_ids
    return batch

In [None]:
#hide_output
# n_cpus = os.cpu_count() - 2

train_ds = train_ds.map(prepare_dataset, remove_columns=train_ds.column_names, batch_size=32, batched=True)
valid_ds = valid_ds.map(prepare_dataset, remove_columns=test_ds.column_names, batch_size=32, batched=True)

HBox(children=(FloatProgress(value=0.0, max=33.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))




## [Optional] Store Data in W&B

We will create a W&B Run. You can now log any data you'd like to W&B Artifacts, and it will be tied to this Run. When we use `Trainer` this run will also be used.

In [None]:
import wandb
entity, project = setup_wandb(entity='wandb', project_name='xlsr-irish', log_model=True)

wandb_run = wandb.init(name='ie_grad-clip-0-05_jim-data_boris-hyp', project='xlsr-irish', entity='wandb', 
                       tags=['ie-en','baseline'], group='baseline', 
                       notes='Using the same data pre-processing and Jim O Regan, with Boris hyperparameters',
                       reinit=True)

[34m[1mwandb[0m: Currently logged in as: [33mwandb[0m (use `wandb login --relogin` to force relogin)


In [None]:
# Log the datasets
ds = {'train_ready_train_ds':train_ds,
     'train_ready_valid_ds':valid_ds,
     'raw_test_ds': test_ds,
#      'raw_en_train_ds': en_train_sample
     }

for name in ds.keys():
    f_path = f'../../data/{name}'
    ds[name].save_to_disk(f_path)
    artifact = wandb.Artifact(name=name, type='dataset',
                             description='Same data pre-processing and Jim O Regan',
                             metadata={'dataset_length':len(ds[name])})
    artifact.add_dir(f_path)
    wandb_run.log_artifact(artifact)

[34m[1mwandb[0m: Adding directory to artifact (./../../data/train_ready_train_ds)... Done. 6.2s
[34m[1mwandb[0m: Adding directory to artifact (./../../data/train_ready_valid_ds)... Done. 3.7s
[34m[1mwandb[0m: Adding directory to artifact (./../../data/raw_test_ds)... Done. 0.1s


In [None]:
# Log vocab file
vcb_f_path = 'data/vocab.json'
artifact = wandb.Artifact(name='vocab', type='vocab', 
                          description='Vocab for combined ie and en, len 36',
                          metadata={'vocab_length':len(vocab.keys())})
artifact.add_file(vcb_f_path)
wandb_run.log_artifact(artifact)

<wandb.sdk.wandb_artifacts.Artifact at 0x7fca49c48190>

In [None]:
# Log processor
processor_path = './data'
processor.save_pretrained(processor_path)
artifact = wandb.Artifact(name='processor', type='processor')
artifact.add_dir(processor_path)
wandb_run.log_artifact(artifact)

[34m[1mwandb[0m: Adding directory to artifact (./data)... Done. 0.1s


<wandb.sdk.wandb_artifacts.Artifact at 0x7fca75fe8390>

## Load Saved Data

In [None]:
# with wandb.init(project='xlsr-irish', entity='wandb') as run:
#     # Connect an Artifact to your run
#     train_ds_artifact = run.use_artifact('train_ready_train_ds:v0')
# #     valid_ds_artifact = run.use_artifact('raw_en_train_ds:v0')

#     # Download model weights to a folder and return the path
#     train_ds_dir = train_ds_artifact.download()
# #     valid_ds_dir = valid_ds_artifact.download()

# # Load your Hugging Face model from that folder, e.g. SequenceClassification model
# train_ds = load_from_disk(train_ds_dir)

## Prep Training

In [None]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True,
                                          pad_to_multiple_of=8, pad_to_multiple_of_labels=8)

In [None]:
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-xlsr-53",
    activation_dropout= 0.055,
    attention_dropout= 0.094,
    hidden_dropout=0.047,
    feat_proj_dropout= 0.04,
    layerdrop=0.041,
    mask_time_prob=0.082,
    gradient_checkpointing=True, 
    ctc_loss_reduction="mean", 
    ctc_zero_infinity=True,
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-xlsr-53 and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model.freeze_feature_extractor()

Set Training arguments

In [None]:
training_args = TrainingArguments(
  output_dir="../../data/my_xlsr",
  group_by_length=True,
  max_grad_norm=0.05,
  per_device_train_batch_size=32,
  per_device_eval_batch_size=64,
  gradient_accumulation_steps=2,
  evaluation_strategy="steps",
  num_train_epochs=50,
  fp16=True,
  save_steps=64,
  eval_steps=64,
  logging_steps=8,
  learning_rate=3e-4,
  warmup_steps=96,
  save_total_limit=1,
  dataloader_num_workers=16,
    
  # WANDB LOGGING: 
  report_to = 'wandb',  # enable logging to W&B
#   run_name = 'ie-en_baseline_15e',   # Name your run, optional
  load_best_model_at_end = True,  # This will ensure your best model will be uploaded to W&B
  metric_for_best_model='wer',    # Load best model based on "wer", not eval loss
  greater_is_better=False,    # Define "best" wer score as the lowest score
)

Create trainer

In [None]:
trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=partial(compute_wer_metric, processor=processor),   # compute_wer_metric imported from xlsr_finetune
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    tokenizer=processor.feature_extractor,
)

## Set up Monitoring [optional]
Log in to Weights and Biases and set your entity (username) and project name, or else use the publicly available entity and project below

In [None]:
setup_wandb(entity='wandb', project_name='xlsr-irish', log_model=True)



('wandb', 'xlsr-irish')

## Train

In [None]:
trainer.train()

# If using W&B and not doing any futher evaluation, then use wandb.finish()
# wandb.finish()  



Step,Training Loss,Validation Loss,Wer,Runtime,Samples Per Second
64,3.2655,3.16297,1.0,23.3681,21.653
128,2.9159,2.9116,1.0,22.5576,22.431
192,2.4403,2.278715,0.99883,22.2436,22.748


**wandb.finish** - If using W&B and not doing any futher training or evaluation, then use wandb.finish()

In [None]:
# wandb.finish()  

## Evaluate

In [None]:
from xlsr_finetune.evaluation import *

In [None]:
test_ds = load_dataset("common_voice", "ga-IE", split="test", cache_dir=data_dir)
# test_ds = test_ds.map(remap_data_dir)



In [None]:
sp2a = partial(speech_file_to_array, resample=True, new_sr=16_000, evaluate=True)
test_ds = test_ds.map(sp2a)

HBox(children=(FloatProgress(value=0.0, max=506.0), HTML(value='')))




In [None]:
ev = partial(evaluate_xlsr, model=model, processor=processor)
result = test_ds.map(ev, batched=True, batch_size=16)



HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))




In [None]:
result[:10]["pred_strings"], result[:10]["sentence"]

(['bhí tara bmhio scéirí agus meamh óruóirc ag seáint í náirdinn nlá',
  'tá sé anseo alais',
  'bui deoar d ar naoionán na thircim cia súin',
  'an raibh na cailíne ag obair',
  'ní raibh ac hí uachtar aggam',
  'go rabh mbeagaibh',
  'ar dheis leabh dé go roumhad anan dílis',
  'bhío na tuonta i builh a goinne na gcairigithe',
  'tá lá is troca i mínea nollaig',
  'agus thug sé a chúlair goimí cheadachagus dimí sé'],
 ['"Bhí Tara Viscardi agus Meadhbh O\'Rourke ag seinnt i ngairdín na mbláth"',
  'Tá sé anseo anois',
  'Ba ghearr go raibh an naíonán ina thoirchim suain',
  'An raibh na cailíní ag obair',
  'Ní raibh, ach bhí uachtar agam',
  'Go raibh maith agaibh',
  'Ar dheis lámh Dé go raibh a anam dílis',
  'Bhí na tonnta ag bualadh i gcoinne na gcarraigeacha.',
  'Tá lá is tríocha i mí na Nollag',
  'Agus thug sé a chúl air go mí-cheadtach agus d’imigh sé'])

In [None]:
wer_true = 100 * wer_metric.compute(predictions=result["pred_strings"], references=result["sentence"])
print(f"WER: {wer_true:2f}")

WER: 74.370977


In [None]:
# import wandb
# wandb_run.log({'test/wer_true': wer_true})
# wandb_run.finish()

## Jim O Regan Evaluate

In [None]:
import torch
import torchaudio
from datasets import load_dataset, load_metric
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import re
test_dataset = load_dataset("common_voice", "ga-IE", split="test")
wer = load_metric("wer")
# processor = Wav2Vec2Processor.from_pretrained("jimregan/wav2vec2-large-xlsr-irish-basic")
# model = Wav2Vec2ForCTC.from_pretrained("jimregan/wav2vec2-large-xlsr-irish-basic") 
# model.to("cuda")

# So, tolower() for Irish is a bit complicated: tAthar -> t-athair
# toupper() is non-deterministic :)
def is_upper_vowel(letter):
    if letter in ['A', 'E', 'I', 'O', 'U', 'Á', 'É', 'Í', 'Ó', 'Ú']:
        return True
    else:
        return False
def irish_lower(word):
    if len(word) > 1 and word[0] in ['n', 't'] and is_upper_vowel(word[1]):
        return word[0] + '-' + word[1:].lower()
    else:
        return word.lower()
    
def irish_lower_sentence(sentence):
    return " ".join([irish_lower(w) for w in sentence.split(" ")])

chars_to_ignore_regex = '[,\?\.\!\;\:\"\“\%\‘\”\(\)\*]'

def remove_special_characters(sentence):
    tmp = re.sub('’ ', ' ', sentence)
    tmp = re.sub("’{{%htmlContent%}}quot;", '', tmp)
    tmp = re.sub('’', '\'', tmp)
    tmp = re.sub(chars_to_ignore_regex, '', tmp)
    sentence = irish_lower_sentence(tmp) + ' '
    return sentence

resampler = torchaudio.transforms.Resample(48_000, 16_000)

# Preprocessing the datasets.
# We need to read the audio files as arrays
def speech_file_to_array_fn(batch):
    batch["sentence"] = remove_special_characters(batch["sentence"])
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    batch["speech"] = resampler(speech_array).squeeze().numpy()
    return batch

test_dataset = test_dataset.map(speech_file_to_array_fn)

# Preprocessing the datasets.
# We need to read the audio files as arrays
def evaluate(batch):
    inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
    with torch.no_grad():
        logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits    
    pred_ids = torch.argmax(logits, dim=-1)
    batch["pred_strings"] = processor.batch_decode(pred_ids)
    return batch
    
result = test_dataset.map(evaluate, batched=True, batch_size=16)
print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))

In [None]:
result[:10]["pred_strings"], result[:10]["sentence"]

In [None]:
wer_true = 100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])

In [None]:
wer_true

In [None]:
import wandb
wandb.log({'test/wer_true': wer_true})
wandb.finish()