<a href="https://colab.research.google.com/github/dlmacedo/starter-academic/blob/master/content/courses/deeplearning/notebooks/pytorch/example_fine_tuning_hf_datasets.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-tuning a deep-learning model with HuggingFace `🤗Datasets` library and PyTorch

Based on its [Quick tour example](https://huggingface.co/docs/datasets/quicktour.html)

Colab author: Manuel Romero / [@mrm8488](https://twitter.com/mrm8488)

In [1]:
!nvidia-smi

Sun Oct  4 02:06:46 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.23.05    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   62C    P8    12W /  70W |      0MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
!pip install -q datasets
!pip install -q transformers

[K     |████████████████████████████████| 153kB 5.4MB/s 
[K     |████████████████████████████████| 17.3MB 199kB/s 
[K     |████████████████████████████████| 245kB 56.6MB/s 
[K     |████████████████████████████████| 1.1MB 4.4MB/s 
[K     |████████████████████████████████| 1.1MB 24.8MB/s 
[K     |████████████████████████████████| 890kB 33.7MB/s 
[K     |████████████████████████████████| 3.0MB 46.7MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone


In [3]:
# Make sure that we have a recent version of pyarrow in the session before we continue - otherwise reboot Colab to activate it
import pyarrow
if int(pyarrow.__version__.split('.')[1]) < 16 and int(pyarrow.__version__.split('.')[0]) == 0:
    import os
    os.kill(os.getpid(), 9)

### Load the model and the tokenizer

In [4]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained('bert-base-cased')
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=435779157.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…




### Load and process the dataset

In [5]:
from datasets import load_dataset
dataset = load_dataset('glue', 'mrpc', split='train')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=7826.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=4473.0, style=ProgressStyle(description…


Downloading and preparing dataset glue/mrpc (download: 1.43 MiB, generated: 1.43 MiB, post-processed: Unknown size, total: 2.85 MiB) to /root/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4...


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Downloading', max=1.0, style=ProgressSt…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Downloading', max=1.0, style=ProgressSt…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Downloading', max=1.0, style=ProgressSt…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset glue downloaded and prepared to /root/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4. Subsequent calls will reuse this data.


In [6]:
len(dataset)

3668

In [7]:
dataset[0]

{'idx': 0,
 'label': 1,
 'sentence1': 'Amrozi accused his brother , whom he called " the witness " , of deliberately distorting his evidence .',
 'sentence2': 'Referring to him as only " the witness " , Amrozi accused his brother of deliberately distorting his evidence .'}

In [8]:
dataset.features

{'idx': Value(dtype='int32', id=None),
 'label': ClassLabel(num_classes=2, names=['not_equivalent', 'equivalent'], names_file=None, id=None),
 'sentence1': Value(dtype='string', id=None),
 'sentence2': Value(dtype='string', id=None)}

In [9]:
# dataset.filter(lambda example: example['label'] == dataset.features['label'].str2int('equivalent'))[0]

In [10]:
# dataset.filter(lambda example: example['label'] == dataset.features['label'].str2int('not_equivalent'))[0]

#### Tokenizing the dataset

In [11]:
def encode(examples):
    return tokenizer(examples['sentence1'], examples['sentence2'], truncation=True, padding='max_length')
dataset = dataset.map(encode, batched=True)
dataset[0]

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))




{'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  

#### Formatting the dataset

In [12]:
dataset = dataset.map(lambda examples: {'labels': examples['label']}, batched=True)

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))




In [13]:
import torch
dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
next(iter(dataloader))

  return torch.tensor(x, **format_kwargs)


{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'input_ids': tensor([[  101,  7277,  2180,  ...,     0,     0,     0],
         [  101, 10684,  2599,  ...,     0,     0,     0],
         [  101,  1220,  1125,  ...,     0,     0,     0],
         ...,
         [  101, 16944,  1107,  ...,     0,     0,     0],
         [  101,  1109, 11896,  ...,     0,     0,     0],
         [  101,  1109,  4173,  ...,     0,     0,     0]]),
 'labels': tensor([1, 0, 1, 0, 1, 1, 0, 1]),
 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]])}

#### Training

In [14]:
from tqdm import tqdm
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.train().to(device)
optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-5)

In [15]:
for epoch in range(3):
    for i, batch in enumerate(tqdm(dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs[0]
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if i % 10 == 0:
            print(f"loss: {loss}")

  0%|          | 1/459 [00:00<07:23,  1.03it/s]

loss: 0.7172756195068359


  2%|▏         | 11/459 [00:08<05:54,  1.26it/s]

loss: 0.7296287417411804


  5%|▍         | 21/459 [00:16<05:51,  1.25it/s]

loss: 0.5364887118339539


  7%|▋         | 31/459 [00:24<05:48,  1.23it/s]

loss: 0.5236343145370483


  9%|▉         | 41/459 [00:32<05:45,  1.21it/s]

loss: 0.625078558921814


 11%|█         | 51/459 [00:41<05:45,  1.18it/s]

loss: 0.5091382265090942


 13%|█▎        | 61/459 [00:49<05:44,  1.15it/s]

loss: 0.6293755769729614


 15%|█▌        | 71/459 [00:58<05:44,  1.13it/s]

loss: 0.711475133895874


 18%|█▊        | 81/459 [01:07<05:37,  1.12it/s]

loss: 0.7263116240501404


 20%|█▉        | 91/459 [01:16<05:23,  1.14it/s]

loss: 0.6979309320449829


 22%|██▏       | 101/459 [01:25<05:07,  1.16it/s]

loss: 0.4606996178627014


 24%|██▍       | 111/459 [01:33<04:55,  1.18it/s]

loss: 0.49716347455978394


 26%|██▋       | 121/459 [01:41<04:44,  1.19it/s]

loss: 0.5128161311149597


 29%|██▊       | 131/459 [01:50<04:35,  1.19it/s]

loss: 0.6707112789154053


 31%|███       | 141/459 [01:58<04:28,  1.19it/s]

loss: 0.49556922912597656


 33%|███▎      | 151/459 [02:07<04:21,  1.18it/s]

loss: 0.5270639657974243


 35%|███▌      | 161/459 [02:15<04:16,  1.16it/s]

loss: 0.7838438749313354


 37%|███▋      | 171/459 [02:24<04:09,  1.15it/s]

loss: 0.6588656902313232


 39%|███▉      | 181/459 [02:32<04:01,  1.15it/s]

loss: 0.6913950443267822


 42%|████▏     | 191/459 [02:41<03:53,  1.15it/s]

loss: 0.7369705438613892


 44%|████▍     | 201/459 [02:50<03:44,  1.15it/s]

loss: 0.3223009705543518


 46%|████▌     | 211/459 [02:58<03:34,  1.16it/s]

loss: 0.2625649571418762


 48%|████▊     | 221/459 [03:07<03:24,  1.16it/s]

loss: 0.48325517773628235


 50%|█████     | 231/459 [03:15<03:15,  1.17it/s]

loss: 0.30013909935951233


 53%|█████▎    | 241/459 [03:24<03:06,  1.17it/s]

loss: 0.7361812591552734


 55%|█████▍    | 251/459 [03:33<02:58,  1.17it/s]

loss: 0.7147536277770996


 57%|█████▋    | 261/459 [03:41<02:49,  1.17it/s]

loss: 0.7271447777748108


 59%|█████▉    | 271/459 [03:50<02:41,  1.16it/s]

loss: 0.6186606287956238


 61%|██████    | 281/459 [03:58<02:34,  1.15it/s]

loss: 0.37796056270599365


 63%|██████▎   | 291/459 [04:07<02:26,  1.15it/s]

loss: 0.32348760962486267


 66%|██████▌   | 301/459 [04:16<02:16,  1.16it/s]

loss: 0.6743075847625732


 68%|██████▊   | 311/459 [04:24<02:08,  1.15it/s]

loss: 0.3901532292366028


 70%|██████▉   | 321/459 [04:33<01:58,  1.16it/s]

loss: 0.3191572427749634


 72%|███████▏  | 331/459 [04:41<01:49,  1.16it/s]

loss: 0.30092623829841614


 74%|███████▍  | 341/459 [04:50<01:41,  1.17it/s]

loss: 0.24516957998275757


 76%|███████▋  | 351/459 [04:58<01:32,  1.17it/s]

loss: 0.5593866109848022


 79%|███████▊  | 361/459 [05:07<01:23,  1.17it/s]

loss: 0.1414974182844162


 81%|████████  | 371/459 [05:15<01:15,  1.17it/s]

loss: 0.4049012362957001


 83%|████████▎ | 381/459 [05:24<01:06,  1.17it/s]

loss: 0.38165226578712463


 85%|████████▌ | 391/459 [05:32<00:58,  1.16it/s]

loss: 0.4822346866130829


 87%|████████▋ | 401/459 [05:41<00:49,  1.17it/s]

loss: 0.7872997522354126


 90%|████████▉ | 411/459 [05:50<00:41,  1.17it/s]

loss: 0.44195812940597534


 92%|█████████▏| 421/459 [05:58<00:32,  1.16it/s]

loss: 0.3639391362667084


 94%|█████████▍| 431/459 [06:07<00:24,  1.16it/s]

loss: 0.2064320296049118


 96%|█████████▌| 441/459 [06:15<00:15,  1.16it/s]

loss: 0.6424264311790466


 98%|█████████▊| 451/459 [06:24<00:06,  1.16it/s]

loss: 0.5213192105293274


100%|██████████| 459/459 [06:30<00:00,  1.17it/s]
  0%|          | 1/459 [00:00<06:39,  1.15it/s]

loss: 0.8455681800842285


  2%|▏         | 11/459 [00:09<06:26,  1.16it/s]

loss: 0.7503973245620728


  5%|▍         | 21/459 [00:18<06:17,  1.16it/s]

loss: 0.38280510902404785


  7%|▋         | 31/459 [00:26<06:09,  1.16it/s]

loss: 0.22022800147533417


  9%|▉         | 41/459 [00:35<06:01,  1.16it/s]

loss: 0.4108743667602539


 11%|█         | 51/459 [00:43<05:53,  1.15it/s]

loss: 0.20817667245864868


 13%|█▎        | 61/459 [00:52<05:44,  1.16it/s]

loss: 0.8382773995399475


 15%|█▌        | 71/459 [01:01<05:36,  1.15it/s]

loss: 0.5222634077072144


 18%|█▊        | 81/459 [01:09<05:26,  1.16it/s]

loss: 0.44198328256607056


 20%|█▉        | 91/459 [01:18<05:17,  1.16it/s]

loss: 0.5540165305137634


 22%|██▏       | 101/459 [01:26<05:09,  1.16it/s]

loss: 0.6011371612548828


 24%|██▍       | 111/459 [01:35<05:00,  1.16it/s]

loss: 0.2989555895328522


 26%|██▋       | 121/459 [01:44<04:50,  1.16it/s]

loss: 0.2866935133934021


 29%|██▊       | 131/459 [01:52<04:43,  1.16it/s]

loss: 0.3875500559806824


 31%|███       | 141/459 [02:01<04:34,  1.16it/s]

loss: 0.502983570098877


 33%|███▎      | 151/459 [02:09<04:25,  1.16it/s]

loss: 0.2505875825881958


 35%|███▌      | 161/459 [02:18<04:17,  1.16it/s]

loss: 0.7445882558822632


 37%|███▋      | 171/459 [02:26<04:08,  1.16it/s]

loss: 0.2723318934440613


 39%|███▉      | 181/459 [02:35<03:59,  1.16it/s]

loss: 0.6344541311264038


 42%|████▏     | 191/459 [02:44<03:51,  1.16it/s]

loss: 0.4550623595714569


 44%|████▍     | 201/459 [02:52<03:43,  1.15it/s]

loss: 0.31510889530181885


 46%|████▌     | 211/459 [03:01<03:34,  1.16it/s]

loss: 0.13264907896518707


 48%|████▊     | 221/459 [03:09<03:25,  1.16it/s]

loss: 0.24630028009414673


 50%|█████     | 231/459 [03:18<03:16,  1.16it/s]

loss: 0.46723228693008423


 53%|█████▎    | 241/459 [03:27<03:07,  1.16it/s]

loss: 0.44160282611846924


 55%|█████▍    | 251/459 [03:35<02:58,  1.17it/s]

loss: 0.2970876097679138


 57%|█████▋    | 261/459 [03:44<02:50,  1.16it/s]

loss: 0.08937611430883408


 59%|█████▉    | 271/459 [03:52<02:41,  1.16it/s]

loss: 0.3178607225418091


 61%|██████    | 281/459 [04:01<02:32,  1.16it/s]

loss: 0.1356995701789856


 63%|██████▎   | 291/459 [04:09<02:24,  1.16it/s]

loss: 0.16187003254890442


 66%|██████▌   | 301/459 [04:18<02:16,  1.16it/s]

loss: 0.24016517400741577


 68%|██████▊   | 311/459 [04:27<02:07,  1.16it/s]

loss: 0.094879150390625


 70%|██████▉   | 321/459 [04:35<01:59,  1.16it/s]

loss: 0.3848307132720947


 72%|███████▏  | 331/459 [04:44<01:50,  1.16it/s]

loss: 0.08922835439443588


 74%|███████▍  | 341/459 [04:52<01:41,  1.16it/s]

loss: 0.13090011477470398


 76%|███████▋  | 351/459 [05:01<01:33,  1.16it/s]

loss: 0.3492284417152405


 79%|███████▊  | 361/459 [05:09<01:24,  1.16it/s]

loss: 0.045762527734041214


 81%|████████  | 371/459 [05:18<01:15,  1.17it/s]

loss: 0.11055445671081543


 83%|████████▎ | 381/459 [05:27<01:07,  1.16it/s]

loss: 0.07099258899688721


 85%|████████▌ | 391/459 [05:35<00:58,  1.16it/s]

loss: 0.1774512529373169


 87%|████████▋ | 401/459 [05:44<00:49,  1.16it/s]

loss: 0.5944624543190002


 90%|████████▉ | 411/459 [05:52<00:41,  1.16it/s]

loss: 0.4088588058948517


 92%|█████████▏| 421/459 [06:01<00:32,  1.16it/s]

loss: 0.07203858345746994


 94%|█████████▍| 431/459 [06:09<00:24,  1.16it/s]

loss: 0.09172537177801132


 96%|█████████▌| 441/459 [06:18<00:15,  1.16it/s]

loss: 0.3740120530128479


 98%|█████████▊| 451/459 [06:27<00:06,  1.15it/s]

loss: 0.4133208990097046


100%|██████████| 459/459 [06:33<00:00,  1.17it/s]
  0%|          | 1/459 [00:00<06:38,  1.15it/s]

loss: 0.5766723155975342


  2%|▏         | 11/459 [00:09<06:28,  1.15it/s]

loss: 0.4334522783756256


  5%|▍         | 21/459 [00:18<06:16,  1.16it/s]

loss: 0.18608692288398743


  7%|▋         | 31/459 [00:26<06:09,  1.16it/s]

loss: 0.025812732055783272


  9%|▉         | 41/459 [00:35<05:59,  1.16it/s]

loss: 0.09308965504169464


 11%|█         | 51/459 [00:43<05:50,  1.16it/s]

loss: 0.0573507584631443


 13%|█▎        | 61/459 [00:52<05:41,  1.17it/s]

loss: 0.7015887498855591


 15%|█▌        | 71/459 [01:00<05:33,  1.16it/s]

loss: 0.06488388031721115


 18%|█▊        | 81/459 [01:09<05:25,  1.16it/s]

loss: 0.10910284519195557


 20%|█▉        | 91/459 [01:17<05:16,  1.16it/s]

loss: 0.35144883394241333


 22%|██▏       | 101/459 [01:26<05:08,  1.16it/s]

loss: 0.5147315263748169


 24%|██▍       | 111/459 [01:35<04:58,  1.16it/s]

loss: 0.10290461033582687


 26%|██▋       | 121/459 [01:43<04:50,  1.16it/s]

loss: 0.0894022285938263


 29%|██▊       | 131/459 [01:52<04:42,  1.16it/s]

loss: 0.7623075246810913


 31%|███       | 141/459 [02:00<04:33,  1.16it/s]

loss: 0.2123534083366394


 33%|███▎      | 151/459 [02:09<04:24,  1.16it/s]

loss: 0.04857476055622101


 35%|███▌      | 161/459 [02:17<04:16,  1.16it/s]

loss: 0.12228884547948837


 37%|███▋      | 171/459 [02:26<04:08,  1.16it/s]

loss: 0.041495226323604584


 39%|███▉      | 181/459 [02:35<04:00,  1.16it/s]

loss: 0.09209441393613815


 42%|████▏     | 191/459 [02:43<03:50,  1.16it/s]

loss: 0.1499522179365158


 44%|████▍     | 201/459 [02:52<03:42,  1.16it/s]

loss: 0.04426620900630951


 46%|████▌     | 211/459 [03:00<03:33,  1.16it/s]

loss: 0.06029752641916275


 48%|████▊     | 221/459 [03:09<03:25,  1.16it/s]

loss: 0.03180631995201111


 50%|█████     | 231/459 [03:17<03:16,  1.16it/s]

loss: 0.02761862426996231


 53%|█████▎    | 241/459 [03:26<03:08,  1.16it/s]

loss: 0.042334601283073425


 55%|█████▍    | 251/459 [03:35<02:59,  1.16it/s]

loss: 0.09980802983045578


 57%|█████▋    | 261/459 [03:43<02:51,  1.16it/s]

loss: 0.020990146324038506


 59%|█████▉    | 271/459 [03:52<02:41,  1.16it/s]

loss: 0.03239567205309868


 61%|██████    | 281/459 [04:00<02:33,  1.16it/s]

loss: 0.00700672622770071


 63%|██████▎   | 291/459 [04:09<02:24,  1.16it/s]

loss: 0.021283205598592758


 66%|██████▌   | 301/459 [04:17<02:15,  1.16it/s]

loss: 0.2726689279079437


 68%|██████▊   | 311/459 [04:26<02:07,  1.16it/s]

loss: 0.7619442343711853


 70%|██████▉   | 321/459 [04:35<01:59,  1.16it/s]

loss: 0.02863924391567707


 72%|███████▏  | 331/459 [04:43<01:50,  1.16it/s]

loss: 0.04163336381316185


 74%|███████▍  | 341/459 [04:52<01:41,  1.16it/s]

loss: 0.22699590027332306


 76%|███████▋  | 351/459 [05:00<01:32,  1.16it/s]

loss: 0.04255276918411255


 79%|███████▊  | 361/459 [05:09<01:24,  1.16it/s]

loss: 0.021186182275414467


 81%|████████  | 371/459 [05:18<01:15,  1.16it/s]

loss: 0.029161306098103523


 83%|████████▎ | 381/459 [05:26<01:07,  1.16it/s]

loss: 0.379173219203949


 85%|████████▌ | 391/459 [05:35<00:58,  1.16it/s]

loss: 0.025724340230226517


 87%|████████▋ | 401/459 [05:43<00:50,  1.15it/s]

loss: 0.9944639801979065


 90%|████████▉ | 411/459 [05:52<00:41,  1.16it/s]

loss: 0.03986262530088425


 92%|█████████▏| 421/459 [06:01<00:32,  1.16it/s]

loss: 0.02317579835653305


 94%|█████████▍| 431/459 [06:09<00:24,  1.17it/s]

loss: 0.009572735987603664


 96%|█████████▌| 441/459 [06:18<00:15,  1.16it/s]

loss: 0.3584897816181183


 98%|█████████▊| 451/459 [06:26<00:06,  1.16it/s]

loss: 0.1070757806301117


100%|██████████| 459/459 [06:33<00:00,  1.17it/s]
