*Copyright (c) Microsoft Corporation. All rights reserved.*

*Licensed under the MIT License.*

# Text Classification of Multi Language Datasets using Transformer Model

In [1]:
import scrapbook as sb
import pandas as pd

from tempfile import TemporaryDirectory
from utils_nlp.common.timer import Timer
from sklearn.metrics import classification_report
from utils_nlp.models.transformers.sequence_classification import *

from utils_nlp.dataset import multinli
from utils_nlp.dataset import dac
from utils_nlp.dataset import bbc_hindi

I1107 18:36:55.496321 140684753213184 file_utils.py:39] PyTorch version 1.2.0 available.
I1107 18:36:55.541030 140684753213184 modeling_xlnet.py:194] Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .


## Introduction

In this notebook, we fine-tune and evaluate a pretrained Transformer model using BERT earchitecture on three different language datasets:

- [MultiNLI dataset](https://www.nyu.edu/projects/bowman/multinli/): The Multi-Genre NLI corpus, in English
- [DAC dataset](https://data.mendeley.com/datasets/v524p5dhpj/2): DataSet for Arabic Classification corpus, in Arabic
- [BBC Hindi dataset](https://github.com/NirantK/hindi2vec/releases/tag/bbc-hindi-v0.1): BBC Hindi News corpus, in Hindi

If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. You can also choose a dataset from three existing datasets (**`MultNLI`**, **`DAC`**, and **`BBC Hindi`**) to experiment. 

### Running Time

The table below provides some reference running time on different datasets.  

|Dataset|QUICK_RUN|Machine Configurations|Running time|
|:------|:---------|:----------------------|:------------|
|MultiNLI|True|2 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 8 minutes |
|MultiNLI|False|2 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 5.7 hours |
|DAC|True|2 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 13 minutes |
|DAC|False|2 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 5.6 hours |
|BBC Hindi|True|2 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 1 minute |
|BBC Hindi|False|2 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 14 minutes |

If you run into CUDA out-of-memory error or the jupyter kernel dies constantly, try reducing the `batch_size` and `max_len` in `CONFIG`, but note that model performance may be compromised. 

In [2]:
# Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.
QUICK_RUN = True

# the dataset you want to try, valid values are: "multinli", "dac", "bbc-hindi"
USE_DATASET = "dac"

Several pretrained models have been made available by [Hugging Face](https://github.com/huggingface/transformers). For text classification, the following pretrained models are supported.

In [3]:
pd.DataFrame({"model_name": SequenceClassifier.list_supported_models()})

Unnamed: 0,model_name
0,bert-base-uncased
1,bert-large-uncased
2,bert-base-cased
3,bert-large-cased
4,bert-base-multilingual-uncased
5,bert-base-multilingual-cased
6,bert-base-chinese
7,bert-base-german-cased
8,bert-large-uncased-whole-word-masking
9,bert-large-cased-whole-word-masking


In order to demonstrate multi language capability of Transformer models, we only use the model **`bert-base-multilingual-cased`** by default in this notebook.

## Configuration

In [4]:
CONFIG = {
    'local_path': "./temp",
    'test_fraction': 0.2,
    'random_seed': 100,
    'train_sample_ratio': 1.0,
    'test_sample_ratio': 1.0,
    'model_name': 'bert-base-multilingual-cased',
    'to_lower': False,
    'cache_dir': './temp',
    'max_len': 150,
    'num_train_epochs': 5,
    'device': 'cuda',
    'batch_size': 8,
    'verbose': True,
    'load_dataset': None
}

if QUICK_RUN:
    CONFIG['train_sample_ratio'] = 0.2
    CONFIG['test_sample_ratio'] = 0.2
    CONFIG['num_train_epochs'] = 1

torch.manual_seed(CONFIG['random_seed'])

if torch.cuda.is_available():
    CONFIG['batch_size'] = 16
    
if USE_DATASET == "multinli":
    CONFIG['to_lower'] = True
    CONFIG['load_dataset'] = multinli.load_dataset
    
    if QUICK_RUN:
        CONFIG['train_sample_ratio'] = 0.1
        CONFIG['test_sample_ratio'] = 0.1
elif USE_DATASET == "dac":
    CONFIG['load_dataset'] = dac.load_dataset
elif USE_DATASET == "bbc-hindi":
    CONFIG['load_dataset'] = bbc_hindi.load_dataset
else:
    raise ValueError("Supported datasets are: 'multinli', 'dac', and 'bbc-hindi'")

## Load Dataset

By choosing the dataset you want to experiment with, the code snippet below will adaptively seletct a helper function **`load_dataset`** for the dataset.  The helper function downloads the raw data, splits it into training and testing datasets (also sub-sampling if the sampling ratio is smaller than 1.0), and then processes for the transformer model. Everything is done in one function call, and you can use the processed training and testing Pytorch datasets to fine tune the model and evaluate the performance of the model.

In [5]:
train_dataset, test_dataset, label_encoder = CONFIG['load_dataset'](
    local_path=CONFIG['local_path'],
    test_fraction=CONFIG['test_fraction'],
    random_seed=CONFIG['random_seed'],
    train_sample_ratio=CONFIG['train_sample_ratio'],
    test_sample_ratio=CONFIG['test_sample_ratio'],
    model_name=CONFIG['model_name'],
    to_lower=CONFIG['to_lower'],
    cache_dir=CONFIG['cache_dir'],
    max_len=CONFIG['max_len']
)

I1107 18:36:59.460421 140684753213184 tokenization_utils.py:374] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt from cache at ./temp/96435fa287fbf7e469185f1062386e05a075cadbf6838b74da22bf64b080bc32.99bcd55fc66f4f3360bc49ba472b940b8dcf223ea6a345deb969d607ca900729


## Fine Tune

There are two steps to fine tune a transformer model for text classifiction: 1). instantiate a `SequenceClassifier` class which is a wrapper of the transformer model, and 2), fit the model using the preprocessed training dataset. The member method `fit` of `SequenceClassifier` class is used to fine tune the model.

In [6]:
model = SequenceClassifier(
    model_name=CONFIG['model_name'],
    num_labels=len(label_encoder.classes_),
    cache_dir=CONFIG['cache_dir']
)

# Fine tune the model using the training dataset
with Timer() as t:
    model.fit(
        train_dataset=train_dataset,
        device=CONFIG['device'],
        num_epochs=CONFIG['num_train_epochs'],
        batch_size=CONFIG['batch_size'],
        verbose=CONFIG['verbose'],
        seed=CONFIG['random_seed']
    )

print("Training time : {:.3f} hrs".format(t.interval / 3600))

I1107 18:38:51.565629 140684753213184 configuration_utils.py:151] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json from cache at ./temp/45629519f3117b89d89fd9c740073d8e4c1f0a70f9842476185100a8afe715d1.83b0fa3d7f1ac0e113ad300189a938c6f14d0588a4200f30eef109d0a047c484
I1107 18:38:51.566921 140684753213184 configuration_utils.py:168] Model config {
  "attention_probs_dropout_prob": 0.1,
  "directionality": "bidi",
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_labels": 5,
  "output_attentions": false,
  "output_hidden_states": false,
  "output_past": true,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_t

Loss:0.051570



Iteration:   0%|          | 1/544 [00:03<34:58,  3.86s/it][A
Iteration:   0%|          | 2/544 [00:05<27:59,  3.10s/it][A
Iteration:   1%|          | 3/544 [00:06<23:18,  2.59s/it][A
Iteration:   1%|          | 4/544 [00:07<19:59,  2.22s/it][A
Iteration:   1%|          | 5/544 [00:09<17:39,  1.96s/it][A
Iteration:   1%|          | 6/544 [00:10<15:48,  1.76s/it][A
Iteration:   1%|▏         | 7/544 [00:11<14:31,  1.62s/it][A
Iteration:   1%|▏         | 8/544 [00:13<13:50,  1.55s/it][A
Iteration:   2%|▏         | 9/544 [00:14<13:07,  1.47s/it][A
                                            :36,  1.42s/it][A
Epoch:   0%|          | 0/1 [00:16<?, ?it/s]               
Iteration:   2%|▏         | 10/544 [00:16<12:36,  1.42s/it][A

Loss:0.033994



Iteration:   2%|▏         | 11/544 [00:17<12:15,  1.38s/it][A
Iteration:   2%|▏         | 12/544 [00:18<12:12,  1.38s/it][A
Iteration:   2%|▏         | 13/544 [00:19<12:13,  1.38s/it][A
Iteration:   3%|▎         | 14/544 [00:21<12:11,  1.38s/it][A
Iteration:   3%|▎         | 15/544 [00:22<12:36,  1.43s/it][A
Iteration:   3%|▎         | 16/544 [00:24<12:12,  1.39s/it][A
Iteration:   3%|▎         | 17/544 [00:25<11:57,  1.36s/it][A
Iteration:   3%|▎         | 18/544 [00:26<11:57,  1.37s/it][A
Iteration:   3%|▎         | 19/544 [00:28<11:44,  1.34s/it][A
                                            :27,  1.43s/it][A
Epoch:   0%|          | 0/1 [00:30<?, ?it/s]               
Iteration:   4%|▎         | 20/544 [00:30<12:27,  1.43s/it][A

Loss:0.017920



Iteration:   4%|▍         | 21/544 [00:31<12:20,  1.42s/it][A
Iteration:   4%|▍         | 22/544 [00:32<12:12,  1.40s/it][A
Iteration:   4%|▍         | 23/544 [00:33<12:07,  1.40s/it][A
Iteration:   4%|▍         | 24/544 [00:35<11:50,  1.37s/it][A
Iteration:   5%|▍         | 25/544 [00:36<11:37,  1.34s/it][A
Iteration:   5%|▍         | 26/544 [00:37<11:29,  1.33s/it][A
Iteration:   5%|▍         | 27/544 [00:39<12:25,  1.44s/it][A
Iteration:   5%|▌         | 28/544 [00:40<12:14,  1.42s/it][A
Iteration:   5%|▌         | 29/544 [00:42<11:53,  1.38s/it][A
                                            :04,  1.41s/it][A
Epoch:   0%|          | 0/1 [00:44<?, ?it/s]               
Iteration:   6%|▌         | 30/544 [00:44<12:04,  1.41s/it][A

Loss:0.012660



Iteration:   6%|▌         | 31/544 [00:45<12:39,  1.48s/it][A
Iteration:   6%|▌         | 32/544 [00:46<12:10,  1.43s/it][A
Iteration:   6%|▌         | 33/544 [00:47<12:00,  1.41s/it][A
Iteration:   6%|▋         | 34/544 [00:49<12:20,  1.45s/it][A
Iteration:   6%|▋         | 35/544 [00:51<12:55,  1.52s/it][A
Iteration:   7%|▋         | 36/544 [00:52<12:33,  1.48s/it][A
Iteration:   7%|▋         | 37/544 [00:53<12:03,  1.43s/it][A
Iteration:   7%|▋         | 38/544 [00:55<11:54,  1.41s/it][A
Iteration:   7%|▋         | 39/544 [00:56<12:14,  1.45s/it][A
                                            :51,  1.41s/it][A
Epoch:   0%|          | 0/1 [00:58<?, ?it/s]               
Iteration:   7%|▋         | 40/544 [00:58<11:51,  1.41s/it][A

Loss:0.013456



Iteration:   8%|▊         | 41/544 [00:59<11:45,  1.40s/it][A
Iteration:   8%|▊         | 42/544 [01:00<11:40,  1.40s/it][A
Iteration:   8%|▊         | 43/544 [01:02<12:04,  1.45s/it][A
Iteration:   8%|▊         | 44/544 [01:03<11:41,  1.40s/it][A
Iteration:   8%|▊         | 45/544 [01:05<11:37,  1.40s/it][A
Iteration:   8%|▊         | 46/544 [01:06<12:01,  1.45s/it][A
Iteration:   9%|▊         | 47/544 [01:08<12:26,  1.50s/it][A
Iteration:   9%|▉         | 48/544 [01:10<12:59,  1.57s/it][A
Iteration:   9%|▉         | 49/544 [01:11<12:19,  1.49s/it][A
                                            :14,  1.49s/it][A
Epoch:   0%|          | 0/1 [01:13<?, ?it/s]               
Iteration:   9%|▉         | 50/544 [01:13<12:14,  1.49s/it][A

Loss:0.023816



Iteration:   9%|▉         | 51/544 [01:14<12:13,  1.49s/it][A
Iteration:  10%|▉         | 52/544 [01:15<11:58,  1.46s/it][A
Iteration:  10%|▉         | 53/544 [01:17<12:01,  1.47s/it][A
Iteration:  10%|▉         | 54/544 [01:18<12:23,  1.52s/it][A
Iteration:  10%|█         | 55/544 [01:20<12:49,  1.57s/it][A
Iteration:  10%|█         | 56/544 [01:21<12:21,  1.52s/it][A
Iteration:  10%|█         | 57/544 [01:23<12:48,  1.58s/it][A
Iteration:  11%|█         | 58/544 [01:24<12:19,  1.52s/it][A
Iteration:  11%|█         | 59/544 [01:26<12:23,  1.53s/it][A
                                            :49,  1.47s/it][A
Epoch:   0%|          | 0/1 [01:28<?, ?it/s]               
Iteration:  11%|█         | 60/544 [01:28<11:49,  1.47s/it][A

Loss:0.015286



Iteration:  11%|█         | 61/544 [01:29<11:37,  1.44s/it][A
Iteration:  11%|█▏        | 62/544 [01:30<12:03,  1.50s/it][A
Iteration:  12%|█▏        | 63/544 [01:32<11:45,  1.47s/it][A
Iteration:  12%|█▏        | 64/544 [01:33<11:32,  1.44s/it][A
Iteration:  12%|█▏        | 65/544 [01:35<11:38,  1.46s/it][A
Iteration:  12%|█▏        | 66/544 [01:36<11:16,  1.41s/it][A
Iteration:  12%|█▏        | 67/544 [01:37<11:11,  1.41s/it][A
Iteration:  12%|█▎        | 68/544 [01:39<11:08,  1.40s/it][A
Iteration:  13%|█▎        | 69/544 [01:40<11:08,  1.41s/it][A
                                            :05,  1.40s/it][A
Epoch:   0%|          | 0/1 [01:42<?, ?it/s]               
Iteration:  13%|█▎        | 70/544 [01:42<11:05,  1.40s/it][A

Loss:0.003237



Iteration:  13%|█▎        | 71/544 [01:43<11:48,  1.50s/it][A
Iteration:  13%|█▎        | 72/544 [01:45<11:32,  1.47s/it][A
Iteration:  13%|█▎        | 73/544 [01:46<11:47,  1.50s/it][A
Iteration:  14%|█▎        | 74/544 [01:48<11:53,  1.52s/it][A
Iteration:  14%|█▍        | 75/544 [01:49<11:35,  1.48s/it][A
Iteration:  14%|█▍        | 76/544 [01:51<11:20,  1.45s/it][A
Iteration:  14%|█▍        | 77/544 [01:52<11:11,  1.44s/it][A
Iteration:  14%|█▍        | 78/544 [01:53<11:03,  1.42s/it][A
Iteration:  15%|█▍        | 79/544 [01:55<11:00,  1.42s/it][A
                                            :56,  1.41s/it][A
Epoch:   0%|          | 0/1 [01:57<?, ?it/s]               
Iteration:  15%|█▍        | 80/544 [01:57<10:56,  1.41s/it][A

Loss:0.008068



Iteration:  15%|█▍        | 81/544 [01:58<10:53,  1.41s/it][A
Iteration:  15%|█▌        | 82/544 [01:59<10:51,  1.41s/it][A
Iteration:  15%|█▌        | 83/544 [02:01<11:32,  1.50s/it][A
Iteration:  15%|█▌        | 84/544 [02:02<11:16,  1.47s/it][A
Iteration:  16%|█▌        | 85/544 [02:04<11:17,  1.48s/it][A
Iteration:  16%|█▌        | 86/544 [02:05<11:41,  1.53s/it][A
Iteration:  16%|█▌        | 87/544 [02:07<11:21,  1.49s/it][A
Iteration:  16%|█▌        | 88/544 [02:08<11:31,  1.52s/it][A
Iteration:  16%|█▋        | 89/544 [02:10<11:01,  1.45s/it][A
                                            :16,  1.49s/it][A
Epoch:   0%|          | 0/1 [02:12<?, ?it/s]               
Iteration:  17%|█▋        | 90/544 [02:12<11:16,  1.49s/it][A

Loss:0.004577



Iteration:  17%|█▋        | 91/544 [02:13<11:39,  1.54s/it][A
Iteration:  17%|█▋        | 92/544 [02:14<11:19,  1.50s/it][A
Iteration:  17%|█▋        | 93/544 [02:16<11:04,  1.47s/it][A
Iteration:  17%|█▋        | 94/544 [02:17<11:05,  1.48s/it][A
Iteration:  17%|█▋        | 95/544 [02:19<11:25,  1.53s/it][A
Iteration:  18%|█▊        | 96/544 [02:20<11:19,  1.52s/it][A
Iteration:  18%|█▊        | 97/544 [02:22<11:02,  1.48s/it][A
Iteration:  18%|█▊        | 98/544 [02:23<10:51,  1.46s/it][A
Iteration:  18%|█▊        | 99/544 [02:24<10:43,  1.45s/it][A
                                            0:59,  1.48s/it][A
Epoch:   0%|          | 0/1 [02:27<?, ?it/s]                
Iteration:  18%|█▊        | 100/544 [02:27<10:59,  1.48s/it][A

Loss:0.007717



Iteration:  19%|█▊        | 101/544 [02:27<10:48,  1.46s/it][A
Iteration:  19%|█▉        | 102/544 [02:29<10:27,  1.42s/it][A
Iteration:  19%|█▉        | 103/544 [02:30<10:34,  1.44s/it][A
Iteration:  19%|█▉        | 104/544 [02:32<11:01,  1.50s/it][A
Iteration:  19%|█▉        | 105/544 [02:34<11:21,  1.55s/it][A
Iteration:  19%|█▉        | 106/544 [02:35<11:02,  1.51s/it][A
Iteration:  20%|█▉        | 107/544 [02:36<10:46,  1.48s/it][A
Iteration:  20%|█▉        | 108/544 [02:38<10:35,  1.46s/it][A
Iteration:  20%|██        | 109/544 [02:39<10:58,  1.51s/it][A
                                            1:04,  1.53s/it][A
Epoch:   0%|          | 0/1 [02:42<?, ?it/s]                
Iteration:  20%|██        | 110/544 [02:42<11:04,  1.53s/it][A

Loss:0.008509



Iteration:  20%|██        | 111/544 [02:43<11:10,  1.55s/it][A
Iteration:  21%|██        | 112/544 [02:44<11:01,  1.53s/it][A
Iteration:  21%|██        | 113/544 [02:46<10:45,  1.50s/it][A
Iteration:  21%|██        | 114/544 [02:47<10:32,  1.47s/it][A
Iteration:  21%|██        | 115/544 [02:48<10:34,  1.48s/it][A
Iteration:  21%|██▏       | 116/544 [02:50<10:36,  1.49s/it][A
Iteration:  22%|██▏       | 117/544 [02:51<10:23,  1.46s/it][A
Iteration:  22%|██▏       | 118/544 [02:53<10:26,  1.47s/it][A
Iteration:  22%|██▏       | 119/544 [02:54<10:39,  1.50s/it][A
                                            0:37,  1.50s/it][A
Epoch:   0%|          | 0/1 [02:57<?, ?it/s]                
Iteration:  22%|██▏       | 120/544 [02:57<10:37,  1.50s/it][A

Loss:0.017689



Iteration:  22%|██▏       | 121/544 [02:57<10:34,  1.50s/it][A
Iteration:  22%|██▏       | 122/544 [02:59<10:43,  1.52s/it][A
Iteration:  23%|██▎       | 123/544 [03:01<10:50,  1.54s/it][A
Iteration:  23%|██▎       | 124/544 [03:02<10:19,  1.48s/it][A
Iteration:  23%|██▎       | 125/544 [03:03<10:10,  1.46s/it][A
Iteration:  23%|██▎       | 126/544 [03:05<10:04,  1.45s/it][A
Iteration:  23%|██▎       | 127/544 [03:06<10:09,  1.46s/it][A
Iteration:  24%|██▎       | 128/544 [03:08<10:01,  1.45s/it][A
Iteration:  24%|██▎       | 129/544 [03:09<10:07,  1.46s/it][A
                                            0:19,  1.50s/it][A
Epoch:   0%|          | 0/1 [03:11<?, ?it/s]                
Iteration:  24%|██▍       | 130/544 [03:11<10:19,  1.50s/it][A

Loss:0.005474



Iteration:  24%|██▍       | 131/544 [03:12<10:29,  1.52s/it][A
Iteration:  24%|██▍       | 132/544 [03:14<10:35,  1.54s/it][A
Iteration:  24%|██▍       | 133/544 [03:15<10:17,  1.50s/it][A
Iteration:  25%|██▍       | 134/544 [03:17<10:15,  1.50s/it][A
Iteration:  25%|██▍       | 135/544 [03:18<10:02,  1.47s/it][A
Iteration:  25%|██▌       | 136/544 [03:20<10:13,  1.50s/it][A
Iteration:  25%|██▌       | 137/544 [03:21<10:02,  1.48s/it][A
Iteration:  25%|██▌       | 138/544 [03:23<09:51,  1.46s/it][A
Iteration:  26%|██▌       | 139/544 [03:24<10:05,  1.49s/it][A
                                            0:25,  1.55s/it][A
Epoch:   0%|          | 0/1 [03:26<?, ?it/s]                
Iteration:  26%|██▌       | 140/544 [03:26<10:25,  1.55s/it][A

Loss:0.017315



Iteration:  26%|██▌       | 141/544 [03:27<09:58,  1.49s/it][A
Iteration:  26%|██▌       | 142/544 [03:29<09:59,  1.49s/it][A
Iteration:  26%|██▋       | 143/544 [03:30<09:58,  1.49s/it][A
Iteration:  26%|██▋       | 144/544 [03:32<09:58,  1.50s/it][A
Iteration:  27%|██▋       | 145/544 [03:33<09:45,  1.47s/it][A
Iteration:  27%|██▋       | 146/544 [03:35<09:56,  1.50s/it][A
Iteration:  27%|██▋       | 147/544 [03:36<09:54,  1.50s/it][A
Iteration:  27%|██▋       | 148/544 [03:38<09:53,  1.50s/it][A
Iteration:  27%|██▋       | 149/544 [03:39<09:52,  1.50s/it][A
                                            9:40,  1.47s/it][A
Epoch:   0%|          | 0/1 [03:41<?, ?it/s]                
Iteration:  28%|██▊       | 150/544 [03:41<09:40,  1.47s/it][A

Loss:0.009913



Iteration:  28%|██▊       | 151/544 [03:42<09:32,  1.46s/it][A
Iteration:  28%|██▊       | 152/544 [03:43<09:15,  1.42s/it][A
Iteration:  28%|██▊       | 153/544 [03:45<09:11,  1.41s/it][A
Iteration:  28%|██▊       | 154/544 [03:46<09:10,  1.41s/it][A
Iteration:  28%|██▊       | 155/544 [03:48<09:37,  1.49s/it][A
Iteration:  29%|██▊       | 156/544 [03:49<09:26,  1.46s/it][A
Iteration:  29%|██▉       | 157/544 [03:51<09:48,  1.52s/it][A
Iteration:  29%|██▉       | 158/544 [03:52<09:35,  1.49s/it][A
Iteration:  29%|██▉       | 159/544 [03:54<09:43,  1.52s/it][A
                                            9:30,  1.48s/it][A
Epoch:   0%|          | 0/1 [03:56<?, ?it/s]                
Iteration:  29%|██▉       | 160/544 [03:56<09:30,  1.48s/it][A

Loss:0.007046



Iteration:  30%|██▉       | 161/544 [03:57<09:22,  1.47s/it][A
Iteration:  30%|██▉       | 162/544 [03:58<09:32,  1.50s/it][A
Iteration:  30%|██▉       | 163/544 [04:00<09:52,  1.55s/it][A
Iteration:  30%|███       | 164/544 [04:02<09:52,  1.56s/it][A
Iteration:  30%|███       | 165/544 [04:03<09:43,  1.54s/it][A
Iteration:  31%|███       | 166/544 [04:04<09:27,  1.50s/it][A
Iteration:  31%|███       | 167/544 [04:06<09:15,  1.47s/it][A
Iteration:  31%|███       | 168/544 [04:07<09:16,  1.48s/it][A
Iteration:  31%|███       | 169/544 [04:09<09:06,  1.46s/it][A
                                            9:19,  1.50s/it][A
Epoch:   0%|          | 0/1 [04:11<?, ?it/s]                
Iteration:  31%|███▏      | 170/544 [04:11<09:19,  1.50s/it][A

Loss:0.010506



Iteration:  31%|███▏      | 171/544 [04:12<09:08,  1.47s/it][A
Iteration:  32%|███▏      | 172/544 [04:13<09:00,  1.45s/it][A
Iteration:  32%|███▏      | 173/544 [04:15<09:13,  1.49s/it][A
Iteration:  32%|███▏      | 174/544 [04:16<09:20,  1.52s/it][A
Iteration:  32%|███▏      | 175/544 [04:18<09:17,  1.51s/it][A
Iteration:  32%|███▏      | 176/544 [04:19<09:14,  1.51s/it][A
Iteration:  33%|███▎      | 177/544 [04:21<09:02,  1.48s/it][A
Iteration:  33%|███▎      | 178/544 [04:22<08:52,  1.46s/it][A
Iteration:  33%|███▎      | 179/544 [04:24<08:56,  1.47s/it][A
                                            8:57,  1.48s/it][A
Epoch:   0%|          | 0/1 [04:26<?, ?it/s]                
Iteration:  33%|███▎      | 180/544 [04:26<08:57,  1.48s/it][A

Loss:0.004304



Iteration:  33%|███▎      | 181/544 [04:27<08:58,  1.48s/it][A
Iteration:  33%|███▎      | 182/544 [04:28<09:07,  1.51s/it][A
Iteration:  34%|███▎      | 183/544 [04:30<08:54,  1.48s/it][A
Iteration:  34%|███▍      | 184/544 [04:31<09:02,  1.51s/it][A
Iteration:  34%|███▍      | 185/544 [04:33<08:51,  1.48s/it][A
Iteration:  34%|███▍      | 186/544 [04:34<08:42,  1.46s/it][A
Iteration:  34%|███▍      | 187/544 [04:35<08:45,  1.47s/it][A
Iteration:  35%|███▍      | 188/544 [04:37<08:26,  1.42s/it][A
Iteration:  35%|███▍      | 189/544 [04:38<08:32,  1.44s/it][A
                                            8:37,  1.46s/it][A
Epoch:   0%|          | 0/1 [04:40<?, ?it/s]                
Iteration:  35%|███▍      | 190/544 [04:40<08:37,  1.46s/it][A

Loss:0.004650



Iteration:  35%|███▌      | 191/544 [04:41<08:39,  1.47s/it][A
Iteration:  35%|███▌      | 192/544 [04:43<08:32,  1.46s/it][A
Iteration:  35%|███▌      | 193/544 [04:44<08:43,  1.49s/it][A
Iteration:  36%|███▌      | 194/544 [04:46<08:33,  1.47s/it][A
Iteration:  36%|███▌      | 195/544 [04:47<08:26,  1.45s/it][A
Iteration:  36%|███▌      | 196/544 [04:49<08:20,  1.44s/it][A
Iteration:  36%|███▌      | 197/544 [04:50<08:32,  1.48s/it][A
Iteration:  36%|███▋      | 198/544 [04:52<08:25,  1.46s/it][A
Iteration:  37%|███▋      | 199/544 [04:53<08:27,  1.47s/it][A
                                            8:27,  1.48s/it][A
Epoch:   0%|          | 0/1 [04:55<?, ?it/s]                
Iteration:  37%|███▋      | 200/544 [04:55<08:27,  1.48s/it][A

Loss:0.011627



Iteration:  37%|███▋      | 201/544 [04:56<08:27,  1.48s/it][A
Iteration:  37%|███▋      | 202/544 [04:58<08:52,  1.56s/it][A
Iteration:  37%|███▋      | 203/544 [04:59<08:35,  1.51s/it][A
Iteration:  38%|███▊      | 204/544 [05:01<08:23,  1.48s/it][A
Iteration:  38%|███▊      | 205/544 [05:02<08:40,  1.54s/it][A
Iteration:  38%|███▊      | 206/544 [05:04<08:35,  1.53s/it][A
Iteration:  38%|███▊      | 207/544 [05:05<08:30,  1.51s/it][A
Iteration:  38%|███▊      | 208/544 [05:07<08:09,  1.46s/it][A
Iteration:  38%|███▊      | 209/544 [05:08<08:12,  1.47s/it][A
                                            8:29,  1.53s/it][A
Epoch:   0%|          | 0/1 [05:10<?, ?it/s]                
Iteration:  39%|███▊      | 210/544 [05:10<08:29,  1.53s/it][A

Loss:0.013395



Iteration:  39%|███▉      | 211/544 [05:11<08:41,  1.57s/it][A
Iteration:  39%|███▉      | 212/544 [05:13<08:34,  1.55s/it][A
Iteration:  39%|███▉      | 213/544 [05:14<08:10,  1.48s/it][A
Iteration:  39%|███▉      | 214/544 [05:16<08:01,  1.46s/it][A
Iteration:  40%|███▉      | 215/544 [05:17<07:55,  1.44s/it][A
Iteration:  40%|███▉      | 216/544 [05:18<07:49,  1.43s/it][A
Iteration:  40%|███▉      | 217/544 [05:20<07:54,  1.45s/it][A
Iteration:  40%|████      | 218/544 [05:21<07:49,  1.44s/it][A
Iteration:  40%|████      | 219/544 [05:23<07:53,  1.46s/it][A
                                            7:48,  1.45s/it][A
Epoch:   0%|          | 0/1 [05:25<?, ?it/s]                
Iteration:  40%|████      | 220/544 [05:25<07:48,  1.45s/it][A

Loss:0.015131



Iteration:  41%|████      | 221/544 [05:26<07:44,  1.44s/it][A
Iteration:  41%|████      | 222/544 [05:27<07:48,  1.45s/it][A
Iteration:  41%|████      | 223/544 [05:28<07:33,  1.41s/it][A
Iteration:  41%|████      | 224/544 [05:30<07:30,  1.41s/it][A
Iteration:  41%|████▏     | 225/544 [05:31<07:30,  1.41s/it][A
Iteration:  42%|████▏     | 226/544 [05:33<07:27,  1.41s/it][A
Iteration:  42%|████▏     | 227/544 [05:34<07:34,  1.43s/it][A
Iteration:  42%|████▏     | 228/544 [05:36<07:38,  1.45s/it][A
Iteration:  42%|████▏     | 229/544 [05:37<07:32,  1.44s/it][A
                                            7:35,  1.45s/it][A
Epoch:   0%|          | 0/1 [05:39<?, ?it/s]                
Iteration:  42%|████▏     | 230/544 [05:39<07:35,  1.45s/it][A

Loss:0.011105



Iteration:  42%|████▏     | 231/544 [05:40<07:38,  1.46s/it][A
Iteration:  43%|████▎     | 232/544 [05:41<07:23,  1.42s/it][A
Iteration:  43%|████▎     | 233/544 [05:43<07:21,  1.42s/it][A
Iteration:  43%|████▎     | 234/544 [05:44<07:18,  1.42s/it][A
Iteration:  43%|████▎     | 235/544 [05:46<07:25,  1.44s/it][A
Iteration:  43%|████▎     | 236/544 [05:47<07:37,  1.48s/it][A
Iteration:  44%|████▎     | 237/544 [05:49<07:57,  1.56s/it][A
Iteration:  44%|████▍     | 238/544 [05:51<07:57,  1.56s/it][A
Iteration:  44%|████▍     | 239/544 [05:52<07:40,  1.51s/it][A
                                            7:45,  1.53s/it][A
Epoch:   0%|          | 0/1 [05:54<?, ?it/s]                
Iteration:  44%|████▍     | 240/544 [05:54<07:45,  1.53s/it][A

Loss:0.008719



Iteration:  44%|████▍     | 241/544 [05:55<07:32,  1.49s/it][A
Iteration:  44%|████▍     | 242/544 [05:57<07:51,  1.56s/it][A
Iteration:  45%|████▍     | 243/544 [05:58<07:50,  1.56s/it][A
Iteration:  45%|████▍     | 244/544 [06:00<07:36,  1.52s/it][A
Iteration:  45%|████▌     | 245/544 [06:01<07:32,  1.51s/it][A
Iteration:  45%|████▌     | 246/544 [06:03<07:29,  1.51s/it][A
Iteration:  45%|████▌     | 247/544 [06:04<07:26,  1.50s/it][A
Iteration:  46%|████▌     | 248/544 [06:06<07:17,  1.48s/it][A
Iteration:  46%|████▌     | 249/544 [06:07<07:31,  1.53s/it][A
                                            7:19,  1.49s/it][A
Epoch:   0%|          | 0/1 [06:09<?, ?it/s]                
Iteration:  46%|████▌     | 250/544 [06:09<07:19,  1.49s/it][A

Loss:0.008591



Iteration:  46%|████▌     | 251/544 [06:10<07:18,  1.50s/it][A
Iteration:  46%|████▋     | 252/544 [06:12<07:08,  1.47s/it][A
Iteration:  47%|████▋     | 253/544 [06:13<07:09,  1.48s/it][A
Iteration:  47%|████▋     | 254/544 [06:14<07:02,  1.46s/it][A
Iteration:  47%|████▋     | 255/544 [06:16<07:11,  1.49s/it][A
Iteration:  47%|████▋     | 256/544 [06:18<07:24,  1.54s/it][A
Iteration:  47%|████▋     | 257/544 [06:19<07:18,  1.53s/it][A
Iteration:  47%|████▋     | 258/544 [06:21<07:28,  1.57s/it][A
Iteration:  48%|████▊     | 259/544 [06:22<07:13,  1.52s/it][A
                                            7:16,  1.54s/it][A
Epoch:   0%|          | 0/1 [06:24<?, ?it/s]                
Iteration:  48%|████▊     | 260/544 [06:24<07:16,  1.54s/it][A

Loss:0.014355



Iteration:  48%|████▊     | 261/544 [06:25<07:03,  1.50s/it][A
Iteration:  48%|████▊     | 262/544 [06:27<07:01,  1.49s/it][A
Iteration:  48%|████▊     | 263/544 [06:28<06:53,  1.47s/it][A
Iteration:  49%|████▊     | 264/544 [06:30<06:46,  1.45s/it][A
Iteration:  49%|████▊     | 265/544 [06:31<06:40,  1.44s/it][A
Iteration:  49%|████▉     | 266/544 [06:32<06:36,  1.42s/it][A
Iteration:  49%|████▉     | 267/544 [06:34<06:32,  1.42s/it][A
Iteration:  49%|████▉     | 268/544 [06:35<06:29,  1.41s/it][A
Iteration:  49%|████▉     | 269/544 [06:37<06:27,  1.41s/it][A
                                            6:46,  1.48s/it][A
Epoch:   0%|          | 0/1 [06:39<?, ?it/s]                
Iteration:  50%|████▉     | 270/544 [06:39<06:46,  1.48s/it][A

Loss:0.011691



Iteration:  50%|████▉     | 271/544 [06:40<06:45,  1.49s/it][A
Iteration:  50%|█████     | 272/544 [06:41<06:37,  1.46s/it][A
Iteration:  50%|█████     | 273/544 [06:43<06:39,  1.47s/it][A
Iteration:  50%|█████     | 274/544 [06:44<06:38,  1.48s/it][A
Iteration:  51%|█████     | 275/544 [06:45<06:31,  1.46s/it][A
Iteration:  51%|█████     | 276/544 [06:47<06:34,  1.47s/it][A
Iteration:  51%|█████     | 277/544 [06:49<06:48,  1.53s/it][A
Iteration:  51%|█████     | 278/544 [06:50<06:44,  1.52s/it][A
Iteration:  51%|█████▏    | 279/544 [06:52<06:34,  1.49s/it][A
                                            6:34,  1.49s/it][A
Epoch:   0%|          | 0/1 [06:54<?, ?it/s]                
Iteration:  51%|█████▏    | 280/544 [06:54<06:34,  1.49s/it][A

Loss:0.007512



Iteration:  52%|█████▏    | 281/544 [06:54<06:27,  1.47s/it][A
Iteration:  52%|█████▏    | 282/544 [06:56<06:28,  1.48s/it][A
Iteration:  52%|█████▏    | 283/544 [06:57<06:27,  1.48s/it][A
Iteration:  52%|█████▏    | 284/544 [06:59<06:20,  1.46s/it][A
Iteration:  52%|█████▏    | 285/544 [07:00<06:21,  1.47s/it][A
Iteration:  53%|█████▎    | 286/544 [07:02<06:13,  1.45s/it][A
Iteration:  53%|█████▎    | 287/544 [07:03<06:28,  1.51s/it][A
Iteration:  53%|█████▎    | 288/544 [07:05<06:19,  1.48s/it][A
Iteration:  53%|█████▎    | 289/544 [07:06<06:12,  1.46s/it][A
                                            6:06,  1.44s/it][A
Epoch:   0%|          | 0/1 [07:08<?, ?it/s]                
Iteration:  53%|█████▎    | 290/544 [07:08<06:06,  1.44s/it][A

Loss:0.013523



Iteration:  53%|█████▎    | 291/544 [07:09<06:09,  1.46s/it][A
Iteration:  54%|█████▎    | 292/544 [07:11<06:04,  1.45s/it][A
Iteration:  54%|█████▍    | 293/544 [07:12<05:59,  1.43s/it][A
Iteration:  54%|█████▍    | 294/544 [07:13<05:55,  1.42s/it][A
Iteration:  54%|█████▍    | 295/544 [07:15<05:53,  1.42s/it][A
Iteration:  54%|█████▍    | 296/544 [07:16<06:03,  1.47s/it][A
Iteration:  55%|█████▍    | 297/544 [07:18<05:57,  1.45s/it][A
Iteration:  55%|█████▍    | 298/544 [07:19<06:11,  1.51s/it][A
Iteration:  55%|█████▍    | 299/544 [07:21<06:20,  1.55s/it][A
                                            6:26,  1.58s/it][A
Epoch:   0%|          | 0/1 [07:23<?, ?it/s]                
Iteration:  55%|█████▌    | 300/544 [07:23<06:26,  1.58s/it][A

Loss:0.004399



Iteration:  55%|█████▌    | 301/544 [07:24<06:18,  1.56s/it][A
Iteration:  56%|█████▌    | 302/544 [07:26<06:05,  1.51s/it][A
Iteration:  56%|█████▌    | 303/544 [07:27<05:50,  1.46s/it][A
Iteration:  56%|█████▌    | 304/544 [07:28<05:45,  1.44s/it][A
Iteration:  56%|█████▌    | 305/544 [07:30<05:41,  1.43s/it][A
Iteration:  56%|█████▋    | 306/544 [07:31<05:45,  1.45s/it][A
Iteration:  56%|█████▋    | 307/544 [07:33<05:41,  1.44s/it][A
Iteration:  57%|█████▋    | 308/544 [07:34<05:38,  1.44s/it][A
Iteration:  57%|█████▋    | 309/544 [07:36<05:47,  1.48s/it][A
                                            5:53,  1.51s/it][A
Epoch:   0%|          | 0/1 [07:38<?, ?it/s]                
Iteration:  57%|█████▋    | 310/544 [07:38<05:53,  1.51s/it][A

Loss:0.007418



Iteration:  57%|█████▋    | 311/544 [07:39<05:44,  1.48s/it][A
Iteration:  57%|█████▋    | 312/544 [07:40<05:39,  1.46s/it][A
Iteration:  58%|█████▊    | 313/544 [07:42<05:39,  1.47s/it][A
Iteration:  58%|█████▊    | 314/544 [07:43<05:34,  1.45s/it][A
Iteration:  58%|█████▊    | 315/544 [07:44<05:28,  1.44s/it][A
Iteration:  58%|█████▊    | 316/544 [07:46<05:24,  1.43s/it][A
Iteration:  58%|█████▊    | 317/544 [07:47<05:21,  1.42s/it][A
Iteration:  58%|█████▊    | 318/544 [07:49<05:20,  1.42s/it][A
Iteration:  59%|█████▊    | 319/544 [07:50<05:35,  1.49s/it][A
                                            5:28,  1.47s/it][A
Epoch:   0%|          | 0/1 [07:52<?, ?it/s]                
Iteration:  59%|█████▉    | 320/544 [07:52<05:28,  1.47s/it][A

Loss:0.005675



Iteration:  59%|█████▉    | 321/544 [07:53<05:34,  1.50s/it][A
Iteration:  59%|█████▉    | 322/544 [07:55<05:26,  1.47s/it][A
Iteration:  59%|█████▉    | 323/544 [07:56<05:26,  1.48s/it][A
Iteration:  60%|█████▉    | 324/544 [07:58<05:20,  1.46s/it][A
Iteration:  60%|█████▉    | 325/544 [07:59<05:21,  1.47s/it][A
Iteration:  60%|█████▉    | 326/544 [08:01<05:32,  1.52s/it][A
Iteration:  60%|██████    | 327/544 [08:02<05:22,  1.49s/it][A
Iteration:  60%|██████    | 328/544 [08:04<05:15,  1.46s/it][A
Iteration:  60%|██████    | 329/544 [08:05<05:11,  1.45s/it][A
                                            5:22,  1.51s/it][A
Epoch:   0%|          | 0/1 [08:07<?, ?it/s]                
Iteration:  61%|██████    | 330/544 [08:07<05:22,  1.51s/it][A

Loss:0.011921



Iteration:  61%|██████    | 331/544 [08:08<05:15,  1.48s/it][A
Iteration:  61%|██████    | 332/544 [08:09<05:09,  1.46s/it][A
Iteration:  61%|██████    | 333/544 [08:11<05:03,  1.44s/it][A
Iteration:  61%|██████▏   | 334/544 [08:12<05:05,  1.46s/it][A
Iteration:  62%|██████▏   | 335/544 [08:14<05:08,  1.47s/it][A
Iteration:  62%|██████▏   | 336/544 [08:15<05:03,  1.46s/it][A
Iteration:  62%|██████▏   | 337/544 [08:17<05:03,  1.47s/it][A
Iteration:  62%|██████▏   | 338/544 [08:18<04:53,  1.42s/it][A
Iteration:  62%|██████▏   | 339/544 [08:20<04:56,  1.44s/it][A
                                            4:57,  1.46s/it][A
Epoch:   0%|          | 0/1 [08:22<?, ?it/s]                
Iteration:  62%|██████▎   | 340/544 [08:22<04:57,  1.46s/it][A

Loss:0.006466



Iteration:  63%|██████▎   | 341/544 [08:22<04:52,  1.44s/it][A
Iteration:  63%|██████▎   | 342/544 [08:24<04:54,  1.46s/it][A
Iteration:  63%|██████▎   | 343/544 [08:26<05:00,  1.49s/it][A
Iteration:  63%|██████▎   | 344/544 [08:27<05:08,  1.54s/it][A
Iteration:  63%|██████▎   | 345/544 [08:29<04:58,  1.50s/it][A
Iteration:  64%|██████▎   | 346/544 [08:30<04:57,  1.50s/it][A
Iteration:  64%|██████▍   | 347/544 [08:32<04:55,  1.50s/it][A
Iteration:  64%|██████▍   | 348/544 [08:33<04:58,  1.52s/it][A
Iteration:  64%|██████▍   | 349/544 [08:35<04:59,  1.54s/it][A
                                            4:51,  1.50s/it][A
Epoch:   0%|          | 0/1 [08:37<?, ?it/s]                
Iteration:  64%|██████▍   | 350/544 [08:37<04:51,  1.50s/it][A

Loss:0.009106



Iteration:  65%|██████▍   | 351/544 [08:38<04:45,  1.48s/it][A
Iteration:  65%|██████▍   | 352/544 [08:39<04:40,  1.46s/it][A
Iteration:  65%|██████▍   | 353/544 [08:41<04:44,  1.49s/it][A
Iteration:  65%|██████▌   | 354/544 [08:42<04:38,  1.47s/it][A
Iteration:  65%|██████▌   | 355/544 [08:43<04:39,  1.48s/it][A
Iteration:  65%|██████▌   | 356/544 [08:45<04:33,  1.45s/it][A
Iteration:  66%|██████▌   | 357/544 [08:46<04:29,  1.44s/it][A
Iteration:  66%|██████▌   | 358/544 [08:48<04:26,  1.43s/it][A
Iteration:  66%|██████▌   | 359/544 [08:49<04:32,  1.48s/it][A
                                            4:32,  1.48s/it][A
Epoch:   0%|          | 0/1 [08:51<?, ?it/s]                
Iteration:  66%|██████▌   | 360/544 [08:51<04:32,  1.48s/it][A

Loss:0.007258



Iteration:  66%|██████▋   | 361/544 [08:52<04:27,  1.46s/it][A
Iteration:  67%|██████▋   | 362/544 [08:54<04:36,  1.52s/it][A
Iteration:  67%|██████▋   | 363/544 [08:55<04:33,  1.51s/it][A
Iteration:  67%|██████▋   | 364/544 [08:57<04:26,  1.48s/it][A
Iteration:  67%|██████▋   | 365/544 [08:58<04:25,  1.48s/it][A
Iteration:  67%|██████▋   | 366/544 [09:00<04:25,  1.49s/it][A
Iteration:  67%|██████▋   | 367/544 [09:01<04:14,  1.44s/it][A
Iteration:  68%|██████▊   | 368/544 [09:03<04:24,  1.50s/it][A
Iteration:  68%|██████▊   | 369/544 [09:04<04:22,  1.50s/it][A
                                            4:20,  1.50s/it][A
Epoch:   0%|          | 0/1 [09:06<?, ?it/s]                
Iteration:  68%|██████▊   | 370/544 [09:06<04:20,  1.50s/it][A

Loss:0.010385



Iteration:  68%|██████▊   | 371/544 [09:07<04:14,  1.47s/it][A
Iteration:  68%|██████▊   | 372/544 [09:08<04:09,  1.45s/it][A
Iteration:  69%|██████▊   | 373/544 [09:10<04:14,  1.49s/it][A
Iteration:  69%|██████▉   | 374/544 [09:11<04:08,  1.46s/it][A
Iteration:  69%|██████▉   | 375/544 [09:13<04:08,  1.47s/it][A
Iteration:  69%|██████▉   | 376/544 [09:14<04:04,  1.46s/it][A
Iteration:  69%|██████▉   | 377/544 [09:16<04:00,  1.44s/it][A
Iteration:  69%|██████▉   | 378/544 [09:17<04:05,  1.48s/it][A
Iteration:  70%|██████▉   | 379/544 [09:19<04:00,  1.46s/it][A
                                            3:57,  1.45s/it][A
Epoch:   0%|          | 0/1 [09:21<?, ?it/s]                
Iteration:  70%|██████▉   | 380/544 [09:21<03:57,  1.45s/it][A

Loss:0.009564



Iteration:  70%|███████   | 381/544 [09:22<04:02,  1.49s/it][A
Iteration:  70%|███████   | 382/544 [09:23<03:56,  1.46s/it][A
Iteration:  70%|███████   | 383/544 [09:25<03:57,  1.47s/it][A
Iteration:  71%|███████   | 384/544 [09:26<04:01,  1.51s/it][A
Iteration:  71%|███████   | 385/544 [09:28<03:59,  1.50s/it][A
Iteration:  71%|███████   | 386/544 [09:29<03:48,  1.45s/it][A
Iteration:  71%|███████   | 387/544 [09:31<03:53,  1.49s/it][A
Iteration:  71%|███████▏  | 388/544 [09:32<03:56,  1.51s/it][A
Iteration:  72%|███████▏  | 389/544 [09:34<03:49,  1.48s/it][A
                                            3:52,  1.51s/it][A
Epoch:   0%|          | 0/1 [09:36<?, ?it/s]                
Iteration:  72%|███████▏  | 390/544 [09:36<03:52,  1.51s/it][A

Loss:0.011374



Iteration:  72%|███████▏  | 391/544 [09:37<03:47,  1.48s/it][A
Iteration:  72%|███████▏  | 392/544 [09:38<03:41,  1.46s/it][A
Iteration:  72%|███████▏  | 393/544 [09:40<03:42,  1.47s/it][A
Iteration:  72%|███████▏  | 394/544 [09:41<03:41,  1.48s/it][A
Iteration:  73%|███████▎  | 395/544 [09:42<03:37,  1.46s/it][A
Iteration:  73%|███████▎  | 396/544 [09:44<03:38,  1.47s/it][A
Iteration:  73%|███████▎  | 397/544 [09:45<03:37,  1.48s/it][A
Iteration:  73%|███████▎  | 398/544 [09:47<03:33,  1.46s/it][A
Iteration:  73%|███████▎  | 399/544 [09:48<03:29,  1.45s/it][A
                                            3:26,  1.43s/it][A
Epoch:   0%|          | 0/1 [09:50<?, ?it/s]                
Iteration:  74%|███████▎  | 400/544 [09:50<03:26,  1.43s/it][A

Loss:0.001646



Iteration:  74%|███████▎  | 401/544 [09:51<03:28,  1.46s/it][A
Iteration:  74%|███████▍  | 402/544 [09:53<03:24,  1.44s/it][A
Iteration:  74%|███████▍  | 403/544 [09:54<03:21,  1.43s/it][A
Iteration:  74%|███████▍  | 404/544 [09:55<03:22,  1.45s/it][A
Iteration:  74%|███████▍  | 405/544 [09:57<03:19,  1.44s/it][A
Iteration:  75%|███████▍  | 406/544 [09:58<03:20,  1.45s/it][A
Iteration:  75%|███████▍  | 407/544 [10:00<03:23,  1.49s/it][A
Iteration:  75%|███████▌  | 408/544 [10:01<03:19,  1.46s/it][A
Iteration:  75%|███████▌  | 409/544 [10:03<03:11,  1.42s/it][A
                                            3:13,  1.44s/it][A
Epoch:   0%|          | 0/1 [10:05<?, ?it/s]                
Iteration:  75%|███████▌  | 410/544 [10:05<03:13,  1.44s/it][A

Loss:0.011223



Iteration:  76%|███████▌  | 411/544 [10:06<03:14,  1.46s/it][A
Iteration:  76%|███████▌  | 412/544 [10:07<03:10,  1.44s/it][A
Iteration:  76%|███████▌  | 413/544 [10:08<03:07,  1.43s/it][A
Iteration:  76%|███████▌  | 414/544 [10:10<03:04,  1.42s/it][A
Iteration:  76%|███████▋  | 415/544 [10:11<03:02,  1.42s/it][A
Iteration:  76%|███████▋  | 416/544 [10:13<03:07,  1.46s/it][A
Iteration:  77%|███████▋  | 417/544 [10:14<03:12,  1.52s/it][A
Iteration:  77%|███████▋  | 418/544 [10:16<03:10,  1.51s/it][A
Iteration:  77%|███████▋  | 419/544 [10:18<03:14,  1.55s/it][A
                                            3:13,  1.56s/it][A
Epoch:   0%|          | 0/1 [10:20<?, ?it/s]                
Iteration:  77%|███████▋  | 420/544 [10:20<03:13,  1.56s/it][A

Loss:0.011658



Iteration:  77%|███████▋  | 421/544 [10:21<03:06,  1.52s/it][A
Iteration:  78%|███████▊  | 422/544 [10:22<03:04,  1.51s/it][A
Iteration:  78%|███████▊  | 423/544 [10:24<03:02,  1.51s/it][A
Iteration:  78%|███████▊  | 424/544 [10:25<02:57,  1.48s/it][A
Iteration:  78%|███████▊  | 425/544 [10:27<02:56,  1.48s/it][A
Iteration:  78%|███████▊  | 426/544 [10:28<02:52,  1.46s/it][A
Iteration:  78%|███████▊  | 427/544 [10:29<02:49,  1.45s/it][A
Iteration:  79%|███████▊  | 428/544 [10:31<02:49,  1.46s/it][A
Iteration:  79%|███████▉  | 429/544 [10:32<02:54,  1.52s/it][A
                                            2:52,  1.51s/it][A
Epoch:   0%|          | 0/1 [10:35<?, ?it/s]                
Iteration:  79%|███████▉  | 430/544 [10:35<02:52,  1.51s/it][A

Loss:0.006936



Iteration:  79%|███████▉  | 431/544 [10:36<02:55,  1.55s/it][A
Iteration:  79%|███████▉  | 432/544 [10:37<02:55,  1.56s/it][A
Iteration:  80%|███████▉  | 433/544 [10:39<02:51,  1.54s/it][A
Iteration:  80%|███████▉  | 434/544 [10:40<02:48,  1.53s/it][A
Iteration:  80%|███████▉  | 435/544 [10:42<02:43,  1.50s/it][A
Iteration:  80%|████████  | 436/544 [10:43<02:43,  1.52s/it][A
Iteration:  80%|████████  | 437/544 [10:45<02:38,  1.48s/it][A
Iteration:  81%|████████  | 438/544 [10:46<02:35,  1.46s/it][A
Iteration:  81%|████████  | 439/544 [10:48<02:37,  1.50s/it][A
                                            2:33,  1.47s/it][A
Epoch:   0%|          | 0/1 [10:50<?, ?it/s]                
Iteration:  81%|████████  | 440/544 [10:50<02:33,  1.47s/it][A

Loss:0.010197



Iteration:  81%|████████  | 441/544 [10:51<02:32,  1.48s/it][A
Iteration:  81%|████████▏ | 442/544 [10:52<02:31,  1.48s/it][A
Iteration:  81%|████████▏ | 443/544 [10:53<02:30,  1.49s/it][A
Iteration:  82%|████████▏ | 444/544 [10:55<02:26,  1.47s/it][A
Iteration:  82%|████████▏ | 445/544 [10:56<02:25,  1.47s/it][A
Iteration:  82%|████████▏ | 446/544 [10:58<02:22,  1.45s/it][A
Iteration:  82%|████████▏ | 447/544 [10:59<02:19,  1.44s/it][A
Iteration:  82%|████████▏ | 448/544 [11:01<02:22,  1.48s/it][A
Iteration:  83%|████████▎ | 449/544 [11:02<02:20,  1.48s/it][A
                                            2:17,  1.46s/it][A
Epoch:   0%|          | 0/1 [11:04<?, ?it/s]                
Iteration:  83%|████████▎ | 450/544 [11:04<02:17,  1.46s/it][A

Loss:0.017916



Iteration:  83%|████████▎ | 451/544 [11:05<02:14,  1.45s/it][A
Iteration:  83%|████████▎ | 452/544 [11:07<02:16,  1.49s/it][A
Iteration:  83%|████████▎ | 453/544 [11:08<02:19,  1.54s/it][A
Iteration:  83%|████████▎ | 454/544 [11:10<02:16,  1.52s/it][A
Iteration:  84%|████████▎ | 455/544 [11:11<02:12,  1.49s/it][A
Iteration:  84%|████████▍ | 456/544 [11:13<02:17,  1.56s/it][A
Iteration:  84%|████████▍ | 457/544 [11:14<02:13,  1.54s/it][A
Iteration:  84%|████████▍ | 458/544 [11:16<02:08,  1.50s/it][A
Iteration:  84%|████████▍ | 459/544 [11:18<02:13,  1.57s/it][A
                                            2:11,  1.57s/it][A
Epoch:   0%|          | 0/1 [11:20<?, ?it/s]                
Iteration:  85%|████████▍ | 460/544 [11:20<02:11,  1.57s/it][A

Loss:0.006646



Iteration:  85%|████████▍ | 461/544 [11:21<02:06,  1.52s/it][A
Iteration:  85%|████████▍ | 462/544 [11:22<02:04,  1.52s/it][A
Iteration:  85%|████████▌ | 463/544 [11:24<02:03,  1.53s/it][A
Iteration:  85%|████████▌ | 464/544 [11:25<01:59,  1.49s/it][A
Iteration:  85%|████████▌ | 465/544 [11:26<01:55,  1.46s/it][A
Iteration:  86%|████████▌ | 466/544 [11:28<01:52,  1.45s/it][A
Iteration:  86%|████████▌ | 467/544 [11:29<01:54,  1.49s/it][A
Iteration:  86%|████████▌ | 468/544 [11:31<01:52,  1.49s/it][A
Iteration:  86%|████████▌ | 469/544 [11:32<01:53,  1.51s/it][A
                                            1:49,  1.48s/it][A
Epoch:   0%|          | 0/1 [11:35<?, ?it/s]                
Iteration:  86%|████████▋ | 470/544 [11:35<01:49,  1.48s/it][A

Loss:0.003486



Iteration:  87%|████████▋ | 471/544 [11:35<01:48,  1.49s/it][A
Iteration:  87%|████████▋ | 472/544 [11:37<01:45,  1.46s/it][A
Iteration:  87%|████████▋ | 473/544 [11:38<01:49,  1.54s/it][A
Iteration:  87%|████████▋ | 474/544 [11:40<01:47,  1.53s/it][A
Iteration:  87%|████████▋ | 475/544 [11:41<01:42,  1.49s/it][A
Iteration:  88%|████████▊ | 476/544 [11:43<01:39,  1.47s/it][A
Iteration:  88%|████████▊ | 477/544 [11:44<01:38,  1.48s/it][A
Iteration:  88%|████████▊ | 478/544 [11:46<01:39,  1.50s/it][A
Iteration:  88%|████████▊ | 479/544 [11:47<01:37,  1.50s/it][A
                                            1:36,  1.50s/it][A
Epoch:   0%|          | 0/1 [11:50<?, ?it/s]                
Iteration:  88%|████████▊ | 480/544 [11:50<01:36,  1.50s/it][A

Loss:0.008434



Iteration:  88%|████████▊ | 481/544 [11:50<01:35,  1.52s/it][A
Iteration:  89%|████████▊ | 482/544 [11:52<01:35,  1.54s/it][A
Iteration:  89%|████████▉ | 483/544 [11:53<01:31,  1.50s/it][A
Iteration:  89%|████████▉ | 484/544 [11:55<01:29,  1.49s/it][A
Iteration:  89%|████████▉ | 485/544 [11:56<01:26,  1.47s/it][A
Iteration:  89%|████████▉ | 486/544 [11:58<01:28,  1.52s/it][A
Iteration:  90%|████████▉ | 487/544 [11:59<01:24,  1.49s/it][A
Iteration:  90%|████████▉ | 488/544 [12:01<01:22,  1.47s/it][A
Iteration:  90%|████████▉ | 489/544 [12:02<01:21,  1.48s/it][A
                                            1:22,  1.53s/it][A
Epoch:   0%|          | 0/1 [12:05<?, ?it/s]                
Iteration:  90%|█████████ | 490/544 [12:05<01:22,  1.53s/it][A

Loss:0.015265



Iteration:  90%|█████████ | 491/544 [12:05<01:20,  1.52s/it][A
Iteration:  90%|█████████ | 492/544 [12:07<01:18,  1.51s/it][A
Iteration:  91%|█████████ | 493/544 [12:08<01:15,  1.48s/it][A
Iteration:  91%|█████████ | 494/544 [12:10<01:12,  1.45s/it][A
Iteration:  91%|█████████ | 495/544 [12:11<01:12,  1.47s/it][A
Iteration:  91%|█████████ | 496/544 [12:13<01:14,  1.55s/it][A
Iteration:  91%|█████████▏| 497/544 [12:15<01:14,  1.58s/it][A
Iteration:  92%|█████████▏| 498/544 [12:16<01:09,  1.50s/it][A
Iteration:  92%|█████████▏| 499/544 [12:17<01:07,  1.50s/it][A
                                            1:05,  1.50s/it][A
Epoch:   0%|          | 0/1 [12:19<?, ?it/s]                
Iteration:  92%|█████████▏| 500/544 [12:19<01:05,  1.50s/it][A

Loss:0.001886



Iteration:  92%|█████████▏| 501/544 [12:20<01:03,  1.47s/it][A
Iteration:  92%|█████████▏| 502/544 [12:22<01:00,  1.45s/it][A
Iteration:  92%|█████████▏| 503/544 [12:23<01:00,  1.47s/it][A
Iteration:  93%|█████████▎| 504/544 [12:25<00:58,  1.47s/it][A
Iteration:  93%|█████████▎| 505/544 [12:26<00:58,  1.51s/it][A
Iteration:  93%|█████████▎| 506/544 [12:28<00:56,  1.48s/it][A
Iteration:  93%|█████████▎| 507/544 [12:29<00:56,  1.53s/it][A
Iteration:  93%|█████████▎| 508/544 [12:31<00:55,  1.55s/it][A
Iteration:  94%|█████████▎| 509/544 [12:32<00:52,  1.51s/it][A
                                            0:52,  1.55s/it][A
Epoch:   0%|          | 0/1 [12:35<?, ?it/s]                
Iteration:  94%|█████████▍| 510/544 [12:35<00:52,  1.55s/it][A

Loss:0.004752



Iteration:  94%|█████████▍| 511/544 [12:35<00:49,  1.51s/it][A
Iteration:  94%|█████████▍| 512/544 [12:37<00:49,  1.55s/it][A
Iteration:  94%|█████████▍| 513/544 [12:39<00:46,  1.51s/it][A
Iteration:  94%|█████████▍| 514/544 [12:40<00:44,  1.48s/it][A
Iteration:  95%|█████████▍| 515/544 [12:41<00:42,  1.48s/it][A
Iteration:  95%|█████████▍| 516/544 [12:43<00:41,  1.48s/it][A
Iteration:  95%|█████████▌| 517/544 [12:44<00:39,  1.46s/it][A
Iteration:  95%|█████████▌| 518/544 [12:46<00:37,  1.45s/it][A
Iteration:  95%|█████████▌| 519/544 [12:47<00:36,  1.46s/it][A
                                            0:34,  1.44s/it][A
Epoch:   0%|          | 0/1 [12:49<?, ?it/s]                
Iteration:  96%|█████████▌| 520/544 [12:49<00:34,  1.44s/it][A

Loss:0.009206



Iteration:  96%|█████████▌| 521/544 [12:50<00:32,  1.41s/it][A
Iteration:  96%|█████████▌| 522/544 [12:51<00:30,  1.41s/it][A
Iteration:  96%|█████████▌| 523/544 [12:53<00:31,  1.48s/it][A
Iteration:  96%|█████████▋| 524/544 [12:55<00:30,  1.51s/it][A
Iteration:  97%|█████████▋| 525/544 [12:56<00:28,  1.48s/it][A
Iteration:  97%|█████████▋| 526/544 [12:57<00:26,  1.46s/it][A
Iteration:  97%|█████████▋| 527/544 [12:59<00:25,  1.51s/it][A
Iteration:  97%|█████████▋| 528/544 [13:01<00:24,  1.51s/it][A
Iteration:  97%|█████████▋| 529/544 [13:02<00:22,  1.50s/it][A
                                            0:20,  1.50s/it][A
Epoch:   0%|          | 0/1 [13:04<?, ?it/s]                
Iteration:  97%|█████████▋| 530/544 [13:04<00:20,  1.50s/it][A

Loss:0.004459



Iteration:  98%|█████████▊| 531/544 [13:05<00:19,  1.47s/it][A
Iteration:  98%|█████████▊| 532/544 [13:06<00:17,  1.42s/it][A
Iteration:  98%|█████████▊| 533/544 [13:08<00:16,  1.49s/it][A
Iteration:  98%|█████████▊| 534/544 [13:09<00:15,  1.52s/it][A
Iteration:  98%|█████████▊| 535/544 [13:11<00:13,  1.48s/it][A
Iteration:  99%|█████████▊| 536/544 [13:12<00:11,  1.43s/it][A
Iteration:  99%|█████████▊| 537/544 [13:14<00:10,  1.47s/it][A
Iteration:  99%|█████████▉| 538/544 [13:15<00:08,  1.45s/it][A
Iteration:  99%|█████████▉| 539/544 [13:17<00:07,  1.49s/it][A
                                            0:05,  1.46s/it][A
Epoch:   0%|          | 0/1 [13:19<?, ?it/s]                
Iteration:  99%|█████████▉| 540/544 [13:19<00:05,  1.46s/it][A

Loss:0.001346



Iteration:  99%|█████████▉| 541/544 [13:20<00:04,  1.48s/it][A
Iteration: 100%|█████████▉| 542/544 [13:21<00:02,  1.48s/it][A
Iteration: 100%|█████████▉| 543/544 [13:23<00:01,  1.49s/it][A
Epoch: 100%|██████████| 1/1 [13:25<00:00, 805.55s/it]54s/it][A

Training time : 0.224 hrs





## Evaluate on Testing Dataset

The `predict` method of the `SequenceClassifier` returns a Numpy ndarray of raw predictions. Each predicting value is a label ID, and if you want to get the label values you will need to call function `get_label_values` from the dataset package. An instance of sklearn `LabelEncoder` is returned when loading the dataset and can be used to get the mapping between label ID and label value.

In [7]:
with Timer() as t:
    preds = model.predict(
        eval_dataset=test_dataset,
        device=CONFIG['device'],
        batch_size=CONFIG['batch_size'],
        verbose=CONFIG['verbose']
    )

print("Prediction time : {:.3f} hrs".format(t.interval / 3600))

Evaluating: 100%|██████████| 272/272 [01:36<00:00,  3.18it/s]

Prediction time : 0.027 hrs





Finally, we compute the precision, recall, and F1 metrics of the evaluation on the test set.

In [8]:
report = classification_report(
    test_dataset.tensors[2], 
    preds,
    digits=2,
    labels=np.unique(test_dataset.tensors[2]),
    target_names=label_encoder.classes_
)

print(report)

              precision    recall  f1-score   support

     culture       0.94      0.93      0.94       554
     diverse       0.95      0.95      0.95       679
     economy       0.87      0.89      0.88       543
    politics       0.88      0.89      0.89       796
      sports       0.99      0.98      0.99      1780

   micro avg       0.94      0.94      0.94      4352
   macro avg       0.93      0.93      0.93      4352
weighted avg       0.94      0.94      0.94      4352



In [9]:
# for testing
report_splits = report.split('\n')[-2].split()

sb.glue("precision", float(report_splits[2]))
sb.glue("recall", float(report_splits[3]))
sb.glue("f1", float(report_splits[4]))