In [1]:
!pip3 install -r requirements.txt

Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://pypi.org/simple, https://packagecloud.io/github/git-lfs/pypi/simple


[33mDEPRECATION: omegaconf 2.0.6 has a non-standard dependency specifier PyYAML>=5.1.*. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of omegaconf or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0m

In [5]:
!echo "awe" > password.txt
!sudo -S apt-get install git-lfs < password.txt

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
git-lfs is already the newest version (3.0.2-1ubuntu0.2).
The following packages were automatically installed and are no longer required:
  libasound2-dev libportaudiocpp0 uuid-dev
Use 'sudo apt autoremove' to remove them.
0 upgraded, 0 newly installed, 0 to remove and 3 not upgraded.


In [None]:
# Standard imports
import os
import json
import torch
import evaluate
import numpy as np
import pandas as pd
from dataclasses import dataclass
from typing import Dict, List, Union

# My own imports
from load_nchlt import load_nchlt
from load_fleurs_asr import load_fleurs_asr
from load_high_quality_tts import load_high_quality_tts
from utils import (
    SR, WRITE_ACCESS_TOKEN,
    clear_cache,
    remove_special_characters_batch
)

# HuggingFace imports
from datasets import Audio
from datasets import load_dataset
from huggingface_hub import login
from transformers import Wav2Vec2ForCTC
from transformers import Wav2Vec2Processor
from transformers import Wav2Vec2CTCTokenizer
from transformers import Wav2Vec2FeatureExtractor
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback

In [None]:
class ASR_MODEL:
    def __init__(self, repo_name, dataset_name, load_from_hf, push_dataset, push_repo, write_audio):
        self.repo_name = repo_name
        self.dataset_name = dataset_name
        self.load_from_hf = load_from_hf
        self.push_dataset = push_dataset
        self.push_repo = push_repo
        self.write_audio = write_audio
        os.makedirs(repo_name, exist_ok=True)

        self.train_set = None
        self.val_set = None
        self.test_set = None
        self.tokenizer = None
        self.feature_extractor = None
        self.processor = None
        self.data_collator = None
        self.xlsr_model = None
        self.trainer = None

    def load_datasets(self):
        if not self.load_from_hf:
            if not os.path.exists(self.dataset_name):
                # Create dataset by combining 3 datasets into an audiofolder
                csv_entries = []
                if (self.dataset_name == "asr_af"):
                    csv_entries += load_fleurs_asr(only_af=True, write_audio=self.write_audio)
                    csv_entries += load_high_quality_tts(only_af=True, write_audio=self.write_audio)
                    csv_entries += load_nchlt(only_af=True, write_audio=self.write_audio)
                elif (self.dataset_name == "asr_xh"):
                    csv_entries += load_fleurs_asr(only_xh=True, write_audio=self.write_audio)
                    csv_entries += load_high_quality_tts(only_xh=True, write_audio=self.write_audio)
                    csv_entries += load_nchlt(only_xh=True, write_audio=self.write_audio)
                elif (self.dataset_name == "asr_af_xh"):
                    csv_entries += load_fleurs_asr(write_audio=self.write_audio)
                    csv_entries += load_high_quality_tts(write_audio=self.write_audio)
                    csv_entries += load_nchlt(write_audio=self.write_audio)
                metadata = pd.DataFrame(csv_entries, columns=['file_name', 'transcription'])
                metadata.to_csv(path_or_buf=os.path.join(self.dataset_name, "metadata.csv"), sep=",", index=False)

                # Load dataset from audiofolder that you created
                dataset = load_dataset("audiofolder", data_dir=self.dataset_name)

                # Push dataset to huggingface hub
                if self.push_dataset:
                    done = False
                    num_restarts = 0
                    while not done:
                        try:
                            dataset.push_to_hub(f"lucas-meyer/{self.dataset_name}")
                            done = True
                        except Exception as e:
                            num_restarts += 1
                            print(f"Restarting (num restarts: {num_restarts})")
            else:
                # Load dataset from audiofolder that you created
                dataset = load_dataset("audiofolder", data_dir=self.dataset_name)
        else:
            # Load dataset from huggingface hub
            dataset = load_dataset(f"lucas-meyer/{self.dataset_name}") # 31 Minutes !!!

        # Downsample audio to SR = 16000 and init train/val/test sets
        self.train_set = dataset["train"].cast_column("audio", Audio(sampling_rate=SR)).rename_column("transcription", "sentence")
        self.val_set = dataset["validation"].cast_column("audio", Audio(sampling_rate=SR)).rename_column("transcription", "sentence")
        self.test_set = dataset["test"].cast_column("audio", Audio(sampling_rate=SR)).rename_column("transcription", "sentence")
        self.train_set = self.train_set.map(remove_special_characters_batch)
        self.val_set = self.val_set.map(remove_special_characters_batch)
        self.test_set = self.test_set.map(remove_special_characters_batch)
        torch.cuda.empty_cache()

    def extract_all_chars(self, batch):
        all_text = " ".join(batch["sentence"])
        vocab = list(set(all_text))
        return {"vocab": [vocab], "all_text": [all_text]}

    def create_tokenizer(self):
        vocab_train = self.train_set.map(self.extract_all_chars,
                                         batched=True, batch_size=-1,
                                         keep_in_memory=True,
                                         remove_columns=self.train_set.column_names)

        vocab_val = self.val_set.map(self.extract_all_chars,
                                     batched=True, batch_size=-1,
                                     keep_in_memory=True,
                                     remove_columns=self.val_set.column_names)

        vocab_test = self.test_set.map(self.extract_all_chars,
                                       batched=True, batch_size=-1,
                                       keep_in_memory=True,
                                       remove_columns=self.test_set.column_names)

        # Get list for vocab of train/val/test
        vocab_list = list(set(vocab_train["vocab"][0]) |
                        set(vocab_test["vocab"][0]) |
                        set(vocab_val["vocab"][0]))

        # Get dict for vocab of train/val/test
        vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
        vocab_dict["|"] = vocab_dict[" "]
        del vocab_dict[" "]
        vocab_dict["[UNK]"] = len(vocab_dict)
        vocab_dict["[PAD]"] = len(vocab_dict)

        # Save vocabulary file
        with open(os.path.join(self.repo_name, 'vocab.json'), 'w') as vocab_file:
            json.dump(vocab_dict, vocab_file)

        self.tokenizer = Wav2Vec2CTCTokenizer(os.path.join(self.repo_name, 'vocab.json'),
                                              unk_token="[UNK]",
                                              pad_token="[PAD]",
                                              bos_token=None,
                                              eos_token=None,
                                              word_delimiter_token="|")
        if self.push_repo:
            self.tokenizer.push_to_hub(f"lucas-meyer/{self.repo_name}")

    def prepare_dataset(self, batch):
        audio = batch["audio"]
        batch["input_values"] = self.processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
        batch["input_length"] = len(batch["input_values"])
        batch["labels"] = self.processor(text=batch["sentence"]).input_ids
        return batch

    def extract_features(self):
        self.feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1,
                                                     sampling_rate=16000,
                                                     padding_value=0.0,
                                                     do_normalize=True,
                                                     return_attention_mask=True)

        self.processor = Wav2Vec2Processor(feature_extractor=self.feature_extractor,
                                      tokenizer=self.tokenizer)

        self.train_set = self.train_set.map(self.prepare_dataset, remove_columns=self.train_set.column_names)
        self.val_set = self.val_set.map(self.prepare_dataset, remove_columns=self.val_set.column_names)
        self.test_set = self.test_set.map(self.prepare_dataset, remove_columns=self.test_set.column_names)

    def create_data_collator(self):
        self.data_collator = DataCollatorCTCWithPadding(processor=self.processor, padding=True)

    def download_xlsr(self):
        self.xlsr_model = Wav2Vec2ForCTC.from_pretrained(
            "facebook/wav2vec2-xls-r-300m",
            attention_dropout=0.0,
            hidden_dropout=0.0,
            feat_proj_dropout=0.0,
            mask_time_prob=0.05,
            layerdrop=0.0,
            ctc_loss_reduction="mean",
            pad_token_id=self.processor.tokenizer.pad_token_id,
            vocab_size=len(self.processor.tokenizer),
        )
        self.xlsr_model.freeze_feature_encoder()          # Freeze feature exctraction weights
        self.xlsr_model.gradient_checkpointing_enable()   # Enable gradient checkpointing

    def compute_metrics(self, pred):
        pred_logits = pred.predictions
        pred_ids = np.argmax(pred_logits, axis=-1)
        pred.label_ids[pred.label_ids == -100] = self.processor.tokenizer.pad_token_id
        pred_str = self.processor.batch_decode(pred_ids)
        # we do not want to group tokens when computing the metrics
        label_str = self.processor.batch_decode(pred.label_ids, group_tokens=False)

        wer_metric = evaluate.load("wer")
        wer = wer_metric.compute(predictions=pred_str, references=label_str)
        return {"wer": wer}

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
            sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
            maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
            different lengths).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )

        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt",
            )

        # labels_batch = self.processor.pad(
        #     text=label_features,
        #     padding=self.padding,
        #     return_tensors="pt",
        # )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        batch["labels"] = labels
        return batch

In [None]:
login(token=WRITE_ACCESS_TOKEN)

model_name = "wav2vec2-xls-r-300m"
dataset_name = "asr_xh"
model = ASR_MODEL(repo_name=f"{model_name}-{dataset_name}",
                    dataset_name=dataset_name,
                    load_from_hf=True,
                    push_dataset=False,
                    push_repo=True,
                    write_audio=False)

model.load_datasets()
model.create_tokenizer()
model.extract_features()
model.create_data_collator()
model.download_xlsr()

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


Downloading builder script:   0%|          | 0.00/12.6k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/13.3k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/2.27G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/227M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/557M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/2.01M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/254k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/617k [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Downloading af_za.tar.gz ...


100%|██████████| 951M/951M [01:47<00:00, 8.84MiB/s]


File af_za.tar.gz downloaded successfully!

Downloading xh_za.tar.gz ...


100%|██████████| 907M/907M [01:41<00:00, 8.95MiB/s]


File xh_za.tar.gz downloaded successfully!



Resolving data files:   0%|          | 0/5202 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/699 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1472 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/5203 [00:00<?, ?it/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Downloading data files:   0%|          | 0/700 [00:00<?, ?it/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Downloading data files:   0%|          | 0/1473 [00:00<?, ?it/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/19 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/273 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/273 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/273 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/273 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/350 [00:00<?, ? examples/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/2 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Map:   0%|          | 0/349 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Restarting (num restarts: 1)


Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/19 [00:00<?, ?it/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Map:   0%|          | 0/274 [00:00<?, ? examples/s]

Map:   0%|          | 0/273 [00:00<?, ? examples/s]

Map:   0%|          | 0/273 [00:00<?, ? examples/s]

Map:   0%|          | 0/273 [00:00<?, ? examples/s]

Map:   0%|          | 0/273 [00:00<?, ? examples/s]

Map:   0%|          | 0/350 [00:00<?, ? examples/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/2 [00:00<?, ?it/s]

Map:   0%|          | 0/349 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Map:   0%|          | 0/295 [00:00<?, ? examples/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/5 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/295 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/294 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/294 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/294 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Map:   0%|          | 0/5202 [00:00<?, ? examples/s]

Map:   0%|          | 0/699 [00:00<?, ? examples/s]

Map:   0%|          | 0/1472 [00:00<?, ? examples/s]

Map:   0%|          | 0/5202 [00:00<?, ? examples/s]

Map:   0%|          | 0/699 [00:00<?, ? examples/s]

Map:   0%|          | 0/1472 [00:00<?, ? examples/s]

Map:   0%|          | 0/5202 [00:00<?, ? examples/s]

Map:   0%|          | 0/699 [00:00<?, ? examples/s]

Map:   0%|          | 0/1472 [00:00<?, ? examples/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.57k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-xls-r-300m and are newly initialized: ['lm_head.weight', 'lm_head.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
clear_cache()

batch_size = 4
grad_accum_steps = 3

training_args = TrainingArguments(
    # Output directory (repo name)
    output_dir=model.repo_name,
    overwrite_output_dir=True,

    # do_train=False,
    # do_eval=False,
    # do_predict=False,

    # Evaluate after every eval_steps
    evaluation_strategy="steps",

    # prediction_loss_only = False,
    # eval_delay=0,

    # Batch sizes and num accumulation steps
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size//4,
    gradient_accumulation_steps=grad_accum_steps,
    eval_accumulation_steps=grad_accum_steps,

    # Use small learning rate (since we halved the batch size)
    learning_rate=0.0001,

    # weight_decay = 0.0,
    # adam_beta1 = 0.9,
    # adam_beta2 = 0.999,
    # adam_epsilon = 1e-08,
    # max_grad_norm = 1.0,

    # Train for 20 epochs max
    num_train_epochs=20,
    # max_steps = 100,

    # Learning rate schedule and warmup
    lr_scheduler_type = 'linear',
    warmup_ratio = 0.1,
    warmup_steps = 0,

    # Logging settings
    # log_level = 'passive',
    # log_level_replica = 'warning',
    # log_on_each_node = True,
    # logging_dir = None,
    # logging_strategy = 'steps',
    # logging_first_step = False,
    logging_steps=100,
    # logging_nan_inf_filter = True,
    # save_strategy = 'steps',

    # Save steps
    save_steps=100,
    save_total_limit = 2,
    # save_safetensors = False,
    # save_on_each_node = False,

    # no_cuda = False,
    # use_cpu = False,
    # use_mps_device = False,

    # Set seed
    seed = 42,

    # data_seed = None,
    # jit_mode_eval = False,
    # use_ipex = False,
    # bf16 = False,

    # Use fp16
    fp16=True,

    # fp16_opt_level = 'O1',
    # half_precision_backend = 'auto',
    # bf16_full_eval = False,
    # fp16_full_eval = False,
    # tf32 = None,
    # local_rank = -1,
    # ddp_backend = None,
    # tpu_num_cores = None,
    # tpu_metrics_debug = False,
    # debug = '',
    # dataloader_drop_last = False,

    # Evaluation steps
    eval_steps=100,

    # Num processes for loading data
    # dataloader_num_workers = 0,

    # past_index = -1,
    # run_name = None,
    # disable_tqdm = None,
    # remove_unused_columns = True,
    # label_names = None,

    # Load best model at end
    load_best_model_at_end = True,
    # metric_for_best_model = None, # default is loss
    # greater_is_better = None,

    # ignore_data_skip = False,
    # sharded_ddp = '',
    # fsdp = '',
    # fsdp_min_num_params = 0,
    # fsdp_config = None,
    # fsdp_transformer_layer_cls_to_wrap = None,
    # deepspeed = None,
    # label_smoothing_factor = 0.0,
    # optim = 'adamw_torch',
    # optim_args = None,
    # adafactor = False,

    # Group by length
    group_by_length = False,
    length_column_name = 'length',

    # report_to = None,
    # ddp_find_unused_parameters = None,
    # ddp_bucket_cap_mb = None,
    # ddp_broadcast_buffers = None,
    # dataloader_pin_memory = True,
    # skip_memory_metrics = True,
    # use_legacy_prediction_loop = False,

    # Push to hub
    push_to_hub = True,

    # resume_from_checkpoint = None,
    # hub_model_id = None,
    # hub_strategy = 'every_save',
    # hub_token = None,
    # hub_private_repo = False,
    # hub_always_push = False,

    # Enable gradient checkpointing
    gradient_checkpointing=True,

    # include_inputs_for_metrics = False,
    # fp16_backend = 'auto',
    # push_to_hub_model_id = None,
    # push_to_hub_organization = None,
    # push_to_hub_token = None,
    # mp_parameters = '',
    # auto_find_batch_size = False,
    # full_determinism = False,
    # torchdynamo = None,
    # ray_scope = 'last',
    # ddp_timeout = 1800,

    # torch_compile = False,
    # torch_compile_backend = None,
    # torch_compile_mode = None,
    # dispatch_batches = None,
)

model.trainer = Trainer(
    model=model.xlsr_model,
    data_collator=model.data_collator,
    args=training_args,
    compute_metrics=model.compute_metrics,
    train_dataset=model.train_set,
    eval_dataset=model.val_set,
    tokenizer=model.processor.feature_extractor,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

model.trainer.train()
model.trainer.push_to_hub()



Step,Training Loss,Validation Loss,Wer
100,14.8497,7.599672,1.0


Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]



Step,Training Loss,Validation Loss,Wer
100,14.8497,7.599672,1.0
200,5.6526,4.587895,1.0
300,4.1313,3.584044,1.0
400,3.3296,3.118462,1.0
500,3.0911,3.005561,1.0
600,2.7494,2.011232,0.999878
700,1.3887,0.91609,0.866007
800,0.9307,0.719424,0.783193
900,0.771,0.59445,0.720533
1000,0.6298,0.53664,0.691218


