In [None]:
!apt update
!apt install -y sox
!apt install -y nodejs
!apt install -y npm

!pip install --upgrade pip torch torchvision torchaudio pandas numpy~=1.19.2 sklearn transformers datasets ipywidgets matplotlib jiwer seaborn unidecode librosa soundfile torchaudio tqdm

In [2]:
!jupyter nbextension enable --py widgetsnbextension
!jupyter labextension install @jupyter-widgets/jupyterlab-manager

Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m
Building jupyterlab assets (build:prod:minimize)


In [5]:
!export HF_DATASETS_CACHE="/workspace/persistent/ASR/cache"
!export TRANSFORMERS_CACHE="/workspace/persistent/ASR/cache"

import os
import tempfile

cache_dir = "/workspace/persistent/ASR/cache"
temp_dir = "/workspace/persistent/ASR/temp"

os.environ['TRANSFORMERS_CACHE'] = cache_dir
os.environ['HF_DATASETS_CACHE'] = cache_dir
os.environ['HF_HOME'] = cache_dir
os.environ['TMPDIR'] = temp_dir
os.environ['TEMP'] = temp_dir
os.environ['TMP'] = temp_dir

print(tempfile.gettempdir()) # prints the current temporary directory

/workspace/persistent/ASR/temp


In [6]:
from datasets import load_dataset, load_metric, ClassLabel, concatenate_datasets, Dataset
import torch
import gc
import random
import pandas as pd
from IPython.display import display, HTML
import IPython.display as ipd
import numpy as np
import random
import json
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Processor
import torchaudio
import librosa

gc.collect()
torch.cuda.empty_cache()

In [7]:
data = load_dataset('csv', data_files={
    'train': 'data/train_data_strip.csv', 
    'valid': 'data/dev_data_strip.csv',
    'test': 'data/test_data_strip.csv'
}, cache_dir=cache_dir)

data = data.remove_columns(["wav_filesize"])
data = data.rename_column("transcript", "sentence")
data = data.rename_column("wav_filename", "path")

def correct_path(batch):
    batch['path'] = batch['path'].replace("../", "")
    return batch

data = data.map(correct_path)

common_voice_train = data['train']
common_voice_dev = data['valid']
common_voice_test = data['test']

mcv_train = load_dataset("common_voice", "nl", split="train+validation", cache_dir=cache_dir)
mcv_test = load_dataset("common_voice", "nl", split="test", cache_dir=cache_dir)

mcv_train = mcv_train.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
mcv_test = mcv_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])

display(data['train'].features)
display(mcv_train.features)
display(data['test'].features)
display(mcv_test.features)

common_voice_train = concatenate_datasets([mcv_train, data['train']])
common_voice_dev = data['valid']
common_voice_test = concatenate_datasets([mcv_test, data['test']])

display(common_voice_train)
display(common_voice_dev)
display(common_voice_test)

Using custom data configuration default-e5deadd8e45b4fa8
Reusing dataset csv (/workspace/persistent/ASR/cache/csv/default-e5deadd8e45b4fa8/0.0.0/9144e0a4e8435090117cea53e6c7537173ef2304525df4a077c435d8ee7828ff)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /workspace/persistent/ASR/cache/csv/default-e5deadd8e45b4fa8/0.0.0/9144e0a4e8435090117cea53e6c7537173ef2304525df4a077c435d8ee7828ff/cache-6ce973b1c2a91f00.arrow
Loading cached processed dataset at /workspace/persistent/ASR/cache/csv/default-e5deadd8e45b4fa8/0.0.0/9144e0a4e8435090117cea53e6c7537173ef2304525df4a077c435d8ee7828ff/cache-b4f67845cc524309.arrow
Loading cached processed dataset at /workspace/persistent/ASR/cache/csv/default-e5deadd8e45b4fa8/0.0.0/9144e0a4e8435090117cea53e6c7537173ef2304525df4a077c435d8ee7828ff/cache-b1b48607cce2c527.arrow
Reusing dataset common_voice (/workspace/persistent/ASR/cache/common_voice/nl/6.1.0/078d412587e9efeb0ae2e574da99c31e18844c496008d53dc5c60f4159ed639b)
Reusing dataset common_voice (/workspace/persistent/ASR/cache/common_voice/nl/6.1.0/078d412587e9efeb0ae2e574da99c31e18844c496008d53dc5c60f4159ed639b)


{'path': Value(dtype='string', id=None),
 'sentence': Value(dtype='string', id=None)}

{'path': Value(dtype='string', id=None),
 'sentence': Value(dtype='string', id=None)}

{'path': Value(dtype='string', id=None),
 'sentence': Value(dtype='string', id=None)}

{'path': Value(dtype='string', id=None),
 'sentence': Value(dtype='string', id=None)}

Dataset({
    features: ['path', 'sentence'],
    num_rows: 82461
})

Dataset({
    features: ['path', 'sentence'],
    num_rows: 17016
})

Dataset({
    features: ['path', 'sentence'],
    num_rows: 26978
})

In [8]:
def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

show_random_elements(common_voice_train)

Unnamed: 0,path,sentence
0,CGN_2.0.3/data/audio/wav/comp-k/nl/fn002004(004).wav,op de grachten is een nieuw vaarregime ingesteld alles mag maar n kant op dobberen
1,CGN_2.0.3/data/audio/wav/comp-o/nl/fn001497(003).wav,dit werd gedefinieerd als families waarvan minstens drie generaties voor zeventien vijfennegentig hoge functies hadden bekleed
2,/workspace/persistent/ASR/cache/downloads/extracted/361bdee0982c582d7b4ccf5271db67240ab004ad9eae17e2e463d1d7ddafe9c3/cv-corpus-6.1-2020-12-11/nl/clips/common_voice_nl_19503026.mp3,De sfeer op hun departement was vijandig.
3,CGN_2.0.3/data/audio/wav/comp-h/nl/fn009060(018).wav,en betogen schrijven da 's lastig vervelend
4,CGN_2.0.3/data/audio/wav/comp-o/vl/fv800558(010).wav,ik zet dan heel dynamische muziek keihard op de gipsy kings bijvoorbeeld
5,CGN_2.0.3/data/audio/wav/comp-a/vl/fv400667(046).wav,da 's van het n naar het ander en gaan kieken kijken en keuren
6,CGN_2.0.3/data/audio/wav/comp-g/nl/fn000200(058).wav,en dat blijkt uit ook uit overwegingen die door in die die door de juristen naar voren gebracht worden
7,/workspace/persistent/ASR/cache/downloads/extracted/361bdee0982c582d7b4ccf5271db67240ab004ad9eae17e2e463d1d7ddafe9c3/cv-corpus-6.1-2020-12-11/nl/clips/common_voice_nl_23934575.mp3,"Voorzitter, ik heb in mijn verslag een catalogus van te nemen maatregelen opgesomd."
8,CGN_2.0.3/data/audio/wav/comp-g/nl/fn000008(074).wav,wij vinden dus dat daar een harmonisering h dat dat gewenst is
9,/workspace/persistent/ASR/cache/downloads/extracted/361bdee0982c582d7b4ccf5271db67240ab004ad9eae17e2e463d1d7ddafe9c3/cv-corpus-6.1-2020-12-11/nl/clips/common_voice_nl_22916873.mp3,De tekst is unaniem goedgekeurd in de bevoegde commissie.


In [9]:
import re
import unidecode

chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�=&\(\)]'

def remove_special_characters(batch):
    batch["sentence"] = unidecode.unidecode(batch["sentence"])
    batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).upper() + " "
    return batch

common_voice_train = common_voice_train.map(remove_special_characters)
common_voice_test = common_voice_test.map(remove_special_characters)

show_random_elements(common_voice_train)

Loading cached processed dataset at /workspace/persistent/ASR/cache/common_voice/nl/6.1.0/078d412587e9efeb0ae2e574da99c31e18844c496008d53dc5c60f4159ed639b/cache-b09b0975e4f03cd8.arrow
Loading cached processed dataset at /workspace/persistent/ASR/cache/common_voice/nl/6.1.0/078d412587e9efeb0ae2e574da99c31e18844c496008d53dc5c60f4159ed639b/cache-b31e40c21e4f4768.arrow


Unnamed: 0,path,sentence
0,CGN_2.0.3/data/audio/wav/comp-o/nl/fn001032(012).wav,VOOR DE JEUGD WAS HET CARNAVALSTIJD OOK DE MEDIA DEELDEN IN DEZE OVERWINNINGSROES WANT ER HEERSTE VOOR HET EERST WEER EEN AANZIENLIJKE MATE VAN PERSVRIJHEID
1,/workspace/persistent/ASR/cache/downloads/extracted/361bdee0982c582d7b4ccf5271db67240ab004ad9eae17e2e463d1d7ddafe9c3/cv-corpus-6.1-2020-12-11/nl/clips/common_voice_nl_24018514.mp3,HET IS MOEILIJK OP DIT PUNT CONTROLE UIT TE OEFENEN
2,CGN_2.0.3/data/audio/wav/comp-f/nl/fn007620(025).wav,UH EN ALS HET NOU ECHT HELEMAAL SCHEEF VERDEELD IS ZOALS WE DAT DAN NOEMEN DAN ZULLEN WAT GAAN VERPLAATSEN
3,CGN_2.0.3/data/audio/wav/comp-c/nl/fn008173(050).wav,V UH WAT IK OP TELETEKST UH HEB GELEZEN DAT BLIJKT DUS OOK ZO TE ZIJN
4,/workspace/persistent/ASR/cache/downloads/extracted/361bdee0982c582d7b4ccf5271db67240ab004ad9eae17e2e463d1d7ddafe9c3/cv-corpus-6.1-2020-12-11/nl/clips/common_voice_nl_23985459.mp3,IK NEEM AAN DAT U DAT PRECIES ZO ZIET ALS IK
5,/workspace/persistent/ASR/cache/downloads/extracted/361bdee0982c582d7b4ccf5271db67240ab004ad9eae17e2e463d1d7ddafe9c3/cv-corpus-6.1-2020-12-11/nl/clips/common_voice_nl_23985337.mp3,WAAR LEIDT DIT ARGUMENT TOE
6,/workspace/persistent/ASR/cache/downloads/extracted/361bdee0982c582d7b4ccf5271db67240ab004ad9eae17e2e463d1d7ddafe9c3/cv-corpus-6.1-2020-12-11/nl/clips/common_voice_nl_23970366.mp3,GEEN VAN DEZE ZAKEN WORDT BEOOGD
7,CGN_2.0.3/data/audio/wav/comp-c/nl/fn008246(029).wav,EN UH NOU MOCHTEN ZE VAN DE WEEK BIEDEN ALS N VAN DE TWAALF UH AANNEMERS OF N VAN DE TWAALF PROJECTONTWIKKELAARS
8,CGN_2.0.3/data/audio/wav/comp-b/nl/fn000255(113).wav,'T UH IK VIND HET 'T VAK TE SAAI OM TE GEVEN IK HEB 'T OVER NEDERLANDS MARCO NIET OVER MUZIEK VOORDAT JE DENKT GAAT IE NOU VERTELLEN OVER MUZIEK DAT DAT SAAI IS
9,CGN_2.0.3/data/audio/wav/comp-o/nl/fn001021(025).wav,DAT IS MOOI ANTWOORDT MA ZE KIJKT MACHTELD ALLEEN EVEN AAN IK ZAL ZO NAAR SCHOOL BELLEN MAAR HET MOET AFGELOPEN ZIJN MACHTELD


In [13]:
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("facebook/hubert-large-ls960-ft")



In [14]:
speech_array, sampling_rate = torchaudio.load(common_voice_train[0]["path"])
display(speech_array.numpy())
common_voice_train

array([[0.        , 0.        , 0.        , ..., 0.00230056, 0.00247365,
        0.00260854]], dtype=float32)

Dataset({
    features: ['path', 'sentence'],
    num_rows: 82461
})

In [15]:
from tqdm.notebook import tqdm, trange
import warnings
warnings.filterwarnings('ignore')

def speech_file_to_array_fn(batch):
    batch["target_text"] = batch["sentence"]
    
    try:
        speech_array, sampling_rate = torchaudio.load(batch["path"])
        batch["speech"] = speech_array[0].numpy()
        batch["sampling_rate"] = sampling_rate
        return batch
    except:
        batch["speech"] = np.array([], dtype = np.float32)
        batch["sampling_rate"] = 1
        
    return batch


def speech_map_dataset(dataset):
    target_text = []
    speech = []
    sampling_rate = []
    
    p_bar = tqdm(range(len(dataset)))
    
    for i in p_bar:
        result = speech_file_to_array_fn(common_voice_train[i])
        target_text = target_text + [result["target_text"]]
        speech = speech + [result["speech"]]
        sampling_rate = sampling_rate + [result["sampling_rate"]]

        p_bar.update()
        
    out_ds = Dataset.from_dict({
        "target_text": target_text,
        "speech": speech,
        "sampling_rate": sampling_rate
    })
    
    out_ds.features = dataset.features
    
    return out_ds

# common_voice_train = speech_map_dataset(common_voice_train)

common_voice_train = common_voice_train.map(speech_file_to_array_fn, remove_columns=common_voice_train.column_names)
common_voice_test = common_voice_test.map(speech_file_to_array_fn, remove_columns=common_voice_test.column_names)

Loading cached processed dataset at /workspace/persistent/ASR/cache/common_voice/nl/6.1.0/078d412587e9efeb0ae2e574da99c31e18844c496008d53dc5c60f4159ed639b/cache-838e55d3643c53c4.arrow
Loading cached processed dataset at /workspace/persistent/ASR/cache/common_voice/nl/6.1.0/078d412587e9efeb0ae2e574da99c31e18844c496008d53dc5c60f4159ed639b/cache-80edfe99ef465de1.arrow


In [16]:
display(common_voice_train)

common_voice_train = common_voice_train.filter(lambda batch: batch['sampling_rate'] > 1)
common_voice_test = common_voice_test.filter(lambda batch: batch['sampling_rate'] > 1)

Dataset({
    features: ['target_text', 'speech', 'sampling_rate'],
    num_rows: 82461
})

Loading cached processed dataset at /workspace/persistent/ASR/cache/common_voice/nl/6.1.0/078d412587e9efeb0ae2e574da99c31e18844c496008d53dc5c60f4159ed639b/cache-98760aae189af925.arrow
Loading cached processed dataset at /workspace/persistent/ASR/cache/common_voice/nl/6.1.0/078d412587e9efeb0ae2e574da99c31e18844c496008d53dc5c60f4159ed639b/cache-1d394fefeda38c41.arrow


In [17]:
display(common_voice_train)
display(common_voice_train.features)

Dataset({
    features: ['target_text', 'speech', 'sampling_rate'],
    num_rows: 82461
})

{'target_text': Value(dtype='string', id=None),
 'speech': Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None),
 'sampling_rate': Value(dtype='int64', id=None)}

In [18]:
def resample(batch):
    if batch["sampling_rate"] == 16_000:
        return batch
    
    batch["speech"] = librosa.resample(np.asarray(batch["speech"]), batch["sampling_rate"], 16_000)
    batch["sampling_rate"] = 16_000
    return batch


def speech_map_dataset(dataset):
    target_text = []
    speech = []
    sampling_rate = []
    
    p_bar = tqdm(range(len(dataset)))
    
    for i in p_bar:
        result = resample(common_voice_train[i])
        target_text = target_text + [result["target_text"]]
        speech = speech + [result["speech"]]
        sampling_rate = sampling_rate + [result["sampling_rate"]]

        p_bar.update()
    
    dic = {
        "target_text": target_text,
        "speech": speech,
        "sampling_rate": sampling_rate
    }
    
    df = pd.DataFrame.from_dict(dic)
    return Dataset.from_pandas(df, features=common_voice_train.features)


common_voice_test = speech_map_dataset(common_voice_test)
common_voice_train = speech_map_dataset(common_voice_train)

# common_voice_test = common_voice_test.map(resample, num_proc=1)
# common_voice_train = common_voice_train.map(resample, num_proc=1)

  0%|          | 0/26978 [00:00<?, ?it/s]

  0%|          | 0/82461 [00:00<?, ?it/s]

In [19]:
# common_voice_test.to_json("/workspace/persistent/ASR/Training/dataset/test.json")
# common_voice_train.to_json("/workspace/persistent/ASR/Training/dataset/train.json")

In [20]:
# common_voice_train = load_dataset("/workspace/persistent/ASR/Training/dataset/train", cache_dir="/workspace/persistent/ASR/cache")
# common_voice_test = load_dataset("/workspace/persistent/ASR/Training/dataset/test", cache_dir="/workspace/persistent/ASR/cache")

In [21]:
from datasets import IterableDataset

common_voice_test
# tokenized_dataset = common_voice_test.map(lambda x: tokenizer(x["target_text"]))

Dataset({
    features: ['target_text', 'speech', 'sampling_rate'],
    num_rows: 26978
})

In [22]:
rand_int = random.randint(0, len(common_voice_train))

ipd.Audio(data=np.asarray(common_voice_train[rand_int]["speech"]), autoplay=False, rate=16000)

In [23]:
rand_int = random.randint(0, len(common_voice_train))

print("Target text:", common_voice_train[rand_int]["target_text"])
print("Input array shape:", np.asarray(common_voice_train[rand_int]["speech"]).shape)
print("Sampling rate:", common_voice_train[rand_int]["sampling_rate"])

Target text: CRUIJFF IS VOOR HET EERST SINDS DRIE JAAR BIJ DE OPENINGSWEDSTRIJD VAN HET VOETBALSEIZOEN AANWEZIG 'T DUEL IN DE ARENA IS AL BIJNA UITVERKOCHT 
Input array shape: (102000,)
Sampling rate: 16000


In [24]:
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)

In [25]:
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [26]:
import tempfile
print(tempfile.gettempdir()) # prints the current temporary directory

/workspace/persistent/ASR/temp


In [27]:
def prepare_dataset(batch):
    # check that all files have the correct sampling rate
    assert (
        len(set(batch["sampling_rate"])) == 1
    ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."

    batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values

    with processor.as_target_processor():
        batch["labels"] = processor(batch["target_text"]).input_ids
    return batch

common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names, batch_size=8, num_proc=4, batched=True, cache_file_name=cache_dir + "/tokenised_final_train.arrow")
common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=4, batched=True, cache_file_name=cache_dir + "/tokenised_final_test.arrow")

In [28]:
from data_collator import DataCollatorCTCWithPadding

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [29]:
wer_metric = load_metric("wer")

Downloading:   0%|          | 0.00/1.95k [00:00<?, ?B/s]

In [30]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [31]:
from transformers import HubertForCTC

model = HubertForCTC.from_pretrained(
    "facebook/hubert-large-ls960-ft",
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.1,
    gradient_checkpointing=True,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)

Downloading:   0%|          | 0.00/1.18G [00:00<?, ?B/s]

In [32]:
model.freeze_feature_extractor()

In [34]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./hubert-nl-cgn-v2",
  group_by_length=True,
  per_device_train_batch_size=1,
  gradient_accumulation_steps=2,
  evaluation_strategy="epoch",
  save_strategy="epoch",
  num_train_epochs=30,
  fp16=True,
  save_steps=500,
  eval_steps=500,
  logging_steps=500,
  learning_rate=3e-4,
  warmup_steps=500,
  save_total_limit=2,
)

In [35]:
from transformers.optimization import Adafactor, AdafactorSchedule
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
lr_scheduler = AdafactorSchedule(optimizer)

In [42]:
from transformers import Trainer

torch.cuda.empty_cache()

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor
)

Using amp fp16 backend


In [None]:
trainer.train()

***** Running training *****
  Num examples = 82461
  Num Epochs = 30
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 2
  Gradient Accumulation steps = 2
  Total optimization steps = 1236900


Epoch,Training Loss,Validation Loss,Wer
0,0.9887,0.545425,0.373725
1,0.9028,0.407082,0.310886
2,0.7538,0.364186,0.271829


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

***** Running Evaluation *****
  Num examples = 26978
  Batch size = 8
Saving model checkpoint to ./hubert-nl-cgn-v2/checkpoint-41230
Configuration saved in ./hubert-nl-cgn-v2/checkpoint-41230/config.json
Model weights saved in ./hubert-nl-cgn-v2/checkpoint-41230/pytorch_model.bin
Configuration saved in ./hubert-nl-cgn-v2/checkpoint-41230/preprocessor_config.json
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_win