In [1]:
from typing import Dict, List
import numpy as np
import flax
import jax.numpy as jnp
from transformers import (
    BatchEncoding,
    PreTrainedTokenizerBase,
)

from transformers.models.t5.modeling_flax_t5 import shift_tokens_right


@flax.struct.dataclass
class FlaxDataCollatorForT5MLM:
    """
    Data collator used for T5 span-masked language modeling.
    It is made sure that after masking the inputs are of length `data_args.max_seq_length` and targets are also of fixed length.
    For more information on how T5 span-masked language modeling works, one can take a look
    at the `official paper <https://arxiv.org/pdf/1910.10683.pdf>`__
    or the `official code for preprocessing <https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/preprocessors.py>`__ .

    Args:
        tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
            The tokenizer used for encoding the data.
        noise_density (:obj:`float`):
            The probability with which to (randomly) mask tokens in the input.
        mean_noise_span_length (:obj:`float`):
            The average span length of the masked tokens.
        input_length (:obj:`int`):
            The expected input length after masking.
        target_length (:obj:`int`):
            The expected target length after masking.
        pad_token_id: (:obj:`int`):
            The pad token id of the model
        decoder_start_token_id: (:obj:`int):
            The decoder start token id of the model
    """

    tokenizer: PreTrainedTokenizerBase
    noise_density: float
    mean_noise_span_length: float
    input_length: int
    target_length: int
    pad_token_id: int
    decoder_start_token_id: int

    def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:

        # convert list to dict and tensorize input
        max_len_array = max([len(list(i.values())[0]) for i in examples])
        batch = BatchEncoding(
            {k: np.array([examples[i][k]+[self.pad_token_id] * (max_len_array - len(examples[i][k])) for i in range(len(examples))]) for k, v in examples[0].items()}
        )

        input_ids = batch["input_ids"]
        batch_size, expandend_input_length = input_ids.shape

        mask_indices = np.asarray([self.random_spans_noise_mask(expandend_input_length) for i in range(batch_size)])
        labels_mask = ~mask_indices

        input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8))
        labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8))

        batch["input_ids"] = self.filter_input_ids(input_ids, input_ids_sentinel)
        batch["labels"] = self.filter_input_ids(input_ids, labels_sentinel)
        batch["decoder_input_ids"] = shift_tokens_right(
            batch["labels"], self.pad_token_id, self.decoder_start_token_id
        )

        return batch

    def create_sentinel_ids(self, mask_indices):
        """
        Sentinel ids creation given the indices that should be masked.
        The start indices of each mask are replaced by the sentinel ids in increasing
        order. Consecutive mask indices to be deleted are replaced with `-1`.
        """
        start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
        start_indices[:, 0] = mask_indices[:, 0]

        sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices)
        sentinel_ids = np.where(sentinel_ids != 0, (self.tokenizer.vocab_size - 1 - sentinel_ids), 0)
        sentinel_ids -= mask_indices - start_indices

        return sentinel_ids

    def filter_input_ids(self, input_ids, sentinel_ids):
        """
        Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
        This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
        """
        batch_size = input_ids.shape[0]

        input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
        input_ids = input_ids_full[(input_ids_full >= 0)].reshape((batch_size, -1))
        input_ids = np.concatenate(
            [input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1
        )
        return input_ids

    def random_spans_noise_mask(self, length):

        """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .

        Noise mask consisting of random spans of noise tokens.
        The number of noise tokens and the number of noise spans and non-noise spans
        are determined deterministically as follows:
        num_noise_tokens = round(length * noise_density)
        num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
        Spans alternate between non-noise and noise, beginning with non-noise.
        Subject to the above restrictions, all masks are equally likely.

        Args:
            length: an int32 scalar (length of the incoming token sequence)
            noise_density: a float - approximate density of output mask
            mean_noise_span_length: a number

        Returns:
            a boolean tensor with shape [length]
        """

        orig_length = length

        num_noise_tokens = int(np.round(length * self.noise_density))
        # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
        num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
        num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length))

        # avoid degeneracy by ensuring positive number of noise spans
        num_noise_spans = max(num_noise_spans, 1)
        num_nonnoise_tokens = length - num_noise_tokens

        # pick the lengths of the noise spans and the non-noise spans
        def _random_segmentation(num_items, num_segments):
            """Partition a sequence of items randomly into non-empty segments.
            Args:
                num_items: an integer scalar > 0
                num_segments: an integer scalar in [1, num_items]
            Returns:
                a Tensor with shape [num_segments] containing positive integers that add
                up to num_items
            """
            mask_indices = np.arange(num_items - 1) < (num_segments - 1)
            np.random.shuffle(mask_indices)
            first_in_segment = np.pad(mask_indices, [[1, 0]])
            segment_id = np.cumsum(first_in_segment)
            # count length of sub segments assuming that list is sorted
            _, segment_length = np.unique(segment_id, return_counts=True)
            return segment_length

        noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
        nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans)

        interleaved_span_lengths = np.reshape(
            np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2]
        )
        span_starts = np.cumsum(interleaved_span_lengths)[:-1]
        span_start_indicator = np.zeros((length,), dtype=np.int8)
        span_start_indicator[span_starts] = True
        span_num = np.cumsum(span_start_indicator)
        is_noise = np.equal(span_num % 2, 1)

        return is_noise[:orig_length]


def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
    num_samples = len(samples_idx)
    samples_to_remove = num_samples % batch_size

    if samples_to_remove != 0:
        samples_idx = samples_idx[:-samples_to_remove]
    sections_split = num_samples // batch_size
    batch_idx = np.split(samples_idx, sections_split)
    return batch_idx


def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):
    """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>`__ .

    Training parameters to avoid padding with random_spans_noise_mask.
    When training a model with random_spans_noise_mask, we would like to set the other
    training hyperparmeters in a way that avoids padding.
    This function helps us compute these hyperparameters.
    We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens,
    and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens.
    This function tells us the required number of tokens in the raw example (for split_tokens())
    as well as the length of the encoded targets. Note that this function assumes
    the inputs and targets will have EOS appended and includes that in the reported length.

    Args:
        inputs_length: an integer - desired length of the tokenized inputs sequence
        noise_density: a float
        mean_noise_span_length: a float
    Returns:
        tokens_length: length of original text in tokens
        targets_length: an integer - length in tokens of encoded targets sequence
    """

    def _tokens_length_to_inputs_length_targets_length(tokens_length):
        num_noise_tokens = int(round(tokens_length * noise_density))
        num_nonnoise_tokens = tokens_length - num_noise_tokens
        num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
        # inputs contain all nonnoise tokens, sentinels for all noise spans
        # and one EOS token.
        _input_length = num_nonnoise_tokens + num_noise_spans + 1
        _output_length = num_noise_tokens + num_noise_spans + 1
        return _input_length, _output_length

    tokens_length = inputs_length

    while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:
        tokens_length += 1

    inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)

    # minor hack to get the targets length to be equal to inputs length
    # which is more likely to have been set to a nice round number.
    if noise_density == 0.5 and targets_length > inputs_length:
        tokens_length -= 1
        targets_length -= 1
    return tokens_length, targets_length


def tokenize_function(examples, tokenizer, text_column_name):
    return tokenizer(examples[text_column_name], return_attention_mask=False)


# Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
def group_texts(examples, expanded_inputs_length):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= expanded_inputs_length:
        total_length = (total_length // expanded_inputs_length) * expanded_inputs_length
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length)]
        for k, t in concatenated_examples.items()
    }
    return result


In [2]:
# from transformers import (
#     MT5ForConditionalGeneration,
#     T5Tokenizer,
#     AdamW,
#     get_linear_schedule_with_warmup
# )
# import torch
# from datasets import load_dataset
# import os
# import json
# from pathlib import Path
# import numpy as np
# from tqdm import tqdm, trange
# from flax.training.common_utils import shard
# from enum import Enum
# from typing import Optional, Tuple
# from fire import Fire
# from math import floor
# import uuid

# #import mt5_utils


# class mt5PerplexityExperiments:

#     def __init__(
#         self,
#         model_id: Enum = 'google/mt5-base',
#         device: Enum = 'cuda:0',
#     ):
#         self.device = device
#         self.model = MT5ForConditionalGeneration.from_pretrained(model_id).to(device)
#         self.tokenizer = T5Tokenizer.from_pretrained(model_id)
        
#         self.log_dict = {}

#     def get_tokenized_dataset(self, datasets, column_name):
#         max_seq_length = min(self.max_seq_length, self.tokenizer.model_max_length)
#         column_names = datasets[column_name].column_names
#         text_column_name = "text" if "text" in column_names else column_names[0]
        
#         tokenized_datasets = datasets.map(
#             lambda x: tokenize_function(x, tokenizer=self.tokenizer, text_column_name=text_column_name),
#             batched=True,
#             num_proc=self.num_proc,
#             remove_columns=column_names
#         )
#         expanded_inputs_length, targets_length = compute_input_and_target_lengths(
#             inputs_length=self.max_seq_length,
#             noise_density=self.mlm_probability,
#             mean_noise_span_length=self.mean_noise_span_length,
#         )

#         data_collator = FlaxDataCollatorForT5MLM(
#             tokenizer=self.tokenizer,
#             noise_density=self.mlm_probability,
#             mean_noise_span_length=self.mean_noise_span_length,
#             input_length=max_seq_length,
#             target_length=targets_length,
#             pad_token_id=self.model.config.pad_token_id,
#             decoder_start_token_id=self.model.config.decoder_start_token_id,
#         )

#         tokenized_datasets = tokenized_datasets.map(
#             lambda x: group_texts(x, expanded_inputs_length=expanded_inputs_length),
#             batched=True,
#             num_proc=self.num_proc,
#         )
#         return tokenized_datasets, data_collator


#     def training(
#         self,
#         train_valid_dir: os.PathLike,
#         max_dataset_len: int = 500000,
#         train_size: float = 0.9,
#         n_epochs: int = 5,
#         learning_rate: float = 0.005,
#         num_warmup_steps: int = 2000,
#         weight_decay: float = 0.001,
#         betas: Tuple[float, float] = [0.9, 0.999],
#         max_seq_length: int = 256,
#         per_device_batch_size: int = 64,
#         mlm_probability: float = 0.15,
#         mean_noise_span_length: int = 3,
#         num_proc: Optional[int] = None,
#     ):
#         self.max_seq_length = max_seq_length
#         self.per_device_batch_size = per_device_batch_size
#         self.mlm_probability = mlm_probability
#         self.mean_noise_span_length = mean_noise_span_length
#         self.num_proc = num_proc

#         log_params = {
#             "train_valid_dir":train_valid_dir,
#             "train_size":train_size,
#             "n_epochs":n_epochs,
#             "learning_rate":learning_rate,
#             "num_warmup_steps":num_warmup_steps,
#             "weight_decay":weight_decay,
#             "betas":betas,
#             "max_seq_length":max_seq_length,
#             "per_device_batch_size":per_device_batch_size,
#             "mlm_probability":mlm_probability,
#             "mean_noise_span_length":mean_noise_span_length,
#             "num_proc":num_proc
#         }
#         random_seed = uuid.uuid4()
#         save_folder = f'mt5_experiments/training_on_{Path(train_valid_dir).name}/{random_seed}'

#         if not os.path.exists(save_folder):
#             os.makedirs(save_folder)
        
#         params_filename = Path(save_folder, "params.json")
#         log_filename = Path(save_folder, "log_results.txt")
#         with open(params_filename, "w") as outfile:
#             json.dump(log_params, outfile, indent=4)

#         train_val_paths = [str(Path(train_valid_dir, i)) for i in os.listdir(train_valid_dir)]
#         dataset = load_dataset('text', data_files=train_val_paths, split='train')

#         dataset_limit = min(len(dataset), max_dataset_len)
#         data_indices = np.random.choice(len(dataset), dataset_limit)
#         cutted_dataset = dataset.select(data_indices)
#         datasets = cutted_dataset.train_test_split(test_size=1-train_size)
#         column_name = 'train'

#         train_tokenized_datasets, train_data_collator = self.get_tokenized_dataset(datasets, column_name)
#         num_train_samples = len(train_tokenized_datasets[column_name])
#         train_batch_idx = generate_batch_splits(
#             np.arange(num_train_samples),
#             self.per_device_batch_size
#             )
        
#         num_train_steps = len(train_tokenized_datasets["train"]) // self.per_device_batch_size * n_epochs
        
#         optimizer = AdamW(
#             self.model.parameters(),
#             lr=learning_rate,
#             weight_decay = weight_decay,
#             betas = betas
#             )

#         scheduler = get_linear_schedule_with_warmup(
#                 optimizer,
#                 num_warmup_steps=num_warmup_steps,
#                 num_training_steps=num_train_steps
#                 )
        
#         self.log_dict = {"train": [], "val": []}
#         for epoch in trange(n_epochs):
#             # ======================== Training ================================
#             train_losses_epoch = []

#             step = int(len(train_batch_idx) * 0.05)
#             for i, batch_idx in tqdm(enumerate(train_batch_idx), desc='Training...', leave=True, total=len(train_batch_idx)):
#                 self.model.train()
#                 f = open(log_filename, 'a+')
       
#                 samples = [train_tokenized_datasets["train"][int(idx)] for idx in batch_idx]
#                 model_inputs = train_data_collator(samples)
#                 model_inputs = shard(model_inputs.data)

#                 input_ids = torch.LongTensor(model_inputs['input_ids']).to(self.device)
#                 # decoder_input_ids = torch.LongTensor(model_inputs['decoder_input_ids']).to(self.device)
#                 labels = torch.LongTensor(model_inputs['labels']).to(self.device)
                
#                 input_ids_size = input_ids.size()
#                 labels_size = labels.size()
#                 input_ids = input_ids.reshape([input_ids_size[0], input_ids_size[1] * input_ids_size[2]])
#                 labels = labels.reshape([labels_size[0], labels_size[1] * labels_size[2]])
                
#                 optimizer.zero_grad()
                
#                 loss = self.model(
#                     input_ids=torch.squeeze(input_ids, 0),
#                     labels=torch.squeeze(labels, 0)
#                 )
#                 train_losses_epoch.append(loss.loss.item())
#                 loss.loss.backward()
#                 optimizer.step()
#                 scheduler.step()

#                 # ======================== Evaluating ==============================
#                 if i % step == 0 and i > 0:
#                     perp_train = np.exp(np.mean(train_losses_epoch))
#                     train_msg = f'TRAIN ITERATION: {i}\t FOR {train_valid_dir} \t Perplexity = {perp_train}\n'
#                     print(train_msg)
#                     f.write(train_msg)
                    
#                     self.log_dict["train"].append(perp_train)

#                     self.model.eval()

#                     with torch.no_grad():
#                         column_name = 'test'
#                         val_tokenized_datasets, val_data_collator = self.get_tokenized_dataset(datasets, column_name)
#                         num_val_samples = len(val_tokenized_datasets[column_name])
#                         val_batch_idx = generate_batch_splits(
#                             np.arange(num_val_samples),
#                             self.per_device_batch_size
#                             )
#                         val_losses_epoch = []
#                         for batch_idx in tqdm(val_batch_idx, desc='Validation...', leave=True):
#                             samples = [val_tokenized_datasets[column_name][int(idx)] for idx in batch_idx]
#                             model_inputs = val_data_collator(samples)
#                             model_inputs = shard(model_inputs.data)

#                             input_ids = torch.LongTensor(model_inputs['input_ids']).to(self.device)
#                             # decoder_input_ids = torch.LongTensor(model_inputs['decoder_input_ids']).to(self.device)
#                             labels = torch.LongTensor(model_inputs['labels']).to(self.device)

#                             input_ids_size = input_ids.size()
#                             labels_size = labels.size()
#                             input_ids = input_ids.reshape([input_ids_size[0], input_ids_size[1] * input_ids_size[2]])
#                             labels = labels.reshape([labels_size[0], labels_size[1] * labels_size[2]])
#                             loss = self.model(
#                                 input_ids=torch.squeeze(input_ids, 0),
#                                 labels=torch.squeeze(labels, 0)
#                             )
#                             val_losses_epoch.append(loss.loss.item())
                        
#                         perp_val = np.exp(np.mean(val_losses_epoch))
#                         val_msg = f'VALIDATION ITERATION: {i}\t FOR {train_valid_dir} \t Perplexity = {perp_val}\n'
#                         print(val_msg)
#                         f.write(val_msg)
                        
#                         self.log_dict["val"].append(perp_val)
#                         f.close()
#                         #torch.save(self.model.state_dict(), Path(save_folder, f'epoch_{epoch}_iteration_{i}.pt'))

#     def testing(
#         self,
#         test_dir: os.PathLike,
#         max_seq_length: int = 256,
#         per_device_batch_size: int = 64,
#         mlm_probability: float = 0.15,
#         mean_noise_span_length: int = 3,
#         num_proc: Optional[int] = None,
#         checkpoint_path: Optional[str] = None
#     ):
#         self.max_seq_length = max_seq_length
#         self.per_device_batch_size = per_device_batch_size
#         self.mlm_probability = mlm_probability
#         self.mean_noise_span_length = mean_noise_span_length
#         self.num_proc = num_proc
        
#         if checkpoint_path is not None:
#             self.model.load_state_dict(torch.load(checkpoint_path, map_location=self.device))
    
#         test_paths = [str(Path(test_dir, i)) for i in os.listdir(test_dir)]
#         datasets = load_dataset('text', data_files=test_paths)

#         self.model.eval()

#         with torch.no_grad():
#             column_name = 'train'
#             test_tokenized_datasets, test_data_collator = self.get_tokenized_dataset(datasets, column_name)
#             num_test_samples = len(test_tokenized_datasets[column_name])
#             test_batch_idx = generate_batch_splits(
#                 np.arange(num_test_samples),
#                 self.per_device_batch_size
#                 )
#             test_losses = []
#             for batch_idx in tqdm(test_batch_idx, desc='Testing...', leave=True):
#                 samples = [test_tokenized_datasets[column_name][int(idx)] for idx in batch_idx]
#                 model_inputs = test_data_collator(samples)
#                 model_inputs = shard(model_inputs.data)

#                 input_ids = torch.LongTensor(model_inputs['input_ids']).to(self.device)
#                 # decoder_input_ids = torch.LongTensor(model_inputs['decoder_input_ids']).to(self.device)
#                 labels = torch.LongTensor(model_inputs['labels']).to(self.device)

                
#                 input_ids_size = input_ids.size()
#                 labels_size = labels.size()
#                 input_ids = input_ids.reshape([input_ids_size[0], input_ids_size[1] * input_ids_size[2]])
#                 labels = labels.reshape([labels_size[0], labels_size[1] * labels_size[2]])
#                 loss = self.model(
#                     input_ids=torch.squeeze(input_ids, 0),
#                     labels=torch.squeeze(labels, 0)
#                 )
#                 test_losses.append(loss.loss.item())
            
#             test_msg = (f'TEST: For {test_dir} \t Perplexity = {np.exp(np.mean(test_losses))}\n')
#             print(test_msg)
#             return np.exp(np.mean(test_losses))


# def main(
#     train_valid_dir: Optional[os.PathLike] = None,
#     max_dataset_len: int = 500000,
#     train_size: float = 0.9,
#     n_epochs: int = 5,
#     learning_rate: float = 0.005,
#     num_warmup_steps: int = 2000,
#     weight_decay: float = 0.001,
#     betas: Tuple[float, float] = [0.9, 0.999],
#     test_dir: Optional[os.PathLike] = None,
#     model_id: Enum = 'google/mt5-base',
#     device: Enum = 'cuda:0',
#     max_seq_length: int = 256,
#     per_device_batch_size: int = 64,
#     mlm_probability: float = 0.15,
#     mean_noise_span_length: int = 3,
#     num_proc: Optional[int] = None
# ):
#     initialize_experiments = mt5PerplexityExperiments(
#         model_id,
#         device,
#     )
#     if train_valid_dir is not None:
#         initialize_experiments.training(
#             train_valid_dir,
#             max_dataset_len,
#             train_size,
#             n_epochs,
#             learning_rate,
#             num_warmup_steps,
#             weight_decay,
#             betas,
#             max_seq_length,
#             per_device_batch_size,
#             mlm_probability,
#             mean_noise_span_length,
#             num_proc
#         )

#     if test_dir is not None:
#         initialize_experiments.testing(
#             test_dir,
#             max_seq_length,
#             per_device_batch_size,
#             mlm_probability,
#             mean_noise_span_length,
#             num_proc
#         )



In [8]:
from transformers import (
    MT5ForConditionalGeneration,
    T5Tokenizer,
    AdamW,
    get_linear_schedule_with_warmup
)
import torch
from tqdm import tqdm
from datasets import load_dataset
import os
import json
from pathlib import Path
import numpy as np
from tqdm import tqdm, trange
from flax.training.common_utils import shard
from enum import Enum
from typing import Optional, Tuple
from fire import Fire
from math import floor
import uuid


class mt5PerplexityExperiments:

    def __init__(
        self,
        model_id: Enum = 'google/mt5-base',
        device: Enum = 'cuda:0',
    ):
        self.device = device
        self.model = MT5ForConditionalGeneration.from_pretrained(model_id).to(device)
        self.tokenizer = T5Tokenizer.from_pretrained(model_id)
        
        self.log_dict = {}

    def get_tokenized_dataset(self, datasets, column_name):
        max_seq_length = min(self.max_seq_length, self.tokenizer.model_max_length)
        column_names = datasets[column_name].column_names
        text_column_name = "text" if "text" in column_names else column_names[0]
        
        tokenized_datasets = datasets.map(
            lambda x: tokenize_function(x, tokenizer=self.tokenizer, text_column_name=text_column_name),
            batched=True,
            num_proc=self.num_proc,
            remove_columns=column_names
        )
        expanded_inputs_length, targets_length = compute_input_and_target_lengths(
            inputs_length=self.max_seq_length,
            noise_density=self.mlm_probability,
            mean_noise_span_length=self.mean_noise_span_length,
        )

        data_collator = FlaxDataCollatorForT5MLM(
            tokenizer=self.tokenizer,
            noise_density=self.mlm_probability,
            mean_noise_span_length=self.mean_noise_span_length,
            input_length=max_seq_length,
            target_length=targets_length,
            pad_token_id=self.model.config.pad_token_id,
            decoder_start_token_id=self.model.config.decoder_start_token_id,
        )

        tokenized_datasets = tokenized_datasets.map(
            lambda x: group_texts(x, expanded_inputs_length=expanded_inputs_length),
            batched=True,
            num_proc=self.num_proc,
        )
        return tokenized_datasets, data_collator


    def training(
        self,
        train_valid_dir: os.PathLike,
        max_dataset_len: int = 500000,
        train_size: float = 0.9,
        n_epochs: int = 5,
        learning_rate: float = 0.005,
        num_warmup_steps: int = 2000,
        weight_decay: float = 0.001,
        betas: Tuple[float, float] = [0.9, 0.999],
        max_seq_length: int = 256,
        per_device_batch_size: int = 64,
        mlm_probability: float = 0.15,
        mean_noise_span_length: int = 3,
        num_proc: Optional[int] = None,
        lr_languages_to_test = None
    ):
        self.max_seq_length = max_seq_length
        self.per_device_batch_size = per_device_batch_size
        self.mlm_probability = mlm_probability
        self.mean_noise_span_length = mean_noise_span_length
        self.num_proc = num_proc

        log_params = {
            "train_valid_dir":train_valid_dir,
            "train_size":train_size,
            "n_epochs":n_epochs,
            "learning_rate":learning_rate,
            "num_warmup_steps":num_warmup_steps,
            "weight_decay":weight_decay,
            "betas":betas,
            "max_seq_length":max_seq_length,
            "per_device_batch_size":per_device_batch_size,
            "mlm_probability":mlm_probability,
            "mean_noise_span_length":mean_noise_span_length,
            "num_proc":num_proc
        }
        random_seed = uuid.uuid4()
        self.save_folder = f'mt5_experiments/training_on_{Path(train_valid_dir).name}/{random_seed}'

        if not os.path.exists(self.save_folder):
            os.makedirs(self.save_folder)
        
        params_filename = Path(self.save_folder, "params.json")
        log_filename = Path(self.save_folder, "log_results.txt")
        log_errors = Path(self.save_folder, "log_errors.txt")
        new_log_path = Path(self.save_folder, "new_log.json")
        with open(params_filename, "w+") as outfile:
            json.dump(log_params, outfile, indent=4)

        train_val_paths = [str(Path(train_valid_dir, i)) for i in os.listdir(train_valid_dir)]
        dataset = load_dataset('text', data_files=train_val_paths, split='train')

        dataset_limit = min(len(dataset), max_dataset_len)
        data_indices = np.random.choice(len(dataset), dataset_limit)
        cutted_dataset = dataset.select(data_indices)
        datasets = cutted_dataset.train_test_split(test_size=1-train_size)
        column_name = 'train'

        train_tokenized_datasets, train_data_collator = self.get_tokenized_dataset(datasets, column_name)
        num_train_samples = len(train_tokenized_datasets[column_name])
        train_batch_idx = generate_batch_splits(
            np.arange(num_train_samples),
            self.per_device_batch_size
            )
        
        num_train_steps = len(train_tokenized_datasets["train"]) // self.per_device_batch_size * n_epochs
        
        optimizer = AdamW(
            self.model.parameters(),
            lr=learning_rate,
            weight_decay = weight_decay,
            betas = betas
            )

        scheduler = get_linear_schedule_with_warmup(
                optimizer,
                num_warmup_steps=num_warmup_steps,
                num_training_steps=num_train_steps
                )
        
        self.log_dict = {"train": [], "val": [], "test": {}}
        for epoch in trange(n_epochs):
            # ======================== Training ================================
            train_losses_epoch = []

            step = int(len(train_batch_idx) * 0.05)
            for i, batch_idx in tqdm(enumerate(train_batch_idx), desc='Training...', leave=True, total=len(train_batch_idx)):
                with open(str(new_log_path), "w") as outfile:
                    json.dump(self.log_dict, outfile)
                
                
                torch.cuda.empty_cache()
                gc.collect()
                self.model.train()
                f = open(log_filename, 'a+')
                f_error = open(log_errors, 'a+')
       
                samples = [train_tokenized_datasets["train"][int(idx)] for idx in batch_idx]
                model_inputs = train_data_collator(samples)
                model_inputs = shard(model_inputs.data)

                input_ids = torch.LongTensor(model_inputs['input_ids']).to(self.device)
                labels = torch.LongTensor(model_inputs['labels']).to(self.device)
                
                input_ids_size = input_ids.size()
                labels_size = labels.size()
                input_ids = input_ids.reshape([input_ids_size[0], input_ids_size[1] * input_ids_size[2]])
                labels = labels.reshape([labels_size[0], labels_size[1] * labels_size[2]])
                
                optimizer.zero_grad()
                loss = self.model(
                    input_ids=torch.squeeze(input_ids, 0),
                    labels=torch.squeeze(labels, 0)
                )
                train_losses_epoch.append(loss.loss.item())
                loss.loss.backward()
                optimizer.step()
                scheduler.step()

                # ======================== Evaluating ==============================
                if i % step == 0 and i > 0:
                    perp_train = np.exp(np.mean(train_losses_epoch))
                    train_msg = f'TRAIN ITERATION: {i}\t FOR {train_valid_dir} \t Perplexity = {perp_train}\n'
                    print(train_msg)
                    f.write(train_msg)
                    
                    self.log_dict["train"].append(perp_train)

                    self.model.eval()
                    with torch.no_grad():
                        column_name = 'test'
                        val_tokenized_datasets, val_data_collator = self.get_tokenized_dataset(datasets, column_name)
                        num_val_samples = len(val_tokenized_datasets[column_name])
                        val_batch_idx = generate_batch_splits(
                            np.arange(num_val_samples),
                            self.per_device_batch_size
                            )
                        val_losses_epoch = []
                        for batch_idx in tqdm(val_batch_idx, desc='Validation...', leave=True):
                            samples = [val_tokenized_datasets[column_name][int(idx)] for idx in batch_idx]
                            model_inputs = val_data_collator(samples)
                            model_inputs = shard(model_inputs.data)

                            input_ids = torch.LongTensor(model_inputs['input_ids']).to(self.device)
                            labels = torch.LongTensor(model_inputs['labels']).to(self.device)

                            input_ids_size = input_ids.size()
                            labels_size = labels.size()
                            input_ids = input_ids.reshape([input_ids_size[0], input_ids_size[1] * input_ids_size[2]])
                            labels = labels.reshape([labels_size[0], labels_size[1] * labels_size[2]])
                            loss = self.model(
                                input_ids=torch.squeeze(input_ids, 0),
                                labels=torch.squeeze(labels, 0)
                            )
                            val_losses_epoch.append(loss.loss.item())
                        
                        perp_val = np.exp(np.mean(val_losses_epoch))
                        val_msg = f'VALIDATION ITERATION: {i}\t FOR {train_valid_dir} \t Perplexity = {perp_val}\n'
                        print(val_msg)
                        f.write(val_msg)
                        
                        self.log_dict["val"].append(perp_val)
                        f.close()
                        #torch.save(self.model.state_dict(), Path(save_folder, f'epoch_{epoch}_iteration_{i}.pt'))
                        
                        torch.cuda.empty_cache()
                        gc.collect()
                        for lr_lang in tqdm(lr_languages_to_test):
                            lang_folder_path = str(Path(Path(train_valid_dir).parent, lr_lang))
                            try:
                                if lang_folder_path not in self.log_dict["test"]:
                                    self.log_dict["test"][lang_folder_path] = []
                                self.log_dict["test"][lang_folder_path].append(self.testing(lang_folder_path))
                            except:
                                f_error.write(f"Something went wrong during processing: {lang_folder_path}\n")


    def testing(
        self,
        test_dir: os.PathLike,
        max_seq_length: int = 256,
        per_device_batch_size: int = 64,
        mlm_probability: float = 0.15,
        mean_noise_span_length: int = 3,
        num_proc: Optional[int] = None,
        checkpoint_path: Optional[str] = None
    ):
        self.max_seq_length = max_seq_length
        self.per_device_batch_size = per_device_batch_size
        self.mlm_probability = mlm_probability
        self.mean_noise_span_length = mean_noise_span_length
        self.num_proc = num_proc
        
        if checkpoint_path is not None:
            self.model.load_state_dict(torch.load(checkpoint_path, map_location=self.device))
    
        test_paths = [str(Path(test_dir, i)) for i in os.listdir(test_dir)]
        datasets = load_dataset('text', data_files=test_paths)

        self.model.eval()
        with torch.no_grad():
            column_name = 'train'
            test_tokenized_datasets, test_data_collator = self.get_tokenized_dataset(datasets, column_name)
            num_test_samples = len(test_tokenized_datasets[column_name])
            test_batch_idx = generate_batch_splits(
                np.arange(num_test_samples),
                self.per_device_batch_size
                )
            test_losses = []
            for batch_idx in tqdm(test_batch_idx, desc='Testing...', leave=True):
                samples = [test_tokenized_datasets[column_name][int(idx)] for idx in batch_idx]
                model_inputs = test_data_collator(samples)
                model_inputs = shard(model_inputs.data)

                input_ids = torch.LongTensor(model_inputs['input_ids']).to(self.device)
                labels = torch.LongTensor(model_inputs['labels']).to(self.device)

                input_ids_size = input_ids.size()
                labels_size = labels.size()
                input_ids = input_ids.reshape([input_ids_size[0], input_ids_size[1] * input_ids_size[2]])
                labels = labels.reshape([labels_size[0], labels_size[1] * labels_size[2]])
                loss = self.model(
                    input_ids=torch.squeeze(input_ids, 0),
                    labels=torch.squeeze(labels, 0)
                )
                test_losses.append(loss.loss.item())
            
            # test_msg = (f'TEST: For {test_dir} \t Perplexity = {np.exp(np.mean(test_losses))}\n')
            # print(test_msg)
            perp_test = np.exp(np.mean(test_losses))
            return perp_test


# HR languages

In [17]:
import pandas as pd
import gc
import torch
from tqdm import tqdm
df = pd.read_csv("multilingual/data/Collected_langs.csv")

In [18]:
hr_languages = df[(df.N_tokens >= 350000)].Name.tolist()
lr_languages = df[(df.N_tokens > 10000) & (df.N_tokens < 350000)].Name.tolist()

In [11]:
dataset_folder = "/home/jovyan/datasets/XL_Dataset/"

In [12]:
init = mt5PerplexityExperiments(device='cuda:0')

for hr_lang in tqdm(hr_languages[:len(hr_languages) // 8], desc="HR lang training"):
    init.training(
        train_valid_dir=dataset_folder+hr_lang,
        per_device_batch_size=16,
        n_epochs=1,
        lr_languages_to_test = lr_languages
    )
    
    params_filename = Path(self.save_folder, "new_log.json")
    
    with open(str(params_filename), "w") as outfile:
        json.dump(self.log_dict, outfile)
        
    self.log_dict.clear()

KeyboardInterrupt: 

In [None]:
init = mt5PerplexityExperiments(device='cuda:0')
hr_lang = "Akan"
init.training(
    train_valid_dir=dataset_folder+hr_lang,
    per_device_batch_size=16,
    n_epochs=1,
    lr_languages_to_test = lr_languages[:5]
)



  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?it/s]
Training...:   0%|          | 0/61 [00:00<?, ?it/s][A
Training...:   2%|▏         | 1/61 [00:05<05:20,  5.35s/it][A
Training...:   3%|▎         | 2/61 [00:09<04:24,  4.49s/it][A
Training...:   5%|▍         | 3/61 [00:13<04:26,  4.59s/it][A

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]



Validation...:   0%|          | 0/6 [00:00<?, ?it/s][A[A

Validation...:  17%|█▋        | 1/6 [00:00<00:00,  7.31it/s][A[A

Validation...:  33%|███▎      | 2/6 [00:00<00:00,  7.73it/s][A[A

Validation...:  50%|█████     | 3/6 [00:00<00:00,  8.07it/s][A[A

Validation...:  67%|██████▋   | 4/6 [00:00<00:00,  8.25it/s][A[A

Validation...:  83%|████████▎ | 5/6 [00:00<00:00,  8.35it/s][A[A

Validation...: 100%|██████████| 6/6 [00:00<00:00,  8.22it/s][A[A




  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?ba/s]

  0%|          | 0/13 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/17 [00:00<?, ?it/s][A[A[A


Testing...:   6%|▌         | 1/17 [00:00<00:12,  1.31it/s][A[A[A


Testing...:  12%|█▏        | 2/17 [00:01<00:10,  1.37it/s][A[A[A


Testing...:  18%|█▊        | 3/17 [00:02<00:10,  1.38it/s][A[A[A


Testing...:  24%|██▎       | 4/17 [00:02<00:09,  1.39it/s][A[A[A


Testing...:  29%|██▉       | 5/17 [00:03<00:08,  1.40it/s][A[A[A


Testing...:  35%|███▌      | 6/17 [00:04<00:07,  1.40it/s][A[A[A


Testing...:  41%|████      | 7/17 [00:05<00:07,  1.40it/s][A[A[A


Testing...:  47%|████▋     | 8/17 [00:05<00:06,  1.40it/s][A[A[A


Testing...:  53%|█████▎    | 9/17 [00:06<00:05,  1.40it/s][A[A[A


Testing...:  59%|█████▉    | 10/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  65%|██████▍   | 11/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  71%|███████   | 12/17 [00:08<00:03,  1.41it/s][A[A[A


Testing...:  76%|███████▋  | 13/17 [00:09<00:02,  1.41it/s][A[A[A


Testing...:  82%|████

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/6 [00:00<?, ?it/s][A[A[A


Testing...:  17%|█▋        | 1/6 [00:00<00:03,  1.36it/s][A[A[A


Testing...:  33%|███▎      | 2/6 [00:01<00:02,  1.39it/s][A[A[A


Testing...:  50%|█████     | 3/6 [00:02<00:02,  1.40it/s][A[A[A


Testing...:  67%|██████▋   | 4/6 [00:02<00:01,  1.40it/s][A[A[A


Testing...:  83%|████████▎ | 5/6 [00:03<00:00,  1.40it/s][A[A[A


Testing...: 100%|██████████| 6/6 [00:04<00:00,  1.40it/s][A[A[A




  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/4 [00:00<?, ?it/s][A[A[A


Testing...:  25%|██▌       | 1/4 [00:00<00:02,  1.38it/s][A[A[A


Testing...:  50%|█████     | 2/4 [00:01<00:01,  1.40it/s][A[A[A


Testing...:  75%|███████▌  | 3/4 [00:02<00:00,  1.40it/s][A[A[A


Testing...: 100%|██████████| 4/4 [00:02<00:00,  1.40it/s][A[A[A




  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?ba/s]

  0%|          | 0/15 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/17 [00:00<?, ?it/s][A[A[A


Testing...:   6%|▌         | 1/17 [00:00<00:11,  1.35it/s][A[A[A


Testing...:  12%|█▏        | 2/17 [00:01<00:10,  1.38it/s][A[A[A


Testing...:  18%|█▊        | 3/17 [00:02<00:10,  1.39it/s][A[A[A


Testing...:  24%|██▎       | 4/17 [00:02<00:09,  1.40it/s][A[A[A


Testing...:  29%|██▉       | 5/17 [00:03<00:08,  1.40it/s][A[A[A


Testing...:  35%|███▌      | 6/17 [00:04<00:07,  1.40it/s][A[A[A


Testing...:  41%|████      | 7/17 [00:05<00:07,  1.41it/s][A[A[A


Testing...:  47%|████▋     | 8/17 [00:05<00:06,  1.41it/s][A[A[A


Testing...:  53%|█████▎    | 9/17 [00:06<00:05,  1.41it/s][A[A[A


Testing...:  59%|█████▉    | 10/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  65%|██████▍   | 11/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  71%|███████   | 12/17 [00:08<00:03,  1.41it/s][A[A[A


Testing...:  76%|███████▋  | 13/17 [00:09<00:02,  1.41it/s][A[A[A


Testing...:  82%|████

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/9 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/8 [00:00<?, ?it/s][A[A[A


Testing...:  12%|█▎        | 1/8 [00:00<00:05,  1.37it/s][A[A[A


Testing...:  25%|██▌       | 2/8 [00:01<00:04,  1.39it/s][A[A[A


Testing...:  38%|███▊      | 3/8 [00:02<00:03,  1.40it/s][A[A[A


Testing...:  50%|█████     | 4/8 [00:02<00:02,  1.40it/s][A[A[A


Testing...:  62%|██████▎   | 5/8 [00:03<00:02,  1.41it/s][A[A[A


Testing...:  75%|███████▌  | 6/8 [00:04<00:01,  1.41it/s][A[A[A


Testing...:  88%|████████▊ | 7/8 [00:04<00:00,  1.41it/s][A[A[A


Testing...: 100%|██████████| 8/8 [00:05<00:00,  1.40it/s][A[A[A


100%|██████████| 5/5 [00:58<00:00, 11.77s/it][A[A

Training...:   7%|▋         | 4/61 [01:27<30:06, 31.70s/it][A
Training...:   8%|▊         | 5/61 [01:30<19:59, 21.41s/it][A
Training...:  10%|▉         | 6/61 [01:34<14:20, 15.64s/it][A

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]



Validation...:   0%|          | 0/1 [00:00<?, ?it/s][A[A

Validation...: 100%|██████████| 1/1 [00:00<00:00,  1.33it/s][A[A




  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?ba/s]

  0%|          | 0/13 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/17 [00:00<?, ?it/s][A[A[A


Testing...:   6%|▌         | 1/17 [00:00<00:11,  1.34it/s][A[A[A


Testing...:  12%|█▏        | 2/17 [00:01<00:10,  1.38it/s][A[A[A


Testing...:  18%|█▊        | 3/17 [00:02<00:10,  1.39it/s][A[A[A


Testing...:  24%|██▎       | 4/17 [00:02<00:09,  1.40it/s][A[A[A


Testing...:  29%|██▉       | 5/17 [00:03<00:08,  1.39it/s][A[A[A


Testing...:  35%|███▌      | 6/17 [00:04<00:07,  1.40it/s][A[A[A


Testing...:  41%|████      | 7/17 [00:05<00:07,  1.40it/s][A[A[A


Testing...:  47%|████▋     | 8/17 [00:05<00:06,  1.40it/s][A[A[A


Testing...:  53%|█████▎    | 9/17 [00:06<00:05,  1.40it/s][A[A[A


Testing...:  59%|█████▉    | 10/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  65%|██████▍   | 11/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  71%|███████   | 12/17 [00:08<00:03,  1.41it/s][A[A[A


Testing...:  76%|███████▋  | 13/17 [00:09<00:02,  1.41it/s][A[A[A


Testing...:  82%|████

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/6 [00:00<?, ?it/s][A[A[A


Testing...:  17%|█▋        | 1/6 [00:00<00:03,  1.38it/s][A[A[A


Testing...:  33%|███▎      | 2/6 [00:01<00:02,  1.39it/s][A[A[A


Testing...:  50%|█████     | 3/6 [00:02<00:02,  1.40it/s][A[A[A


Testing...:  67%|██████▋   | 4/6 [00:02<00:01,  1.40it/s][A[A[A


Testing...:  83%|████████▎ | 5/6 [00:03<00:00,  1.40it/s][A[A[A


Testing...: 100%|██████████| 6/6 [00:04<00:00,  1.40it/s][A[A[A




  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/4 [00:00<?, ?it/s][A[A[A


Testing...:  25%|██▌       | 1/4 [00:00<00:02,  1.36it/s][A[A[A


Testing...:  50%|█████     | 2/4 [00:01<00:01,  1.39it/s][A[A[A


Testing...:  75%|███████▌  | 3/4 [00:02<00:00,  1.40it/s][A[A[A


Testing...: 100%|██████████| 4/4 [00:02<00:00,  1.40it/s][A[A[A




  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?ba/s]

  0%|          | 0/15 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/17 [00:00<?, ?it/s][A[A[A


Testing...:   6%|▌         | 1/17 [00:00<00:11,  1.36it/s][A[A[A


Testing...:  12%|█▏        | 2/17 [00:01<00:10,  1.39it/s][A[A[A


Testing...:  18%|█▊        | 3/17 [00:02<00:10,  1.40it/s][A[A[A


Testing...:  24%|██▎       | 4/17 [00:02<00:09,  1.40it/s][A[A[A


Testing...:  29%|██▉       | 5/17 [00:03<00:08,  1.40it/s][A[A[A


Testing...:  35%|███▌      | 6/17 [00:04<00:07,  1.40it/s][A[A[A


Testing...:  41%|████      | 7/17 [00:04<00:07,  1.41it/s][A[A[A


Testing...:  47%|████▋     | 8/17 [00:05<00:06,  1.41it/s][A[A[A


Testing...:  53%|█████▎    | 9/17 [00:06<00:05,  1.41it/s][A[A[A


Testing...:  59%|█████▉    | 10/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  65%|██████▍   | 11/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  71%|███████   | 12/17 [00:08<00:03,  1.41it/s][A[A[A


Testing...:  76%|███████▋  | 13/17 [00:09<00:02,  1.41it/s][A[A[A


Testing...:  82%|████

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/9 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/8 [00:00<?, ?it/s][A[A[A


Testing...:  12%|█▎        | 1/8 [00:00<00:05,  1.38it/s][A[A[A


Testing...:  25%|██▌       | 2/8 [00:01<00:04,  1.40it/s][A[A[A


Testing...:  38%|███▊      | 3/8 [00:02<00:03,  1.40it/s][A[A[A


Testing...:  50%|█████     | 4/8 [00:02<00:02,  1.40it/s][A[A[A


Testing...:  62%|██████▎   | 5/8 [00:03<00:02,  1.40it/s][A[A[A


Testing...:  75%|███████▌  | 6/8 [00:04<00:01,  1.40it/s][A[A[A


Testing...:  88%|████████▊ | 7/8 [00:05<00:00,  1.40it/s][A[A[A


Testing...: 100%|██████████| 8/8 [00:05<00:00,  1.40it/s][A[A[A


100%|██████████| 5/5 [00:59<00:00, 11.95s/it][A[A

Training...:  11%|█▏        | 7/61 [02:48<31:14, 34.71s/it][A
Training...:  13%|█▎        | 8/61 [02:52<21:58, 24.87s/it][A
Training...:  15%|█▍        | 9/61 [02:57<16:16, 18.78s/it][A

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]



Validation...:   0%|          | 0/1 [00:00<?, ?it/s][A[A

Validation...: 100%|██████████| 1/1 [00:00<00:00,  1.35it/s][A[A




  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?ba/s]

  0%|          | 0/13 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/17 [00:00<?, ?it/s][A[A[A


Testing...:   6%|▌         | 1/17 [00:00<00:12,  1.27it/s][A[A[A


Testing...:  12%|█▏        | 2/17 [00:01<00:11,  1.35it/s][A[A[A


Testing...:  18%|█▊        | 3/17 [00:02<00:10,  1.37it/s][A[A[A


Testing...:  24%|██▎       | 4/17 [00:02<00:09,  1.39it/s][A[A[A


Testing...:  29%|██▉       | 5/17 [00:03<00:08,  1.39it/s][A[A[A


Testing...:  35%|███▌      | 6/17 [00:04<00:07,  1.40it/s][A[A[A


Testing...:  41%|████      | 7/17 [00:05<00:07,  1.40it/s][A[A[A


Testing...:  47%|████▋     | 8/17 [00:05<00:06,  1.40it/s][A[A[A


Testing...:  53%|█████▎    | 9/17 [00:06<00:05,  1.41it/s][A[A[A


Testing...:  59%|█████▉    | 10/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  65%|██████▍   | 11/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  71%|███████   | 12/17 [00:08<00:03,  1.41it/s][A[A[A


Testing...:  76%|███████▋  | 13/17 [00:09<00:02,  1.41it/s][A[A[A


Testing...:  82%|████

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/6 [00:00<?, ?it/s][A[A[A


Testing...:  17%|█▋        | 1/6 [00:00<00:03,  1.36it/s][A[A[A


Testing...:  33%|███▎      | 2/6 [00:01<00:02,  1.39it/s][A[A[A


Testing...:  50%|█████     | 3/6 [00:02<00:02,  1.40it/s][A[A[A


Testing...:  67%|██████▋   | 4/6 [00:02<00:01,  1.40it/s][A[A[A


Testing...:  83%|████████▎ | 5/6 [00:03<00:00,  1.40it/s][A[A[A


Testing...: 100%|██████████| 6/6 [00:04<00:00,  1.40it/s][A[A[A




  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/4 [00:00<?, ?it/s][A[A[A


Testing...:  25%|██▌       | 1/4 [00:00<00:02,  1.36it/s][A[A[A


Testing...:  50%|█████     | 2/4 [00:01<00:01,  1.39it/s][A[A[A


Testing...:  75%|███████▌  | 3/4 [00:02<00:00,  1.40it/s][A[A[A


Testing...: 100%|██████████| 4/4 [00:02<00:00,  1.40it/s][A[A[A




  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?ba/s]

  0%|          | 0/15 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/17 [00:00<?, ?it/s][A[A[A


Testing...:   6%|▌         | 1/17 [00:00<00:11,  1.36it/s][A[A[A


Testing...:  12%|█▏        | 2/17 [00:01<00:10,  1.39it/s][A[A[A


Testing...:  18%|█▊        | 3/17 [00:02<00:10,  1.40it/s][A[A[A


Testing...:  24%|██▎       | 4/17 [00:02<00:09,  1.40it/s][A[A[A


Testing...:  29%|██▉       | 5/17 [00:03<00:08,  1.40it/s][A[A[A


Testing...:  35%|███▌      | 6/17 [00:04<00:07,  1.40it/s][A[A[A


Testing...:  41%|████      | 7/17 [00:04<00:07,  1.41it/s][A[A[A


Testing...:  47%|████▋     | 8/17 [00:05<00:06,  1.41it/s][A[A[A


Testing...:  53%|█████▎    | 9/17 [00:06<00:05,  1.41it/s][A[A[A


Testing...:  59%|█████▉    | 10/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  65%|██████▍   | 11/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  71%|███████   | 12/17 [00:08<00:03,  1.41it/s][A[A[A


Testing...:  76%|███████▋  | 13/17 [00:09<00:02,  1.41it/s][A[A[A


Testing...:  82%|████

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/9 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/8 [00:00<?, ?it/s][A[A[A


Testing...:  12%|█▎        | 1/8 [00:00<00:05,  1.35it/s][A[A[A


Testing...:  25%|██▌       | 2/8 [00:01<00:04,  1.38it/s][A[A[A


Testing...:  38%|███▊      | 3/8 [00:02<00:03,  1.39it/s][A[A[A


Testing...:  50%|█████     | 4/8 [00:02<00:02,  1.40it/s][A[A[A


Testing...:  62%|██████▎   | 5/8 [00:03<00:02,  1.40it/s][A[A[A


Testing...:  75%|███████▌  | 6/8 [00:04<00:01,  1.41it/s][A[A[A


Testing...:  88%|████████▊ | 7/8 [00:05<00:00,  1.41it/s][A[A[A


Testing...: 100%|██████████| 8/8 [00:05<00:00,  1.40it/s][A[A[A


100%|██████████| 5/5 [00:59<00:00, 11.89s/it][A[A

Training...:  16%|█▋        | 10/61 [04:11<30:21, 35.71s/it][A
Training...:  18%|█▊        | 11/61 [04:15<21:38, 25.98s/it][A
Training...:  20%|█▉        | 12/61 [04:20<16:00, 19.59s/it][A

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]



Validation...:   0%|          | 0/1 [00:00<?, ?it/s][A[A

Validation...: 100%|██████████| 1/1 [00:00<00:00,  1.33it/s][A[A




  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?ba/s]

  0%|          | 0/13 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/17 [00:00<?, ?it/s][A[A[A


Testing...:   6%|▌         | 1/17 [00:00<00:11,  1.34it/s][A[A[A


Testing...:  12%|█▏        | 2/17 [00:01<00:10,  1.38it/s][A[A[A


Testing...:  18%|█▊        | 3/17 [00:02<00:10,  1.39it/s][A[A[A


Testing...:  24%|██▎       | 4/17 [00:02<00:09,  1.40it/s][A[A[A


Testing...:  29%|██▉       | 5/17 [00:03<00:08,  1.40it/s][A[A[A


Testing...:  35%|███▌      | 6/17 [00:04<00:07,  1.40it/s][A[A[A


Testing...:  41%|████      | 7/17 [00:05<00:07,  1.40it/s][A[A[A


Testing...:  47%|████▋     | 8/17 [00:05<00:06,  1.40it/s][A[A[A


Testing...:  53%|█████▎    | 9/17 [00:06<00:05,  1.41it/s][A[A[A


Testing...:  59%|█████▉    | 10/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  65%|██████▍   | 11/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  71%|███████   | 12/17 [00:08<00:03,  1.41it/s][A[A[A


Testing...:  76%|███████▋  | 13/17 [00:09<00:02,  1.41it/s][A[A[A


Testing...:  82%|████

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/6 [00:00<?, ?it/s][A[A[A


Testing...:  17%|█▋        | 1/6 [00:00<00:03,  1.35it/s][A[A[A


Testing...:  33%|███▎      | 2/6 [00:01<00:02,  1.38it/s][A[A[A


Testing...:  50%|█████     | 3/6 [00:02<00:02,  1.39it/s][A[A[A


Testing...:  67%|██████▋   | 4/6 [00:02<00:01,  1.40it/s][A[A[A


Testing...:  83%|████████▎ | 5/6 [00:03<00:00,  1.40it/s][A[A[A


Testing...: 100%|██████████| 6/6 [00:04<00:00,  1.40it/s][A[A[A




  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/4 [00:00<?, ?it/s][A[A[A


Testing...:  25%|██▌       | 1/4 [00:00<00:02,  1.35it/s][A[A[A


Testing...:  50%|█████     | 2/4 [00:01<00:01,  1.38it/s][A[A[A


Testing...:  75%|███████▌  | 3/4 [00:02<00:00,  1.39it/s][A[A[A


Testing...: 100%|██████████| 4/4 [00:02<00:00,  1.39it/s][A[A[A




  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?ba/s]

  0%|          | 0/15 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/17 [00:00<?, ?it/s][A[A[A


Testing...:   6%|▌         | 1/17 [00:00<00:11,  1.37it/s][A[A[A


Testing...:  12%|█▏        | 2/17 [00:01<00:10,  1.39it/s][A[A[A


Testing...:  18%|█▊        | 3/17 [00:02<00:10,  1.40it/s][A[A[A


Testing...:  24%|██▎       | 4/17 [00:02<00:09,  1.40it/s][A[A[A


Testing...:  29%|██▉       | 5/17 [00:03<00:08,  1.40it/s][A[A[A


Testing...:  35%|███▌      | 6/17 [00:04<00:07,  1.40it/s][A[A[A


Testing...:  41%|████      | 7/17 [00:04<00:07,  1.40it/s][A[A[A


Testing...:  47%|████▋     | 8/17 [00:05<00:06,  1.40it/s][A[A[A


Testing...:  53%|█████▎    | 9/17 [00:06<00:05,  1.41it/s][A[A[A


Testing...:  59%|█████▉    | 10/17 [00:07<00:04,  1.40it/s][A[A[A


Testing...:  65%|██████▍   | 11/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  71%|███████   | 12/17 [00:08<00:03,  1.41it/s][A[A[A


Testing...:  76%|███████▋  | 13/17 [00:09<00:02,  1.41it/s][A[A[A


Testing...:  82%|████

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/9 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/8 [00:00<?, ?it/s][A[A[A


Testing...:  12%|█▎        | 1/8 [00:00<00:05,  1.39it/s][A[A[A


Testing...:  25%|██▌       | 2/8 [00:01<00:04,  1.40it/s][A[A[A


Testing...:  38%|███▊      | 3/8 [00:02<00:03,  1.40it/s][A[A[A


Testing...:  50%|█████     | 4/8 [00:02<00:02,  1.40it/s][A[A[A


Testing...:  62%|██████▎   | 5/8 [00:03<00:02,  1.40it/s][A[A[A


Testing...:  75%|███████▌  | 6/8 [00:04<00:01,  1.41it/s][A[A[A


Testing...:  88%|████████▊ | 7/8 [00:04<00:00,  1.41it/s][A[A[A


Testing...: 100%|██████████| 8/8 [00:05<00:00,  1.40it/s][A[A[A


100%|██████████| 5/5 [00:59<00:00, 11.85s/it][A[A

Training...:  21%|██▏       | 13/61 [05:33<28:36, 35.76s/it][A
Training...:  23%|██▎       | 14/61 [05:36<20:18, 25.92s/it][A
Training...:  25%|██▍       | 15/61 [05:41<15:02, 19.61s/it][A

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]



Validation...:   0%|          | 0/1 [00:00<?, ?it/s][A[A

Validation...: 100%|██████████| 1/1 [00:00<00:00,  1.32it/s][A[A




  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?ba/s]

  0%|          | 0/13 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/17 [00:00<?, ?it/s][A[A[A


Testing...:   6%|▌         | 1/17 [00:00<00:11,  1.34it/s][A[A[A


Testing...:  12%|█▏        | 2/17 [00:01<00:10,  1.37it/s][A[A[A


Testing...:  18%|█▊        | 3/17 [00:02<00:10,  1.39it/s][A[A[A


Testing...:  24%|██▎       | 4/17 [00:02<00:09,  1.40it/s][A[A[A


Testing...:  29%|██▉       | 5/17 [00:03<00:08,  1.40it/s][A[A[A


Testing...:  35%|███▌      | 6/17 [00:04<00:07,  1.40it/s][A[A[A


Testing...:  41%|████      | 7/17 [00:05<00:07,  1.40it/s][A[A[A


Testing...:  47%|████▋     | 8/17 [00:05<00:06,  1.40it/s][A[A[A


Testing...:  53%|█████▎    | 9/17 [00:06<00:05,  1.41it/s][A[A[A


Testing...:  59%|█████▉    | 10/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  65%|██████▍   | 11/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  71%|███████   | 12/17 [00:08<00:03,  1.41it/s][A[A[A


Testing...:  76%|███████▋  | 13/17 [00:09<00:02,  1.41it/s][A[A[A


Testing...:  82%|████

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/6 [00:00<?, ?it/s][A[A[A


Testing...:  17%|█▋        | 1/6 [00:00<00:03,  1.38it/s][A[A[A


Testing...:  33%|███▎      | 2/6 [00:01<00:02,  1.40it/s][A[A[A


Testing...:  50%|█████     | 3/6 [00:02<00:02,  1.40it/s][A[A[A


Testing...:  67%|██████▋   | 4/6 [00:02<00:01,  1.40it/s][A[A[A


Testing...:  83%|████████▎ | 5/6 [00:03<00:00,  1.41it/s][A[A[A


Testing...: 100%|██████████| 6/6 [00:04<00:00,  1.40it/s][A[A[A




  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/4 [00:00<?, ?it/s][A[A[A


Testing...:  25%|██▌       | 1/4 [00:00<00:02,  1.37it/s][A[A[A


Testing...:  50%|█████     | 2/4 [00:01<00:01,  1.38it/s][A[A[A


Testing...:  75%|███████▌  | 3/4 [00:02<00:00,  1.39it/s][A[A[A


Testing...: 100%|██████████| 4/4 [00:02<00:00,  1.39it/s][A[A[A




  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?ba/s]

  0%|          | 0/15 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/17 [00:00<?, ?it/s][A[A[A


Testing...:   6%|▌         | 1/17 [00:00<00:11,  1.37it/s][A[A[A


Testing...:  12%|█▏        | 2/17 [00:01<00:10,  1.39it/s][A[A[A


Testing...:  18%|█▊        | 3/17 [00:02<00:10,  1.40it/s][A[A[A


Testing...:  24%|██▎       | 4/17 [00:02<00:09,  1.40it/s][A[A[A


Testing...:  29%|██▉       | 5/17 [00:03<00:08,  1.40it/s][A[A[A


Testing...:  35%|███▌      | 6/17 [00:04<00:07,  1.40it/s][A[A[A


Testing...:  41%|████      | 7/17 [00:04<00:07,  1.41it/s][A[A[A


Testing...:  47%|████▋     | 8/17 [00:05<00:06,  1.41it/s][A[A[A


Testing...:  53%|█████▎    | 9/17 [00:06<00:05,  1.40it/s][A[A[A


Testing...:  59%|█████▉    | 10/17 [00:07<00:04,  1.40it/s][A[A[A


Testing...:  65%|██████▍   | 11/17 [00:07<00:04,  1.40it/s][A[A[A


Testing...:  71%|███████   | 12/17 [00:08<00:03,  1.41it/s][A[A[A


Testing...:  76%|███████▋  | 13/17 [00:09<00:02,  1.41it/s][A[A[A


Testing...:  82%|████

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/9 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/8 [00:00<?, ?it/s][A[A[A


Testing...:  12%|█▎        | 1/8 [00:00<00:05,  1.38it/s][A[A[A


Testing...:  25%|██▌       | 2/8 [00:01<00:04,  1.40it/s][A[A[A


Testing...:  38%|███▊      | 3/8 [00:02<00:03,  1.40it/s][A[A[A


Testing...:  50%|█████     | 4/8 [00:02<00:02,  1.40it/s][A[A[A


Testing...:  62%|██████▎   | 5/8 [00:03<00:02,  1.40it/s][A[A[A


Testing...:  75%|███████▌  | 6/8 [00:04<00:01,  1.41it/s][A[A[A


Testing...:  88%|████████▊ | 7/8 [00:04<00:00,  1.41it/s][A[A[A


Testing...: 100%|██████████| 8/8 [00:05<00:00,  1.40it/s][A[A[A


100%|██████████| 5/5 [00:59<00:00, 11.88s/it][A[A

Training...:  26%|██▌       | 16/61 [06:55<27:01, 36.04s/it][A
Training...:  28%|██▊       | 17/61 [07:00<19:25, 26.48s/it][A
Training...:  30%|██▉       | 18/61 [07:04<14:15, 19.90s/it][A

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]



Validation...:   0%|          | 0/1 [00:00<?, ?it/s][A[A

Validation...: 100%|██████████| 1/1 [00:00<00:00,  1.33it/s][A[A




  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?ba/s]

  0%|          | 0/13 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/17 [00:00<?, ?it/s][A[A[A


Testing...:   6%|▌         | 1/17 [00:00<00:12,  1.32it/s][A[A[A


Testing...:  12%|█▏        | 2/17 [00:01<00:10,  1.37it/s][A[A[A


Testing...:  18%|█▊        | 3/17 [00:02<00:10,  1.39it/s][A[A[A


Testing...:  24%|██▎       | 4/17 [00:02<00:09,  1.40it/s][A[A[A


Testing...:  29%|██▉       | 5/17 [00:03<00:08,  1.40it/s][A[A[A


Testing...:  35%|███▌      | 6/17 [00:04<00:07,  1.40it/s][A[A[A


Testing...:  41%|████      | 7/17 [00:05<00:07,  1.40it/s][A[A[A


Testing...:  47%|████▋     | 8/17 [00:05<00:06,  1.40it/s][A[A[A


Testing...:  53%|█████▎    | 9/17 [00:06<00:05,  1.41it/s][A[A[A


Testing...:  59%|█████▉    | 10/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  65%|██████▍   | 11/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  71%|███████   | 12/17 [00:08<00:03,  1.41it/s][A[A[A


Testing...:  76%|███████▋  | 13/17 [00:09<00:02,  1.41it/s][A[A[A


Testing...:  82%|████

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/6 [00:00<?, ?it/s][A[A[A


Testing...:  17%|█▋        | 1/6 [00:00<00:03,  1.37it/s][A[A[A


Testing...:  33%|███▎      | 2/6 [00:01<00:02,  1.39it/s][A[A[A


Testing...:  50%|█████     | 3/6 [00:02<00:02,  1.40it/s][A[A[A


Testing...:  67%|██████▋   | 4/6 [00:02<00:01,  1.40it/s][A[A[A


Testing...:  83%|████████▎ | 5/6 [00:03<00:00,  1.40it/s][A[A[A


Testing...: 100%|██████████| 6/6 [00:04<00:00,  1.40it/s][A[A[A




  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/4 [00:00<?, ?it/s][A[A[A


Testing...:  25%|██▌       | 1/4 [00:00<00:02,  1.36it/s][A[A[A


Testing...:  50%|█████     | 2/4 [00:01<00:01,  1.39it/s][A[A[A


Testing...:  75%|███████▌  | 3/4 [00:02<00:00,  1.40it/s][A[A[A


Testing...: 100%|██████████| 4/4 [00:02<00:00,  1.39it/s][A[A[A




  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?ba/s]

  0%|          | 0/15 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/17 [00:00<?, ?it/s][A[A[A


Testing...:   6%|▌         | 1/17 [00:00<00:11,  1.35it/s][A[A[A


Testing...:  12%|█▏        | 2/17 [00:01<00:10,  1.38it/s][A[A[A


Testing...:  18%|█▊        | 3/17 [00:02<00:10,  1.39it/s][A[A[A


Testing...:  24%|██▎       | 4/17 [00:02<00:09,  1.40it/s][A[A[A


Testing...:  29%|██▉       | 5/17 [00:03<00:08,  1.40it/s][A[A[A


Testing...:  35%|███▌      | 6/17 [00:04<00:07,  1.40it/s][A[A[A


Testing...:  41%|████      | 7/17 [00:05<00:07,  1.40it/s][A[A[A


Testing...:  47%|████▋     | 8/17 [00:05<00:06,  1.41it/s][A[A[A


Testing...:  53%|█████▎    | 9/17 [00:06<00:05,  1.41it/s][A[A[A


Testing...:  59%|█████▉    | 10/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  65%|██████▍   | 11/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  71%|███████   | 12/17 [00:08<00:03,  1.41it/s][A[A[A


Testing...:  76%|███████▋  | 13/17 [00:09<00:02,  1.41it/s][A[A[A


Testing...:  82%|████

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/9 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/8 [00:00<?, ?it/s][A[A[A


Testing...:  12%|█▎        | 1/8 [00:00<00:05,  1.39it/s][A[A[A


Testing...:  25%|██▌       | 2/8 [00:01<00:04,  1.40it/s][A[A[A


Testing...:  38%|███▊      | 3/8 [00:02<00:03,  1.40it/s][A[A[A


Testing...:  50%|█████     | 4/8 [00:02<00:02,  1.40it/s][A[A[A


Testing...:  62%|██████▎   | 5/8 [00:03<00:02,  1.41it/s][A[A[A


Testing...:  75%|███████▌  | 6/8 [00:04<00:01,  1.41it/s][A[A[A


Testing...:  88%|████████▊ | 7/8 [00:04<00:00,  1.41it/s][A[A[A


Testing...: 100%|██████████| 8/8 [00:05<00:00,  1.40it/s][A[A[A


100%|██████████| 5/5 [00:59<00:00, 11.83s/it][A[A

Training...:  31%|███       | 19/61 [08:18<25:16, 36.11s/it][A
Training...:  33%|███▎      | 20/61 [08:22<18:06, 26.51s/it][A
Training...:  34%|███▍      | 21/61 [08:27<13:16, 19.91s/it][A

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]



Validation...:   0%|          | 0/1 [00:00<?, ?it/s][A[A

Validation...: 100%|██████████| 1/1 [00:00<00:00,  1.34it/s][A[A




  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?ba/s]

  0%|          | 0/13 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/17 [00:00<?, ?it/s][A[A[A


Testing...:   6%|▌         | 1/17 [00:00<00:11,  1.34it/s][A[A[A


Testing...:  12%|█▏        | 2/17 [00:01<00:10,  1.38it/s][A[A[A


Testing...:  18%|█▊        | 3/17 [00:02<00:10,  1.39it/s][A[A[A


Testing...:  24%|██▎       | 4/17 [00:02<00:09,  1.40it/s][A[A[A


Testing...:  29%|██▉       | 5/17 [00:03<00:08,  1.40it/s][A[A[A


Testing...:  35%|███▌      | 6/17 [00:04<00:07,  1.40it/s][A[A[A


Testing...:  41%|████      | 7/17 [00:05<00:07,  1.40it/s][A[A[A


Testing...:  47%|████▋     | 8/17 [00:05<00:06,  1.40it/s][A[A[A


Testing...:  53%|█████▎    | 9/17 [00:06<00:05,  1.41it/s][A[A[A


Testing...:  59%|█████▉    | 10/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  65%|██████▍   | 11/17 [00:07<00:04,  1.41it/s][A[A[A


Testing...:  71%|███████   | 12/17 [00:08<00:03,  1.41it/s][A[A[A


Testing...:  76%|███████▋  | 13/17 [00:09<00:02,  1.41it/s][A[A[A


Testing...:  82%|████

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/6 [00:00<?, ?it/s][A[A[A


Testing...:  17%|█▋        | 1/6 [00:00<00:03,  1.36it/s][A[A[A


Testing...:  33%|███▎      | 2/6 [00:01<00:02,  1.39it/s][A[A[A


Testing...:  50%|█████     | 3/6 [00:02<00:02,  1.40it/s][A[A[A


Testing...:  67%|██████▋   | 4/6 [00:02<00:01,  1.40it/s][A[A[A


Testing...:  83%|████████▎ | 5/6 [00:03<00:00,  1.40it/s][A[A[A


Testing...: 100%|██████████| 6/6 [00:04<00:00,  1.40it/s][A[A[A




  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/4 [00:00<?, ?it/s][A[A[A


Testing...:  25%|██▌       | 1/4 [00:00<00:02,  1.37it/s][A[A[A


Testing...:  50%|█████     | 2/4 [00:01<00:01,  1.39it/s][A[A[A


Testing...:  75%|███████▌  | 3/4 [00:02<00:00,  1.40it/s][A[A[A


Testing...: 100%|██████████| 4/4 [00:02<00:00,  1.40it/s][A[A[A




  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?ba/s]

  0%|          | 0/15 [00:00<?, ?ba/s]




Testing...:   0%|          | 0/17 [00:00<?, ?it/s][A[A[A


Testing...:   6%|▌         | 1/17 [00:00<00:11,  1.36it/s][A[A[A


Testing...:  12%|█▏        | 2/17 [00:01<00:10,  1.39it/s][A[A[A


Testing...:  18%|█▊        | 3/17 [00:02<00:10,  1.39it/s][A[A[A


Testing...:  24%|██▎       | 4/17 [00:02<00:09,  1.40it/s][A[A[A


Testing...:  29%|██▉       | 5/17 [00:03<00:08,  1.40it/s][A[A[A


Testing...:  35%|███▌      | 6/17 [00:04<00:07,  1.40it/s][A[A[A