In [None]:
from datasets import load_metric, load_dataset, Audio
from transformers import Wav2Vec2ForCTC, AutoProcessor, TrainingArguments, Trainer
import torch

from tqdm import tqdm

import re

from finetuning_util_hausa import preprocess_texts, preprocess_text, create_vocab_dict, create_data_collator, compute_metrics, ASRDataset

import json

cache_dir_fleurs ="/data/users/kashrest/lrl-asr-experiments/data/fleurs"

In [None]:
# Hausa
cache_dir="/data/users/kashrest/lrl-asr-experiments/data/fleurs"
stream_data = load_dataset("google/fleurs", "ha_ng", split="test", cache_dir=cache_dir, streaming=True)
sample = next(iter(stream_data))
print(sample)
ha_sample = sample["audio"]["array"]
ha_sample_transcription = sample["transcription"]

In [None]:
# Hausa
"""from huggingface_hub import notebook_login

notebook_login()"""

cache_dir="/data/users/kashrest/lrl-asr-experiments/data/fleurs"

data = load_dataset("google/fleurs", "ha_ng", split="test", cache_dir=cache_dir)

# Finetuning - MMS-1b-all

## Finetuning 

In [None]:
fleurs_hausa_train_transcriptions = []
fleurs_hausa_train_audio= []

fleurs_hausa_val_transcriptions = []
fleurs_hausa_val_audio= []

fleurs_hausa_test_transcriptions = []
fleurs_hausa_test_audio= []

for elem in load_dataset("google/fleurs", "ha_ng", split="train", cache_dir=cache_dir_fleurs):
    fleurs_hausa_train_transcriptions.append(elem["raw_transcription"])
    fleurs_hausa_train_audio.append(elem["audio"]["array"])
    
for elem in load_dataset("google/fleurs", "ha_ng", split="validation", cache_dir=cache_dir_fleurs):
    fleurs_hausa_val_transcriptions.append(elem["raw_transcription"])
    fleurs_hausa_val_audio.append(elem["audio"]["array"])
    
for elem in load_dataset("google/fleurs", "ha_ng", split="test", cache_dir=cache_dir_fleurs):
    fleurs_hausa_test_transcriptions.append(elem["raw_transcription"])
    fleurs_hausa_test_audio.append(elem["audio"]["array"])

fleurs_hausa_train_transcriptions = preprocess_texts(fleurs_hausa_train_transcriptions)
fleurs_hausa_val_transcriptions = preprocess_texts(fleurs_hausa_val_transcriptions)
fleurs_hausa_test_transcriptions = preprocess_texts(fleurs_hausa_test_transcriptions)

In [None]:
# ISO-639-3 for Hausa = "hau"
vocab_dict = create_vocab_dict(fleurs_hausa_train_transcriptions, fleurs_hausa_val_transcriptions, fleurs_hausa_test_transcriptions)
target_lang = "hau"
new_vocab_dict = {target_lang: vocab_dict}
root = "/data/users/kashrest/lrl-asr-experiments/"
out_dir = "facebook_mms-1b-all/poc-1/"
try:
    os.mkdir(out_dir)
except:
    print(f"Experiment folder already exists") 
    
with open(root+out_dir+'vocab.json', 'w+') as vocab_file:
    json.dump(vocab_dict, vocab_file)

In [None]:
from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(root+out_dir, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)

from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

model_sampling_rate = 16000

train_dataset = ASRDataset(fleurs_hausa_train_audio, fleurs_hausa_train_transcriptions, model_sampling_rate, processor)
val_dataset = ASRDataset(fleurs_hausa_val_audio, fleurs_hausa_val_transcriptions, model_sampling_rate, processor)
test_dataset = ASRDataset(fleurs_hausa_test_audio, fleurs_hausa_test_transcriptions, model_sampling_rate, processor)

data_collator = create_data_collator(processor)

model_id = "facebook/mms-1b-all"

model = Wav2Vec2ForCTC.from_pretrained(
    model_id,
    attention_dropout=0.0,
    hidden_dropout=0.0,
    feat_proj_dropout=0.0,
    layerdrop=0.0,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
    ignore_mismatched_sizes=True,
)

In [None]:
model.init_adapter_layers()

model.freeze_base_model()

adapter_weights = model._get_adapters()
for param in adapter_weights.values():
    param.requires_grad = True


training_args = TrainingArguments(
  output_dir=root+out_dir,
  group_by_length=True,
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  gradient_checkpointing=True,
  fp16=True,
  save_steps=200,
  eval_steps=100,
  logging_steps=100,
  learning_rate=2e-4,
  warmup_steps=100,
  save_total_limit=2,
  push_to_hub=False,
)

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=processor.feature_extractor,
)

trainer.train()

## Evaluation of fine-tuned checkpoint 

In [None]:
model_id = "facebook/mms-1b"
best_model_checkpoint = "./facebook_mms-1b/hausa-finetuning-2-script-basic-example/checkpoint-5200/"

processor = AutoProcessor.from_pretrained(model_id)
model = Wav2Vec2ForCTC.from_pretrained(model_id)
#processor.tokenizer.get_vocab()["hau"]

In [None]:
def _basic_preprocessing(transcription):
    chars_to_ignore = [",", "?", ".", "!", "-", ";", ":", "\\", '"', '“',"%", "‘", '”', "�"]
    chars_to_ignore_regex = (f'[{"".join(chars_to_ignore)}]' if chars_to_ignore is not None else None)
    transcription = re.sub(chars_to_ignore_regex, "", transcription.lower())
    return transcription
    
processor.tokenizer.set_target_lang("hau")
#model.load_adapter("hau")
transcriptions = []
gold_transcriptions = []
for elem in tqdm(data):
    inputs = processor(elem["audio"]["array"], sampling_rate=16_000, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs).logits
    ids = torch.argmax(outputs, dim=-1)[0]
    transcriptions.append(processor.decode(ids))
    gold_transcriptions.append(_basic_preprocessing(elem["transcription"]))
    # 'wachambuzi wa soka wanamtaja mesi kama nyota hatari zaidi duniani'
    # => In English: "soccer analysts describe Messi as the most dangerous player in the world"
    
with open("./facebook_mms-1b/hausa-finetuning-2-script-basic-example/fluers_test_output.jsonl", "w") as f:
    for transcription in transcriptions:
        json.dump(transcription, f)
        f.write("\n")

In [None]:
import json
with open("./facebook_mms-1b/hausa-finetuning-2-script-basic-example/fluers_test_output.jsonl", "w") as f:
    for transcription in transcriptions:
        json.dump(transcription, f)
        f.write("\n")

In [None]:
wer_metric = load_metric("wer")
cer_metric = load_metric("cer")
wer = wer_metric.compute(predictions=transcriptions, references=gold_transcriptions)
cer = cer_metric.compute(predictions=transcriptions, references=gold_transcriptions)
wer, cer

# Inference - MMS-1b-all

In [None]:
data[0]

In [None]:
model_id = "facebook/mms-1b-all"

processor = AutoProcessor.from_pretrained(model_id)
model = Wav2Vec2ForCTC.from_pretrained(model_id)

In [None]:
processor.tokenizer.get_vocab()["hau"]

In [None]:
#batch = processor.pad(input_features, padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors="pt")
processor.tokenizer.set_target_lang("hau")
model.load_adapter("hau")
transcriptions = []
gold_transcriptions = []
for elem in tqdm(data):
    inputs = processor(elem["audio"]["array"], sampling_rate=16_000, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs).logits
    ids = torch.argmax(outputs, dim=-1)[0]
    transcriptions.append(processor.decode(ids))
    gold_transcriptions.append(preprocess_text(elem["raw_transcription"]))
    # 'wachambuzi wa soka wanamtaja mesi kama nyota hatari zaidi duniani'
    # => In English: "soccer analysts describe Messi as the most dangerous player in the world"
    
with open("facebook_mms-1b-all/zero-shot/fluers_customized_preprocessing_test_output.jsonl", "w") as f:
    for transcription in transcriptions:
        json.dump(transcription, f)
        f.write("\n")

In [None]:
len(transcriptions), len(gold_transcriptions)

In [None]:
wer_metric = load_metric("wer")
cer_metric = load_metric("cer")
wer = wer_metric.compute(predictions=transcriptions, references=gold_transcriptions)
cer = cer_metric.compute(predictions=transcriptions, references=gold_transcriptions)
wer, cer

In [None]:
transcriptions[0]

In [None]:
gold_transcriptions[0]