# training ver2.0

### import & setup

In [None]:
#hugging face login
from huggingface_hub import notebook_login
notebook_login()

In [None]:
import yaml
with open('./config.yml', 'rb') as yml:
    config = yaml.safe_load(yml)

In [None]:
# https://note.mjunya.com/posts/2021-12-13-multi-gpu-order/
import os
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]=config['CUDA_VISIBLE_DEVICES']
!echo ${CUDA_VISIBLE_DEVICES}

import torch
for i in range(torch.cuda.device_count()):
    info = torch.cuda.get_device_properties(i)
    print(f"CUDA:{i} {info.name}, {info.total_memory / 1024 ** 2}MB")

print("------------------------------")
print(f"version: {torch.__version__}")
print(f"available: {torch.cuda.is_available()}")
print(f"count: {torch.cuda.device_count()}")
for i in range(0,torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    print(f"GPU {i}: {torch.cuda.get_device_capability(i)}")
print(f"default: {torch.cuda.current_device()}")

In [None]:
import torch.nn as nn
import torchaudio, datasets, warnings
from datasets import load_dataset, load_metric, Audio
import pandas as pd
import numpy as np
warnings.filterwarnings('ignore')

In [None]:
repo_name=config['repo_name']
target=config['target']
lr=config['lr']
TRAIN_ALL_WEIGHTS=config['TRAIN_ALL_WEIGHTS']
num_train_epochs=config['num_train_epochs']
per_device_train_batch_size=config['per_device_train_batch_size']
torch.backends.cudnn.benchmark=config['torch.backends.cudnn.benchmark']
sr=config['sr']
train_csv='./datasets/train_'+target+'.csv'
val_csv='./datasets/val_'+target+'.csv'

In [None]:
train= datasets.load_dataset("csv", data_files={"train":[train_csv]},usecols=['path',target],num_proc=config['num_proc'])
val=datasets.load_dataset("csv", data_files={"train":[val_csv]},usecols=['path',target],num_proc=config['num_proc'])

In [None]:
import random
from IPython.display import display, HTML

def show_random(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))
    
show_random(train['train'],2)

### make token list

In [None]:
# https://engineers.ntt.com/entry/2021/12/20/172148

def extract_token(batch):
  all_label = " ".join(batch[target])
  vocab = list(set(all_label))
  return {"vocab": [vocab], "all_text": [all_label]}

In [None]:
vocab_train=train.map(extract_token,batched=True,batch_size=-1,keep_in_memory=True,remove_columns=train.column_names['train'])
vocab_val= val.map(extract_token, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=val.column_names['train'])
vocab_list= list(set(vocab_train["train"]["vocab"][0]) | set(vocab_val["train"]["vocab"][0]))
vocab_dict= {v: k for k, v in enumerate(vocab_list)}
print(len(vocab_dict))
vocab_dict

In [None]:
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["pau"] = len(vocab_dict) # データセット、openjtalkの表記と合わせる必要あり
len(vocab_dict)

In [None]:
import json
with open(f'./token_{target}.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

### ASR setup

In [None]:
#from transformers import Wav2Vec2PhonemeCTCTokenizer
#tokenizer = Wav2Vec2PhonemeCTCTokenizer(vocab_file=f'{exp_dir}token_{tgt}.json', unk_token="[UNK]", pad_token="pau",do_phonemize=False, word_delimiter_token="|", phone_delimiter_token="|", phonemizer_lang="ja", phonemizer_backend='espeak')
#tokenizer.push_to_hub(repo_name)

In [None]:
from transformers import Wav2Vec2CTCTokenizer
tokenizer = Wav2Vec2CTCTokenizer(vocab_file=f'./token_{target}.json', unk_token="[UNK]", pad_token="pau", word_delimiter_token="|")
tokenizer.push_to_hub(repo_name)

In [None]:
from transformers import Wav2Vec2FeatureExtractor
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=sr, padding_value=0.0, do_normalize=True, return_attention_mask=True)
feature_extractor
# (large) return_attention_mask=True (base) False

In [None]:
from transformers import Wav2Vec2Processor
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
processor

preprocess audio data

In [None]:
def path2array(batch):
    array, rate = torchaudio.load(filepath=batch['path'],format='wav')
    batch["audio_array"]= array
    batch["sampling_rate"] =rate
    return batch

In [None]:
%%time
train=train.map(path2array,num_proc=config['num_proc'])
val=val.map(path2array,num_proc=config['num_proc'])

In [None]:
train=train.remove_columns(['path','sampling_rate'])
val=val.remove_columns(['path','sampling_rate'])

In [None]:
# import itertools
# def array_dim(batch):
#     batch["array_dim"]=len(batch["audio_array"])
#     batch["1Darray"]= list(itertools.chain.from_iterable(batch["audio_array"]))
#     batch["sample"]=len(batch["1Darray"])
#     return batch

In [None]:
# b=train['train'][3]['audio_array'] #list,1
# c=processor(b, sampling_rate=sr)   #transformers.feature_extraction_utils.BatchFeature,2
# d=processor(b, sampling_rate=sr).input_values[0] #numpy.ndarray,8661
# e=processor(b, sampling_rate=sr).input_values    #list,1
# f=processor(b, sampling_rate=sr).input_values[-1] #numpy.ndarray,8661
# g=processor(b, sampling_rate=sr).attention_mask #list,1 all=1

In [None]:
# no change name ["input_values"],["labels"]

def prepare_dataset(batch):
    batch["input_values"] = processor(batch["audio_array"], sampling_rate=sr).input_values[0]
    with processor.as_target_processor():
        batch["labels"] = processor(batch[target]).input_ids
    return batch

In [None]:
%%time
train_prepared=train.map(prepare_dataset,remove_columns=train.column_names["train"],num_proc=config['num_proc'])
val_prepared=val.map(prepare_dataset,remove_columns=val.column_names["train"],num_proc=config['num_proc'])

In [None]:
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
        max_length (:obj:`int`, `optional`):
            Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
        max_length_labels (:obj:`int`, `optional`):
            Maximum length of the ``labels`` returned list and optionally padding length (see above).
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [None]:
from evaluate import load
import jiwer
wer_metric = load_metric("wer")
cer_metric = load_metric('cer')

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    cer = cer_metric.compute(predictions=pred_str, references=label_str)    

    return {"wer": wer, "cer": cer}

In [None]:
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    pretrained_model_name_or_path=config['pretrained_name'],
    attention_dropout=config['attention_dropout'],
    hidden_dropout=config['hidden_dropout'],
    feat_proj_dropout=config['feat_proj_dropout'],
    mask_time_prob=config['mask_time_prob'],
    layerdrop=config['layerdrop'],
    ctc_loss_reduction=config['ctc_loss_reduction'], 
    pad_token_id=processor.tokenizer.pad_token_id,
    diversity_loss_weight=config['diversity_loss_weight'],
    vocab_size=len(processor.tokenizer),
    ignore_mismatched_sizes=config['ignore_mismatched_sizes'],
)
#model.lm_head = nn.Linear(1024,len(processor.tokenizer))

In [None]:
if TRAIN_ALL_WEIGHTS:
    for param in model.parameters():
        param.requires_grad = True
else:
    model.freeze_feature_extractor()
#model.freeze_feature_extractor()    #jdrtのときはこっち？

In [None]:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
pytorch_total_params

In [None]:
warmup_steps = 5 * len(train)//32
num_total_steps = num_train_epochs * len(train)//32

from transformers import TrainingArguments
training_args = TrainingArguments(
#input-output
  output_dir='./'+repo_name,
  logging_dir="./"+"logs",
  push_to_hub=config['push_to_hub'],
  save_total_limit=config['save_total_limit'],
  seed=config['seed'],
#batch
  per_device_train_batch_size=per_device_train_batch_size, 
  per_device_eval_batch_size=per_device_train_batch_size,
  evaluation_strategy=config['evaluation_strategy'],
  save_strategy=config['save_strategy'],
  logging_steps=config['logging_steps'],
  num_train_epochs=num_train_epochs,
  #eval_steps=config['eval_steps],
#lr
  learning_rate=lr,
  lr_scheduler_type=config['lr_scheduler_type'],
  weight_decay=config['weight_decay'],
  warmup_steps=config['warmup_steps'], #この数分学習率増加してから減少させるスケジューラ
#tokens
  group_by_length=config['group_by_length'],
  prediction_loss_only=config['prediction_loss_only'],

#faster
  dataloader_num_workers=os.cpu_count(),
  fp16=config['fp16'],
  fp16_full_eval=config['fp16_full_eval'],
  gradient_checkpointing=config['gradient_checkpointing'],
  gradient_accumulation_steps=32//per_device_train_batch_size,
)
training_args

In [None]:
from transformers import Trainer
from transformers import EarlyStoppingCallback

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_prepared['train'],
    eval_dataset=val_prepared['train'],
    tokenizer=processor.feature_extractor,
    #callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],  #3epoch未改善でearly stop
)
trainer

In [None]:
import gc
gc.collect()

In [None]:
%%time
trainer.train()

In [None]:
trainer.save_state()
trainer.save_model()