In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
!export HF_DATASETS_CACHE='/path/to/save/cache'

In [None]:
repo_name = "wav2vec2-xls-r-300m_phone-mfa_korean"

### Import related libraries

In [None]:
import glob
import json
import jiwer
import torch
import librosa
import numpy as np

import transformers
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Processor, Wav2Vec2CTCTokenizer, Wav2Vec2PhonemeCTCTokenizer, Wav2Vec2ForCTC, \
TrainingArguments, Trainer, EarlyStoppingCallback

import datasets
from datasets import load_dataset, load_metric

### Data Preparation in JSON Format

In [None]:
from tqdm import tqdm

def prepare_json_format(data_name, data_path):
    w = open(data_name, 'w')
    wav_list = sorted(glob.glob(data_path))
    for i in tqdm(range(len(wav_list)), ncols=40):
        wav = wav_list[i]
        fname = wav.rsplit('/',1)[-1].rsplit('.',1)[0]
        abs_path = wav
        text = open(wav.rsplit('.',1)[0] + '.mfa.txt', 'r').readlines()[0].strip()
        json_line = '{"fname":"%s", "path":"%s", "sampling_rate":16000, "text":"%s"}' % (fname, abs_path, text)
        w.write(json_line + "\n")
    w.close()

In [None]:
tr_data_name = './path/to/data/train.json'
te_data_name = './path/to/data/test.json'

In [None]:
tr_data_path = "/path/to/data/train/*.wav"
te_data_path = "/path/to/data/test/*.wav"

prepare_json_format(tr_data_name, tr_data_path)
prepare_json_format(te_data_name, te_data_path)

In [None]:
ds = load_dataset('json', data_files={'train':tr_data_name, 'test':te_data_name},
                  cache_dir='/path/to/save/cache')

In [None]:
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

show_random_elements(ds['train'].remove_columns(['fname','path', 'sampling_rate']))

In [None]:
def extract_all_chars(batch):
    all_text = " ".join(batch['text'])
    # vocab = list(set(all_text))
    vocab = list(set(all_text.split()))
    return {"vocab":[vocab], "all_text":[all_text]}

vocabs = ds.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=ds.column_names['train'])
# print(list(set(vocabs['train']['vocab'][0])))
vocab_list = list(set(vocabs['train']['vocab'][0]) | set(vocabs['test']['vocab'][0]))
vocab_dict = {v:k for k,v in enumerate(sorted(vocab_list))}
vocab_dict[' '] = len(vocab_dict)
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)

In [None]:
vocab_name = "./vocab/vocab.json"

with open(vocab_name, 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

In [None]:
from transformers import Wav2Vec2PhonemeCTCTokenizer
tokenizer = Wav2Vec2PhonemeCTCTokenizer.from_pretrained("./vocab/", pad_token="[PAD]", unk_token="[UNK]",
                                                        phone_delimiter_token="|",
                                                        do_phonemize=False,
                                                        cache_dir='/path/to/save/cache')

In [None]:
tokenizer.push_to_hub(repo_name)

In [None]:
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True, cache_dir='/data2/excalibur12/.cache/huggingface/datasets')

In [None]:
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [None]:
def prep_dataset(batch):
    audio = batch['path']
    batch['input_values'] = processor(librosa.load(audio, sr=16000)[0], sampling_rate=batch['sampling_rate']).input_values[0]
    batch['input_length'] = len(batch['input_values'])
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["text"]).input_ids

    return batch

In [None]:
ds = ds.map(prep_dataset, remove_columns=ds.column_names['train'], num_proc=4)

In [None]:
max_input_length_in_sec = 12.0
ds['train'] = ds['train'].filter(lambda x: x < max_input_length_in_sec * processor.feature_extractor.sampling_rate, input_columns=["input_length"])

In [None]:
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None
    
    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        
        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )
        
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        
        batch['labels'] = labels
        
        return batch

In [None]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [None]:
wer_metric = load_metric('wer', cache_dir='/path/to/save/cache')

In [None]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)
    
    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
    
    pred_str = processor.batch_decode(pred_ids)
    
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
    
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    
    return {"PER": wer}

In [None]:
model = Wav2Vec2ForCTC.from_pretrained(
    'facebook/wav2vec2-xls-r-300m',
    ctc_loss_reduction='mean',
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
    cache_dir='/path/to/save/cache'
)

In [None]:
model.freeze_feature_encoder()

In [None]:
training_args = TrainingArguments(
    output_dir=repo_name,
    group_by_length=True,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    # evaluation_strategy="steps", ## comment out when only training
    evaluation_strategy="epoch",  ## Choose either "steps" or "epoch"
    # logging_strategy="steps",
    logging_strategy="epoch",  ## Choose either "steps" or "epoch"
    num_train_epochs=20,
    fp16=True,
    gradient_checkpointing=False,
    # save_strategy="steps",
    save_strategy="epoch",
    # save_steps=1000,  ## Only when save_strategy is "step"
    # eval_steps=1000, ## Only when save_strategy is "step" ## comment out when only training
    # logging_steps=1000, ## Only when save_strategy is "step"
    learning_rate=1e-4,
    weight_decay=0.005,
    warmup_ratio=0.2,
    # warmup_steps=750, ##5000,
    save_total_limit=10,
    load_best_model_at_end=True,
    # no_cuda=True,
)

In [None]:
class MyTrainer(Trainer):
    def log(self, logs: Dict[str, float]) -> None:
        logs['learning_rate'] = self._get_learning_rate()
        super().log(logs)

In [None]:
trainer = MyTrainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=ds['train'],
    eval_dataset=ds['test'], ## comment out when only training
    tokenizer=processor.feature_extractor,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=5)]
)

In [None]:
trainer.train()

In [None]:
trainer.push_to_hub()

In [None]:
checkpoint_step_num = 20250
tokenizer.save_pretrained('/path/to/tokenizer/model/{}/checkpoint-{}'.format(repo_name, checkpoint_step_num))

In [None]:
tokenizer.push_to_hub()