In [1]:
# =====================================
# Cell 1: Setup & Imports
# =====================================
#%pip install transformers datasets accelerate -q
#%pip install sentencepiece -q   # in case needed by certain models

import os
import math
import random
import torch
import pandas as pd
import json
from collections import defaultdict

# Hugging Face
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    AutoModelForCausalLM
)
from datasets import Dataset, DatasetDict


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def load_in_scope_data(json_path):

    with open(json_path, "r") as file:
        data = json.load(file)

    def process_in_scope(data_split):
        data_list = []
        for entry in data_split:
            # entry[0] is the query, entry[1] is the intent
            data_list.append({"query": entry[0], "intent": entry[1]})
        return pd.DataFrame(data_list)

    # Extract only the in-scope splits
    # Out of scope data is ignored in data synthesis
    train_df = process_in_scope(data["train"])
    val_df   = process_in_scope(data["val"])
    test_df  = process_in_scope(data["test"])
    
    return train_df, val_df, test_df

# Usage:
train_df, val_df, test_df = load_in_scope_data("data_full.json")

train_dataset = Dataset.from_pandas(train_df.reset_index(drop=True))
val_dataset = Dataset.from_pandas(val_df.reset_index(drop=True))
test_dataset  = Dataset.from_pandas(test_df.reset_index(drop=True))

print("Train shape:", train_dataset.shape)
print("Val shape:",   val_dataset.shape)
print("Test shape:",  test_dataset.shape)


Train shape: (15000, 2)
Val shape: (3000, 2)
Test shape: (4500, 2)


In [None]:
# Define the pre-trained model name from Hugging Face
model_name_g = "roberta-base"

# Initialize the tokenizer associated with the chosen model
tokenizer_g = AutoTokenizer.from_pretrained(model_name_g)

# Initialize the model for sequence classification
# The `num_labels` parameter is temporarily set to 150 and will be adjusted later based on the dataset
model_g = AutoModelForSequenceClassification.from_pretrained(model_name_g, num_labels=150)


# Gather unique intents from the training set
unique_intents = sorted(list(set(train_dataset["intent"])))
num_labels = len(unique_intents)
print(f"Number of distinct in-scope intents: {num_labels}")

# Adjust the model's classification head if the number of labels has changed
if model_g.config.num_labels != num_labels:
    # Resize the token embeddings in case the tokenizer has been updated
    model_g.resize_token_embeddings(len(tokenizer_g))
    
    # Replace the classifier with a new one that has the correct number of output labels
    # `model_g.config.dim` should be replaced with the correct hidden size if different
    # For RoBERTa, it's typically 768
    model_g.classifier = torch.nn.Linear(model_g.config.hidden_size, num_labels)

# Create mappings from labels to IDs and vice versa
label2id = {label: i for i, label in enumerate(unique_intents)}
id2label = {i: label for label, i in label2id.items()}

# Update the model configuration with these mappings
model_g.config.label2id = label2id
model_g.config.id2label = id2label

def encode_batch_g(batch):
    # Tokenize the 'query' field with truncation and padding
    enc = tokenizer_g(
        batch["query"], 
        truncation=True, 
        padding="max_length", 
        max_length=32  # Maximum sequence length; adjust based on data
    )
    
    # Map each intent label in the batch to its corresponding ID
    enc["labels"] = [label2id[intent] for intent in batch["intent"]]
    
    return enc

# Apply the encoding function to the training and validation datasets
# The `batched=True` parameter processes multiple examples at once for efficiency
# `remove_columns` removes the original columns not needed for training
train_encoded_g = train_dataset.map(encode_batch_g, batched=True, remove_columns=train_dataset.column_names)
val_encoded_g   = val_dataset.map(encode_batch_g,   batched=True, remove_columns=val_dataset.column_names)


# Define training arguments for the Hugging Face Trainer
training_args_g = TrainingArguments(
    output_dir="g_output",                # Directory to save model checkpoints and logs
    eval_strategy="epoch",                # Evaluation is performed at the end of each epoch
    per_device_train_batch_size=8,        # Batch size per device during training
    per_device_eval_batch_size=8,         # Batch size per device during evaluation
    num_train_epochs=2,                    # Total number of training epochs; adjust based on dataset size
    logging_steps=10,                     # Log training metrics every 10 steps
    seed=42                               # Random seed for reproducibility
)

# Initialize the Trainer with the model, training arguments, and datasets
trainer_g = Trainer(
    model=model_g,                        # The sequence classification model to be trained
    args=training_args_g,                 # Training configuration
    train_dataset=train_encoded_g,        # Encoded training dataset
    eval_dataset=val_encoded_g            # Encoded validation dataset
)

# ---------------------------
# Model Training
# ---------------------------

# Start the training process
trainer_g.train()


# Save the fine-tuned model to the specified directory
model_g.save_pretrained("model_g_finetuned")

# Save the tokenizer to the same directory for future use
tokenizer_g.save_pretrained("tokenizer_g_finetuned")


Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Number of distinct in-scope intents: 150


Map: 100%|██████████| 15000/15000 [00:00<00:00, 52387.06 examples/s]
Map: 100%|██████████| 3000/3000 [00:00<00:00, 69684.79 examples/s]
  0%|          | 10/3750 [00:04<22:16,  2.80it/s] 

{'loss': 5.0194, 'grad_norm': 5.052273273468018, 'learning_rate': 4.986666666666667e-05, 'epoch': 0.01}


  1%|          | 20/3750 [00:08<21:10,  2.94it/s]

{'loss': 4.9705, 'grad_norm': 6.139972686767578, 'learning_rate': 4.973333333333334e-05, 'epoch': 0.01}


  1%|          | 30/3750 [00:11<21:00,  2.95it/s]

{'loss': 5.019, 'grad_norm': 3.5842578411102295, 'learning_rate': 4.96e-05, 'epoch': 0.02}


  1%|          | 40/3750 [00:15<21:01,  2.94it/s]

{'loss': 5.0433, 'grad_norm': 4.0746259689331055, 'learning_rate': 4.9466666666666665e-05, 'epoch': 0.02}


  1%|▏         | 50/3750 [00:18<20:52,  2.95it/s]

{'loss': 5.0259, 'grad_norm': 3.5344083309173584, 'learning_rate': 4.933333333333334e-05, 'epoch': 0.03}


  2%|▏         | 60/3750 [00:21<20:46,  2.96it/s]

{'loss': 5.0086, 'grad_norm': 2.9226601123809814, 'learning_rate': 4.92e-05, 'epoch': 0.03}


  2%|▏         | 70/3750 [00:25<20:55,  2.93it/s]

{'loss': 5.0115, 'grad_norm': 3.2236642837524414, 'learning_rate': 4.906666666666667e-05, 'epoch': 0.04}


  2%|▏         | 80/3750 [00:28<20:40,  2.96it/s]

{'loss': 5.0643, 'grad_norm': 4.371707916259766, 'learning_rate': 4.8933333333333335e-05, 'epoch': 0.04}


  2%|▏         | 90/3750 [00:32<20:37,  2.96it/s]

{'loss': 5.0454, 'grad_norm': 2.7682693004608154, 'learning_rate': 4.88e-05, 'epoch': 0.05}


  3%|▎         | 100/3750 [00:35<20:37,  2.95it/s]

{'loss': 5.0014, 'grad_norm': 3.4534249305725098, 'learning_rate': 4.866666666666667e-05, 'epoch': 0.05}


  3%|▎         | 110/3750 [00:38<21:12,  2.86it/s]

{'loss': 4.9942, 'grad_norm': 3.630795955657959, 'learning_rate': 4.853333333333334e-05, 'epoch': 0.06}


  3%|▎         | 120/3750 [00:42<20:38,  2.93it/s]

{'loss': 5.0422, 'grad_norm': 4.792765140533447, 'learning_rate': 4.8400000000000004e-05, 'epoch': 0.06}


  3%|▎         | 130/3750 [00:45<20:32,  2.94it/s]

{'loss': 5.0224, 'grad_norm': 2.862114191055298, 'learning_rate': 4.826666666666667e-05, 'epoch': 0.07}


  4%|▎         | 140/3750 [00:49<20:31,  2.93it/s]

{'loss': 5.0185, 'grad_norm': 2.583104133605957, 'learning_rate': 4.8133333333333336e-05, 'epoch': 0.07}


  4%|▍         | 150/3750 [00:52<20:33,  2.92it/s]

{'loss': 4.9952, 'grad_norm': 3.785583972930908, 'learning_rate': 4.8e-05, 'epoch': 0.08}


  4%|▍         | 160/3750 [00:56<20:22,  2.94it/s]

{'loss': 5.0875, 'grad_norm': 3.4088308811187744, 'learning_rate': 4.7866666666666674e-05, 'epoch': 0.09}


  5%|▍         | 170/3750 [00:59<20:10,  2.96it/s]

{'loss': 5.0214, 'grad_norm': 4.124048233032227, 'learning_rate': 4.773333333333333e-05, 'epoch': 0.09}


  5%|▍         | 180/3750 [01:02<20:20,  2.93it/s]

{'loss': 5.0541, 'grad_norm': 3.4611361026763916, 'learning_rate': 4.76e-05, 'epoch': 0.1}


  5%|▌         | 190/3750 [01:06<20:13,  2.93it/s]

{'loss': 5.0222, 'grad_norm': 3.174656391143799, 'learning_rate': 4.746666666666667e-05, 'epoch': 0.1}


  5%|▌         | 200/3750 [01:09<20:15,  2.92it/s]

{'loss': 5.0436, 'grad_norm': 2.392371654510498, 'learning_rate': 4.7333333333333336e-05, 'epoch': 0.11}


  6%|▌         | 210/3750 [01:13<20:19,  2.90it/s]

{'loss': 5.0102, 'grad_norm': 2.2688045501708984, 'learning_rate': 4.72e-05, 'epoch': 0.11}


  6%|▌         | 220/3750 [01:16<19:58,  2.94it/s]

{'loss': 5.0348, 'grad_norm': 3.214308500289917, 'learning_rate': 4.706666666666667e-05, 'epoch': 0.12}


  6%|▌         | 230/3750 [01:19<20:01,  2.93it/s]

{'loss': 5.0362, 'grad_norm': 2.684067726135254, 'learning_rate': 4.6933333333333333e-05, 'epoch': 0.12}


  6%|▋         | 240/3750 [01:23<19:46,  2.96it/s]

{'loss': 4.9944, 'grad_norm': 3.159722328186035, 'learning_rate': 4.6800000000000006e-05, 'epoch': 0.13}


  7%|▋         | 250/3750 [01:26<19:51,  2.94it/s]

{'loss': 5.0258, 'grad_norm': 2.1722652912139893, 'learning_rate': 4.666666666666667e-05, 'epoch': 0.13}


  7%|▋         | 260/3750 [01:30<20:11,  2.88it/s]

{'loss': 5.0484, 'grad_norm': 2.1719844341278076, 'learning_rate': 4.653333333333334e-05, 'epoch': 0.14}


  7%|▋         | 270/3750 [01:33<21:19,  2.72it/s]

{'loss': 5.0299, 'grad_norm': 4.517457008361816, 'learning_rate': 4.64e-05, 'epoch': 0.14}


  7%|▋         | 280/3750 [01:37<19:51,  2.91it/s]

{'loss': 4.9836, 'grad_norm': 2.03983211517334, 'learning_rate': 4.626666666666667e-05, 'epoch': 0.15}


  8%|▊         | 290/3750 [01:40<19:40,  2.93it/s]

{'loss': 4.9969, 'grad_norm': 3.1691040992736816, 'learning_rate': 4.6133333333333334e-05, 'epoch': 0.15}


  8%|▊         | 300/3750 [01:44<19:41,  2.92it/s]

{'loss': 5.0547, 'grad_norm': 2.8082611560821533, 'learning_rate': 4.600000000000001e-05, 'epoch': 0.16}


  8%|▊         | 310/3750 [01:47<19:34,  2.93it/s]

{'loss': 5.0721, 'grad_norm': 3.9297780990600586, 'learning_rate': 4.5866666666666666e-05, 'epoch': 0.17}


  9%|▊         | 320/3750 [01:51<19:30,  2.93it/s]

{'loss': 5.0344, 'grad_norm': 3.1770384311676025, 'learning_rate': 4.573333333333333e-05, 'epoch': 0.17}


  9%|▉         | 330/3750 [01:54<19:29,  2.92it/s]

{'loss': 5.023, 'grad_norm': 2.2063772678375244, 'learning_rate': 4.5600000000000004e-05, 'epoch': 0.18}


  9%|▉         | 340/3750 [01:58<19:24,  2.93it/s]

{'loss': 5.0175, 'grad_norm': 2.2341372966766357, 'learning_rate': 4.546666666666667e-05, 'epoch': 0.18}


  9%|▉         | 350/3750 [02:01<19:21,  2.93it/s]

{'loss': 5.0503, 'grad_norm': 2.1586201190948486, 'learning_rate': 4.5333333333333335e-05, 'epoch': 0.19}


 10%|▉         | 360/3750 [02:04<19:15,  2.93it/s]

{'loss': 5.0182, 'grad_norm': 2.128214120864868, 'learning_rate': 4.52e-05, 'epoch': 0.19}


 10%|▉         | 370/3750 [02:08<19:09,  2.94it/s]

{'loss': 5.0522, 'grad_norm': 3.0447378158569336, 'learning_rate': 4.5066666666666667e-05, 'epoch': 0.2}


 10%|█         | 380/3750 [02:11<19:07,  2.94it/s]

{'loss': 5.0097, 'grad_norm': 2.0937387943267822, 'learning_rate': 4.493333333333333e-05, 'epoch': 0.2}


 10%|█         | 390/3750 [02:15<19:05,  2.93it/s]

{'loss': 5.0146, 'grad_norm': 3.00099515914917, 'learning_rate': 4.4800000000000005e-05, 'epoch': 0.21}


 11%|█         | 400/3750 [02:18<19:00,  2.94it/s]

{'loss': 5.0036, 'grad_norm': 3.84261155128479, 'learning_rate': 4.466666666666667e-05, 'epoch': 0.21}


 11%|█         | 410/3750 [02:21<18:59,  2.93it/s]

{'loss': 5.033, 'grad_norm': 3.164228916168213, 'learning_rate': 4.4533333333333336e-05, 'epoch': 0.22}


 11%|█         | 420/3750 [02:25<19:01,  2.92it/s]

{'loss': 5.0121, 'grad_norm': 3.1757349967956543, 'learning_rate': 4.44e-05, 'epoch': 0.22}


 11%|█▏        | 430/3750 [02:28<19:00,  2.91it/s]

{'loss': 5.0265, 'grad_norm': 3.2207701206207275, 'learning_rate': 4.426666666666667e-05, 'epoch': 0.23}


 12%|█▏        | 440/3750 [02:32<18:45,  2.94it/s]

{'loss': 5.0613, 'grad_norm': 2.1988778114318848, 'learning_rate': 4.413333333333334e-05, 'epoch': 0.23}


 12%|█▏        | 450/3750 [02:35<19:09,  2.87it/s]

{'loss': 4.9777, 'grad_norm': 2.043647050857544, 'learning_rate': 4.4000000000000006e-05, 'epoch': 0.24}


 12%|█▏        | 460/3750 [02:39<18:39,  2.94it/s]

{'loss': 5.0488, 'grad_norm': 2.0327255725860596, 'learning_rate': 4.3866666666666665e-05, 'epoch': 0.25}


 13%|█▎        | 470/3750 [02:42<18:37,  2.94it/s]

{'loss': 5.0907, 'grad_norm': 1.9551926851272583, 'learning_rate': 4.373333333333334e-05, 'epoch': 0.25}


 13%|█▎        | 480/3750 [02:46<18:32,  2.94it/s]

{'loss': 5.0353, 'grad_norm': 2.188025712966919, 'learning_rate': 4.36e-05, 'epoch': 0.26}


 13%|█▎        | 490/3750 [02:49<18:32,  2.93it/s]

{'loss': 5.0092, 'grad_norm': 3.005448341369629, 'learning_rate': 4.346666666666667e-05, 'epoch': 0.26}


 13%|█▎        | 500/3750 [02:52<18:28,  2.93it/s]

{'loss': 5.0264, 'grad_norm': 1.9124590158462524, 'learning_rate': 4.3333333333333334e-05, 'epoch': 0.27}


 14%|█▎        | 510/3750 [03:00<20:51,  2.59it/s]  

{'loss': 5.0455, 'grad_norm': 3.77321720123291, 'learning_rate': 4.32e-05, 'epoch': 0.27}


 14%|█▍        | 520/3750 [03:03<18:24,  2.92it/s]

{'loss': 5.0354, 'grad_norm': 3.871476411819458, 'learning_rate': 4.3066666666666665e-05, 'epoch': 0.28}


 14%|█▍        | 530/3750 [03:07<18:17,  2.93it/s]

{'loss': 5.0198, 'grad_norm': 3.0940561294555664, 'learning_rate': 4.293333333333334e-05, 'epoch': 0.28}


 14%|█▍        | 540/3750 [03:10<18:14,  2.93it/s]

{'loss': 5.0566, 'grad_norm': 1.7907155752182007, 'learning_rate': 4.2800000000000004e-05, 'epoch': 0.29}


 15%|█▍        | 550/3750 [03:13<18:05,  2.95it/s]

{'loss': 5.0287, 'grad_norm': 4.155357837677002, 'learning_rate': 4.266666666666667e-05, 'epoch': 0.29}


 15%|█▍        | 560/3750 [03:17<18:11,  2.92it/s]

{'loss': 5.0367, 'grad_norm': 2.010209083557129, 'learning_rate': 4.2533333333333335e-05, 'epoch': 0.3}


 15%|█▌        | 570/3750 [03:20<18:00,  2.94it/s]

{'loss': 5.073, 'grad_norm': 3.271730899810791, 'learning_rate': 4.24e-05, 'epoch': 0.3}


 15%|█▌        | 580/3750 [03:24<18:01,  2.93it/s]

{'loss': 5.0416, 'grad_norm': 2.837615489959717, 'learning_rate': 4.226666666666667e-05, 'epoch': 0.31}


 16%|█▌        | 590/3750 [03:27<18:05,  2.91it/s]

{'loss': 5.0193, 'grad_norm': 2.913388967514038, 'learning_rate': 4.213333333333334e-05, 'epoch': 0.31}


 16%|█▌        | 600/3750 [03:30<17:56,  2.92it/s]

{'loss': 5.0671, 'grad_norm': 2.9814062118530273, 'learning_rate': 4.2e-05, 'epoch': 0.32}


 16%|█▋        | 610/3750 [03:34<17:49,  2.94it/s]

{'loss': 5.0257, 'grad_norm': 3.7320492267608643, 'learning_rate': 4.186666666666667e-05, 'epoch': 0.33}


 17%|█▋        | 620/3750 [03:37<17:45,  2.94it/s]

{'loss': 5.021, 'grad_norm': 3.333852529525757, 'learning_rate': 4.1733333333333336e-05, 'epoch': 0.33}


 17%|█▋        | 630/3750 [03:41<17:44,  2.93it/s]

{'loss': 5.0509, 'grad_norm': 1.850269079208374, 'learning_rate': 4.16e-05, 'epoch': 0.34}


 17%|█▋        | 640/3750 [03:44<17:36,  2.94it/s]

{'loss': 5.0321, 'grad_norm': 2.8982951641082764, 'learning_rate': 4.146666666666667e-05, 'epoch': 0.34}


 17%|█▋        | 650/3750 [03:48<17:35,  2.94it/s]

{'loss': 5.0343, 'grad_norm': 3.941437005996704, 'learning_rate': 4.133333333333333e-05, 'epoch': 0.35}


 18%|█▊        | 660/3750 [03:51<17:28,  2.95it/s]

{'loss': 5.0233, 'grad_norm': 2.8868014812469482, 'learning_rate': 4.12e-05, 'epoch': 0.35}


 18%|█▊        | 670/3750 [03:54<17:31,  2.93it/s]

{'loss': 5.0208, 'grad_norm': 2.9078056812286377, 'learning_rate': 4.106666666666667e-05, 'epoch': 0.36}


 18%|█▊        | 680/3750 [03:58<18:02,  2.84it/s]

{'loss': 5.0525, 'grad_norm': 2.8838894367218018, 'learning_rate': 4.093333333333334e-05, 'epoch': 0.36}


 18%|█▊        | 690/3750 [04:01<17:21,  2.94it/s]

{'loss': 5.0305, 'grad_norm': 2.9406943321228027, 'learning_rate': 4.08e-05, 'epoch': 0.37}


 19%|█▊        | 700/3750 [04:05<17:33,  2.89it/s]

{'loss': 5.0396, 'grad_norm': 3.047203540802002, 'learning_rate': 4.066666666666667e-05, 'epoch': 0.37}


 19%|█▉        | 710/3750 [04:08<17:23,  2.91it/s]

{'loss': 5.0176, 'grad_norm': 1.7637834548950195, 'learning_rate': 4.0533333333333334e-05, 'epoch': 0.38}


 19%|█▉        | 720/3750 [04:12<17:11,  2.94it/s]

{'loss': 5.0036, 'grad_norm': 3.702693223953247, 'learning_rate': 4.0400000000000006e-05, 'epoch': 0.38}


 19%|█▉        | 730/3750 [04:15<17:20,  2.90it/s]

{'loss': 5.0392, 'grad_norm': 3.415644645690918, 'learning_rate': 4.026666666666667e-05, 'epoch': 0.39}


 20%|█▉        | 740/3750 [04:18<17:03,  2.94it/s]

{'loss': 5.0285, 'grad_norm': 3.1675796508789062, 'learning_rate': 4.013333333333333e-05, 'epoch': 0.39}


 20%|██        | 750/3750 [04:22<17:06,  2.92it/s]

{'loss': 5.0236, 'grad_norm': 1.9785051345825195, 'learning_rate': 4e-05, 'epoch': 0.4}


 20%|██        | 760/3750 [04:25<17:05,  2.92it/s]

{'loss': 5.0195, 'grad_norm': 2.007643699645996, 'learning_rate': 3.986666666666667e-05, 'epoch': 0.41}


 21%|██        | 770/3750 [04:29<16:55,  2.94it/s]

{'loss': 5.0089, 'grad_norm': 2.03027081489563, 'learning_rate': 3.9733333333333335e-05, 'epoch': 0.41}


 21%|██        | 780/3750 [04:32<16:51,  2.94it/s]

{'loss': 5.0291, 'grad_norm': 2.409105062484741, 'learning_rate': 3.960000000000001e-05, 'epoch': 0.42}


 21%|██        | 790/3750 [04:36<16:46,  2.94it/s]

{'loss': 5.006, 'grad_norm': 3.158135175704956, 'learning_rate': 3.9466666666666666e-05, 'epoch': 0.42}


 21%|██▏       | 800/3750 [04:39<16:42,  2.94it/s]

{'loss': 5.0063, 'grad_norm': 3.0397608280181885, 'learning_rate': 3.933333333333333e-05, 'epoch': 0.43}


 22%|██▏       | 810/3750 [04:42<16:44,  2.93it/s]

{'loss': 5.0012, 'grad_norm': 2.055492401123047, 'learning_rate': 3.9200000000000004e-05, 'epoch': 0.43}


 22%|██▏       | 820/3750 [04:46<16:40,  2.93it/s]

{'loss': 5.0258, 'grad_norm': 2.1471564769744873, 'learning_rate': 3.906666666666667e-05, 'epoch': 0.44}


 22%|██▏       | 830/3750 [04:49<16:40,  2.92it/s]

{'loss': 5.0325, 'grad_norm': 2.9309756755828857, 'learning_rate': 3.8933333333333336e-05, 'epoch': 0.44}


 22%|██▏       | 840/3750 [04:53<16:30,  2.94it/s]

{'loss': 5.016, 'grad_norm': 1.8554749488830566, 'learning_rate': 3.88e-05, 'epoch': 0.45}


 23%|██▎       | 850/3750 [04:56<16:31,  2.93it/s]

{'loss': 4.994, 'grad_norm': 2.883685827255249, 'learning_rate': 3.866666666666667e-05, 'epoch': 0.45}


 23%|██▎       | 860/3750 [05:00<16:32,  2.91it/s]

{'loss': 4.9975, 'grad_norm': 2.875936508178711, 'learning_rate': 3.853333333333334e-05, 'epoch': 0.46}


 23%|██▎       | 870/3750 [05:03<16:20,  2.94it/s]

{'loss': 5.0516, 'grad_norm': 4.393168926239014, 'learning_rate': 3.8400000000000005e-05, 'epoch': 0.46}


 23%|██▎       | 880/3750 [05:06<16:23,  2.92it/s]

{'loss': 4.9971, 'grad_norm': 2.8681037425994873, 'learning_rate': 3.8266666666666664e-05, 'epoch': 0.47}


 24%|██▎       | 890/3750 [05:10<16:11,  2.94it/s]

{'loss': 5.0481, 'grad_norm': 3.0468828678131104, 'learning_rate': 3.8133333333333336e-05, 'epoch': 0.47}


 24%|██▍       | 900/3750 [05:13<16:06,  2.95it/s]

{'loss': 5.043, 'grad_norm': 2.9125401973724365, 'learning_rate': 3.8e-05, 'epoch': 0.48}


 24%|██▍       | 910/3750 [05:17<16:11,  2.92it/s]

{'loss': 5.002, 'grad_norm': 1.9000041484832764, 'learning_rate': 3.786666666666667e-05, 'epoch': 0.49}


 25%|██▍       | 920/3750 [05:20<16:11,  2.91it/s]

{'loss': 5.0284, 'grad_norm': 1.7450275421142578, 'learning_rate': 3.773333333333334e-05, 'epoch': 0.49}


 25%|██▍       | 930/3750 [05:24<16:06,  2.92it/s]

{'loss': 5.0329, 'grad_norm': 1.7631295919418335, 'learning_rate': 3.76e-05, 'epoch': 0.5}


 25%|██▌       | 940/3750 [05:27<15:57,  2.93it/s]

{'loss': 5.0436, 'grad_norm': 1.7597132921218872, 'learning_rate': 3.7466666666666665e-05, 'epoch': 0.5}


 25%|██▌       | 950/3750 [05:31<16:16,  2.87it/s]

{'loss': 5.0209, 'grad_norm': 4.250227451324463, 'learning_rate': 3.733333333333334e-05, 'epoch': 0.51}


 26%|██▌       | 960/3750 [05:34<15:49,  2.94it/s]

{'loss': 5.0553, 'grad_norm': 3.0937108993530273, 'learning_rate': 3.72e-05, 'epoch': 0.51}


 26%|██▌       | 970/3750 [05:37<15:47,  2.94it/s]

{'loss': 5.0283, 'grad_norm': 2.7537126541137695, 'learning_rate': 3.706666666666667e-05, 'epoch': 0.52}


 26%|██▌       | 980/3750 [05:41<15:41,  2.94it/s]

{'loss': 5.0203, 'grad_norm': 1.5753252506256104, 'learning_rate': 3.6933333333333334e-05, 'epoch': 0.52}


 26%|██▋       | 990/3750 [05:44<15:39,  2.94it/s]

{'loss': 5.0654, 'grad_norm': 3.56359601020813, 'learning_rate': 3.68e-05, 'epoch': 0.53}


 27%|██▋       | 1000/3750 [05:48<15:35,  2.94it/s]

{'loss': 5.0033, 'grad_norm': 1.8077939748764038, 'learning_rate': 3.6666666666666666e-05, 'epoch': 0.53}


 27%|██▋       | 1010/3750 [05:55<17:46,  2.57it/s]  

{'loss': 5.0339, 'grad_norm': 1.6155352592468262, 'learning_rate': 3.653333333333334e-05, 'epoch': 0.54}


 27%|██▋       | 1020/3750 [05:58<15:32,  2.93it/s]

{'loss': 5.0571, 'grad_norm': 4.252065181732178, 'learning_rate': 3.6400000000000004e-05, 'epoch': 0.54}


 27%|██▋       | 1030/3750 [06:02<15:23,  2.94it/s]

{'loss': 5.0489, 'grad_norm': 2.878174066543579, 'learning_rate': 3.626666666666667e-05, 'epoch': 0.55}


 28%|██▊       | 1040/3750 [06:05<15:21,  2.94it/s]

{'loss': 5.0103, 'grad_norm': 1.6197842359542847, 'learning_rate': 3.6133333333333335e-05, 'epoch': 0.55}


 28%|██▊       | 1050/3750 [06:09<15:28,  2.91it/s]

{'loss': 5.0525, 'grad_norm': 3.5670671463012695, 'learning_rate': 3.6e-05, 'epoch': 0.56}


 28%|██▊       | 1060/3750 [06:12<15:12,  2.95it/s]

{'loss': 5.0191, 'grad_norm': 1.523386001586914, 'learning_rate': 3.586666666666667e-05, 'epoch': 0.57}


 29%|██▊       | 1070/3750 [06:15<15:10,  2.94it/s]

{'loss': 5.0236, 'grad_norm': 1.663748860359192, 'learning_rate': 3.573333333333333e-05, 'epoch': 0.57}


 29%|██▉       | 1080/3750 [06:19<15:12,  2.93it/s]

{'loss': 5.0119, 'grad_norm': 3.658426523208618, 'learning_rate': 3.56e-05, 'epoch': 0.58}


 29%|██▉       | 1090/3750 [06:22<15:08,  2.93it/s]

{'loss': 5.0299, 'grad_norm': 2.664708137512207, 'learning_rate': 3.546666666666667e-05, 'epoch': 0.58}


 29%|██▉       | 1100/3750 [06:26<15:06,  2.92it/s]

{'loss': 5.0242, 'grad_norm': 4.214006423950195, 'learning_rate': 3.5333333333333336e-05, 'epoch': 0.59}


 30%|██▉       | 1110/3750 [06:29<14:57,  2.94it/s]

{'loss': 5.0272, 'grad_norm': 2.7345566749572754, 'learning_rate': 3.52e-05, 'epoch': 0.59}


 30%|██▉       | 1120/3750 [06:32<14:56,  2.93it/s]

{'loss': 5.0455, 'grad_norm': 3.048525333404541, 'learning_rate': 3.506666666666667e-05, 'epoch': 0.6}


 30%|███       | 1130/3750 [06:36<14:55,  2.93it/s]

{'loss': 5.0162, 'grad_norm': 4.2101240158081055, 'learning_rate': 3.493333333333333e-05, 'epoch': 0.6}


 30%|███       | 1140/3750 [06:39<14:46,  2.94it/s]

{'loss': 5.0161, 'grad_norm': 2.823697328567505, 'learning_rate': 3.48e-05, 'epoch': 0.61}


 31%|███       | 1150/3750 [06:43<14:44,  2.94it/s]

{'loss': 5.0227, 'grad_norm': 1.6932324171066284, 'learning_rate': 3.466666666666667e-05, 'epoch': 0.61}


 31%|███       | 1160/3750 [06:46<14:39,  2.95it/s]

{'loss': 5.0347, 'grad_norm': 1.6262977123260498, 'learning_rate': 3.453333333333334e-05, 'epoch': 0.62}


 31%|███       | 1170/3750 [06:50<14:35,  2.95it/s]

{'loss': 5.0298, 'grad_norm': 3.434870481491089, 'learning_rate': 3.4399999999999996e-05, 'epoch': 0.62}


 31%|███▏      | 1180/3750 [06:53<14:30,  2.95it/s]

{'loss': 5.0061, 'grad_norm': 1.5360748767852783, 'learning_rate': 3.426666666666667e-05, 'epoch': 0.63}


 32%|███▏      | 1190/3750 [06:56<14:27,  2.95it/s]

{'loss': 5.0324, 'grad_norm': 1.673466682434082, 'learning_rate': 3.4133333333333334e-05, 'epoch': 0.63}


 32%|███▏      | 1200/3750 [07:00<14:24,  2.95it/s]

{'loss': 5.0348, 'grad_norm': 3.455604076385498, 'learning_rate': 3.4000000000000007e-05, 'epoch': 0.64}


 32%|███▏      | 1210/3750 [07:03<14:21,  2.95it/s]

{'loss': 5.0404, 'grad_norm': 2.3490772247314453, 'learning_rate': 3.3866666666666665e-05, 'epoch': 0.65}


 33%|███▎      | 1220/3750 [07:07<14:20,  2.94it/s]

{'loss': 5.0418, 'grad_norm': 3.689936399459839, 'learning_rate': 3.373333333333333e-05, 'epoch': 0.65}


 33%|███▎      | 1230/3750 [07:10<14:14,  2.95it/s]

{'loss': 5.0239, 'grad_norm': 4.5737199783325195, 'learning_rate': 3.3600000000000004e-05, 'epoch': 0.66}


 33%|███▎      | 1240/3750 [07:13<14:09,  2.95it/s]

{'loss': 5.0243, 'grad_norm': 3.2326719760894775, 'learning_rate': 3.346666666666667e-05, 'epoch': 0.66}


 33%|███▎      | 1250/3750 [07:17<14:07,  2.95it/s]

{'loss': 5.043, 'grad_norm': 1.815353274345398, 'learning_rate': 3.3333333333333335e-05, 'epoch': 0.67}


 34%|███▎      | 1260/3750 [07:20<14:07,  2.94it/s]

{'loss': 5.0036, 'grad_norm': 3.7010622024536133, 'learning_rate': 3.32e-05, 'epoch': 0.67}


 34%|███▍      | 1270/3750 [07:24<14:03,  2.94it/s]

{'loss': 5.0186, 'grad_norm': 4.359467029571533, 'learning_rate': 3.3066666666666666e-05, 'epoch': 0.68}


 34%|███▍      | 1280/3750 [07:27<13:58,  2.95it/s]

{'loss': 5.0435, 'grad_norm': 3.6841821670532227, 'learning_rate': 3.293333333333333e-05, 'epoch': 0.68}


 34%|███▍      | 1290/3750 [07:30<13:55,  2.95it/s]

{'loss': 5.0114, 'grad_norm': 2.9178683757781982, 'learning_rate': 3.2800000000000004e-05, 'epoch': 0.69}


 35%|███▍      | 1300/3750 [07:34<13:53,  2.94it/s]

{'loss': 5.0278, 'grad_norm': 3.1213598251342773, 'learning_rate': 3.266666666666667e-05, 'epoch': 0.69}


 35%|███▍      | 1310/3750 [07:37<13:51,  2.93it/s]

{'loss': 5.0258, 'grad_norm': 1.8154670000076294, 'learning_rate': 3.253333333333333e-05, 'epoch': 0.7}


 35%|███▌      | 1320/3750 [07:41<13:44,  2.95it/s]

{'loss': 5.0242, 'grad_norm': 2.0355327129364014, 'learning_rate': 3.24e-05, 'epoch': 0.7}


 35%|███▌      | 1330/3750 [07:44<13:40,  2.95it/s]

{'loss': 5.0226, 'grad_norm': 1.678694486618042, 'learning_rate': 3.226666666666667e-05, 'epoch': 0.71}


 36%|███▌      | 1340/3750 [07:47<13:39,  2.94it/s]

{'loss': 5.0089, 'grad_norm': 2.9835855960845947, 'learning_rate': 3.213333333333334e-05, 'epoch': 0.71}


 36%|███▌      | 1350/3750 [07:51<13:36,  2.94it/s]

{'loss': 5.0126, 'grad_norm': 1.7682005167007446, 'learning_rate': 3.2000000000000005e-05, 'epoch': 0.72}


 36%|███▋      | 1360/3750 [07:54<13:33,  2.94it/s]

{'loss': 5.0035, 'grad_norm': 1.6133371591567993, 'learning_rate': 3.1866666666666664e-05, 'epoch': 0.73}


 37%|███▋      | 1370/3750 [07:58<13:34,  2.92it/s]

{'loss': 5.0514, 'grad_norm': 1.5622755289077759, 'learning_rate': 3.173333333333334e-05, 'epoch': 0.73}


 37%|███▋      | 1380/3750 [08:01<13:25,  2.94it/s]

{'loss': 5.027, 'grad_norm': 1.5535913705825806, 'learning_rate': 3.16e-05, 'epoch': 0.74}


 37%|███▋      | 1390/3750 [08:05<13:23,  2.94it/s]

{'loss': 5.0291, 'grad_norm': 1.5287758111953735, 'learning_rate': 3.146666666666667e-05, 'epoch': 0.74}


 37%|███▋      | 1400/3750 [08:08<13:22,  2.93it/s]

{'loss': 5.0318, 'grad_norm': 3.561150074005127, 'learning_rate': 3.1333333333333334e-05, 'epoch': 0.75}


 38%|███▊      | 1410/3750 [08:11<13:25,  2.90it/s]

{'loss': 5.0067, 'grad_norm': 2.8315539360046387, 'learning_rate': 3.12e-05, 'epoch': 0.75}


 38%|███▊      | 1420/3750 [08:15<13:13,  2.94it/s]

{'loss': 5.0329, 'grad_norm': 3.24347186088562, 'learning_rate': 3.1066666666666665e-05, 'epoch': 0.76}


 38%|███▊      | 1430/3750 [08:18<13:28,  2.87it/s]

{'loss': 5.0145, 'grad_norm': 2.7763285636901855, 'learning_rate': 3.093333333333334e-05, 'epoch': 0.76}


 38%|███▊      | 1440/3750 [08:22<13:08,  2.93it/s]

{'loss': 5.0191, 'grad_norm': 1.7304325103759766, 'learning_rate': 3.08e-05, 'epoch': 0.77}


 39%|███▊      | 1450/3750 [08:25<13:06,  2.92it/s]

{'loss': 5.0557, 'grad_norm': 3.453376293182373, 'learning_rate': 3.066666666666667e-05, 'epoch': 0.77}


 39%|███▉      | 1460/3750 [08:29<13:38,  2.80it/s]

{'loss': 5.007, 'grad_norm': 3.2866270542144775, 'learning_rate': 3.0533333333333335e-05, 'epoch': 0.78}


 39%|███▉      | 1470/3750 [08:32<13:00,  2.92it/s]

{'loss': 5.0238, 'grad_norm': 1.5763908624649048, 'learning_rate': 3.04e-05, 'epoch': 0.78}


 39%|███▉      | 1480/3750 [08:36<12:51,  2.94it/s]

{'loss': 5.0229, 'grad_norm': 2.752304792404175, 'learning_rate': 3.0266666666666666e-05, 'epoch': 0.79}


 40%|███▉      | 1490/3750 [08:39<12:51,  2.93it/s]

{'loss': 5.0297, 'grad_norm': 2.822134017944336, 'learning_rate': 3.0133333333333335e-05, 'epoch': 0.79}


 40%|████      | 1500/3750 [08:42<12:44,  2.94it/s]

{'loss': 5.0422, 'grad_norm': 1.6038652658462524, 'learning_rate': 3e-05, 'epoch': 0.8}


 40%|████      | 1510/3750 [08:50<14:29,  2.58it/s]

{'loss': 5.0313, 'grad_norm': 1.5670301914215088, 'learning_rate': 2.986666666666667e-05, 'epoch': 0.81}


 41%|████      | 1520/3750 [08:53<12:49,  2.90it/s]

{'loss': 5.029, 'grad_norm': 3.540975332260132, 'learning_rate': 2.9733333333333336e-05, 'epoch': 0.81}


 41%|████      | 1530/3750 [08:57<12:52,  2.87it/s]

{'loss': 5.0026, 'grad_norm': 1.4506416320800781, 'learning_rate': 2.96e-05, 'epoch': 0.82}


 41%|████      | 1540/3750 [09:00<12:30,  2.94it/s]

{'loss': 5.032, 'grad_norm': 1.4018833637237549, 'learning_rate': 2.946666666666667e-05, 'epoch': 0.82}


 41%|████▏     | 1550/3750 [09:04<12:25,  2.95it/s]

{'loss': 5.0006, 'grad_norm': 1.460127592086792, 'learning_rate': 2.9333333333333336e-05, 'epoch': 0.83}


 42%|████▏     | 1560/3750 [09:07<12:25,  2.94it/s]

{'loss': 5.0376, 'grad_norm': 4.24613618850708, 'learning_rate': 2.9199999999999998e-05, 'epoch': 0.83}


 42%|████▏     | 1570/3750 [09:10<12:21,  2.94it/s]

{'loss': 5.0602, 'grad_norm': 2.7465665340423584, 'learning_rate': 2.906666666666667e-05, 'epoch': 0.84}


 42%|████▏     | 1580/3750 [09:14<12:18,  2.94it/s]

{'loss': 5.0066, 'grad_norm': 2.847921371459961, 'learning_rate': 2.8933333333333333e-05, 'epoch': 0.84}


 42%|████▏     | 1590/3750 [09:17<12:16,  2.93it/s]

{'loss': 5.0211, 'grad_norm': 1.5605089664459229, 'learning_rate': 2.88e-05, 'epoch': 0.85}


 43%|████▎     | 1600/3750 [09:21<12:09,  2.95it/s]

{'loss': 5.0369, 'grad_norm': 2.6353423595428467, 'learning_rate': 2.8666666666666668e-05, 'epoch': 0.85}


 43%|████▎     | 1610/3750 [09:24<12:06,  2.94it/s]

{'loss': 5.0095, 'grad_norm': 2.7246992588043213, 'learning_rate': 2.8533333333333333e-05, 'epoch': 0.86}


 43%|████▎     | 1620/3750 [09:27<12:19,  2.88it/s]

{'loss': 5.0051, 'grad_norm': 1.5017975568771362, 'learning_rate': 2.84e-05, 'epoch': 0.86}


 43%|████▎     | 1630/3750 [09:31<12:27,  2.84it/s]

{'loss': 5.0494, 'grad_norm': 1.50718092918396, 'learning_rate': 2.8266666666666668e-05, 'epoch': 0.87}


 44%|████▎     | 1640/3750 [09:34<12:10,  2.89it/s]

{'loss': 5.034, 'grad_norm': 2.84578800201416, 'learning_rate': 2.8133333333333334e-05, 'epoch': 0.87}


 44%|████▍     | 1650/3750 [09:38<11:56,  2.93it/s]

{'loss': 4.9984, 'grad_norm': 1.4836530685424805, 'learning_rate': 2.8000000000000003e-05, 'epoch': 0.88}


 44%|████▍     | 1660/3750 [09:41<11:53,  2.93it/s]

{'loss': 5.0434, 'grad_norm': 1.6071457862854004, 'learning_rate': 2.786666666666667e-05, 'epoch': 0.89}


 45%|████▍     | 1670/3750 [09:45<11:54,  2.91it/s]

{'loss': 5.0211, 'grad_norm': 3.463681936264038, 'learning_rate': 2.7733333333333334e-05, 'epoch': 0.89}


 45%|████▍     | 1680/3750 [09:48<11:45,  2.93it/s]

{'loss': 5.0207, 'grad_norm': 3.505729913711548, 'learning_rate': 2.7600000000000003e-05, 'epoch': 0.9}


 45%|████▌     | 1690/3750 [09:52<11:40,  2.94it/s]

{'loss': 5.0238, 'grad_norm': 2.6609983444213867, 'learning_rate': 2.746666666666667e-05, 'epoch': 0.9}


 45%|████▌     | 1700/3750 [09:55<12:09,  2.81it/s]

{'loss': 5.0506, 'grad_norm': 1.4527714252471924, 'learning_rate': 2.733333333333333e-05, 'epoch': 0.91}


 46%|████▌     | 1710/3750 [09:59<11:49,  2.87it/s]

{'loss': 5.0421, 'grad_norm': 3.508558750152588, 'learning_rate': 2.7200000000000004e-05, 'epoch': 0.91}


 46%|████▌     | 1720/3750 [10:02<11:38,  2.91it/s]

{'loss': 5.0169, 'grad_norm': 1.4819514751434326, 'learning_rate': 2.706666666666667e-05, 'epoch': 0.92}


 46%|████▌     | 1730/3750 [10:06<11:41,  2.88it/s]

{'loss': 5.0267, 'grad_norm': 1.4155811071395874, 'learning_rate': 2.6933333333333332e-05, 'epoch': 0.92}


 46%|████▋     | 1740/3750 [10:09<11:29,  2.92it/s]

{'loss': 5.0383, 'grad_norm': 3.4625608921051025, 'learning_rate': 2.6800000000000004e-05, 'epoch': 0.93}


 47%|████▋     | 1750/3750 [10:12<11:22,  2.93it/s]

{'loss': 5.0027, 'grad_norm': 3.0015463829040527, 'learning_rate': 2.6666666666666667e-05, 'epoch': 0.93}


 47%|████▋     | 1760/3750 [10:16<11:18,  2.93it/s]

{'loss': 5.0298, 'grad_norm': 1.4773144721984863, 'learning_rate': 2.6533333333333332e-05, 'epoch': 0.94}


 47%|████▋     | 1770/3750 [10:19<11:15,  2.93it/s]

{'loss': 5.007, 'grad_norm': 1.3344950675964355, 'learning_rate': 2.64e-05, 'epoch': 0.94}


 47%|████▋     | 1780/3750 [10:23<11:11,  2.93it/s]

{'loss': 4.9808, 'grad_norm': 1.4564216136932373, 'learning_rate': 2.6266666666666667e-05, 'epoch': 0.95}


 48%|████▊     | 1790/3750 [10:26<11:25,  2.86it/s]

{'loss': 5.0385, 'grad_norm': 3.5231947898864746, 'learning_rate': 2.6133333333333333e-05, 'epoch': 0.95}


 48%|████▊     | 1800/3750 [10:30<11:06,  2.93it/s]

{'loss': 5.0312, 'grad_norm': 3.414778470993042, 'learning_rate': 2.6000000000000002e-05, 'epoch': 0.96}


 48%|████▊     | 1810/3750 [10:33<11:02,  2.93it/s]

{'loss': 5.0195, 'grad_norm': 1.4412589073181152, 'learning_rate': 2.5866666666666667e-05, 'epoch': 0.97}


 49%|████▊     | 1820/3750 [10:36<11:02,  2.91it/s]

{'loss': 5.0297, 'grad_norm': 1.3477667570114136, 'learning_rate': 2.5733333333333337e-05, 'epoch': 0.97}


 49%|████▉     | 1830/3750 [10:40<10:59,  2.91it/s]

{'loss': 5.0187, 'grad_norm': 2.4685094356536865, 'learning_rate': 2.5600000000000002e-05, 'epoch': 0.98}


 49%|████▉     | 1840/3750 [10:43<10:51,  2.93it/s]

{'loss': 5.0006, 'grad_norm': 1.4464747905731201, 'learning_rate': 2.5466666666666668e-05, 'epoch': 0.98}


 49%|████▉     | 1850/3750 [10:47<10:54,  2.90it/s]

{'loss': 5.0466, 'grad_norm': 3.3690683841705322, 'learning_rate': 2.5333333333333337e-05, 'epoch': 0.99}


 50%|████▉     | 1860/3750 [10:50<10:48,  2.91it/s]

{'loss': 5.0275, 'grad_norm': 1.3787153959274292, 'learning_rate': 2.5200000000000003e-05, 'epoch': 0.99}


 50%|████▉     | 1870/3750 [10:54<10:40,  2.93it/s]

{'loss': 4.9953, 'grad_norm': 2.6627426147460938, 'learning_rate': 2.5066666666666665e-05, 'epoch': 1.0}


                                                   
 50%|█████     | 1875/3750 [11:11<10:42,  2.92it/s]

{'eval_loss': 5.013391971588135, 'eval_runtime': 16.0637, 'eval_samples_per_second': 186.757, 'eval_steps_per_second': 23.345, 'epoch': 1.0}


 50%|█████     | 1880/3750 [11:13<46:39,  1.50s/it]  

{'loss': 5.0207, 'grad_norm': 2.6294143199920654, 'learning_rate': 2.4933333333333334e-05, 'epoch': 1.0}


 50%|█████     | 1890/3750 [11:17<11:36,  2.67it/s]

{'loss': 5.0188, 'grad_norm': 2.722485303878784, 'learning_rate': 2.48e-05, 'epoch': 1.01}


 51%|█████     | 1900/3750 [11:20<10:31,  2.93it/s]

{'loss': 5.0046, 'grad_norm': 1.4918416738510132, 'learning_rate': 2.466666666666667e-05, 'epoch': 1.01}


 51%|█████     | 1910/3750 [11:23<10:30,  2.92it/s]

{'loss': 5.0206, 'grad_norm': 1.5010364055633545, 'learning_rate': 2.4533333333333334e-05, 'epoch': 1.02}


 51%|█████     | 1920/3750 [11:27<10:25,  2.93it/s]

{'loss': 5.0315, 'grad_norm': 3.0142228603363037, 'learning_rate': 2.44e-05, 'epoch': 1.02}


 51%|█████▏    | 1930/3750 [11:30<10:20,  2.93it/s]

{'loss': 5.0153, 'grad_norm': 2.777310371398926, 'learning_rate': 2.426666666666667e-05, 'epoch': 1.03}


 52%|█████▏    | 1940/3750 [11:34<10:17,  2.93it/s]

{'loss': 5.016, 'grad_norm': 1.521807312965393, 'learning_rate': 2.4133333333333335e-05, 'epoch': 1.03}


 52%|█████▏    | 1950/3750 [11:37<10:25,  2.88it/s]

{'loss': 5.0178, 'grad_norm': 3.587449550628662, 'learning_rate': 2.4e-05, 'epoch': 1.04}


 52%|█████▏    | 1960/3750 [11:41<10:32,  2.83it/s]

{'loss': 5.0165, 'grad_norm': 1.592452049255371, 'learning_rate': 2.3866666666666666e-05, 'epoch': 1.05}


 53%|█████▎    | 1970/3750 [11:44<10:19,  2.87it/s]

{'loss': 4.99, 'grad_norm': 1.5262248516082764, 'learning_rate': 2.3733333333333335e-05, 'epoch': 1.05}


 53%|█████▎    | 1980/3750 [11:48<10:04,  2.93it/s]

{'loss': 5.0168, 'grad_norm': 1.5904980897903442, 'learning_rate': 2.36e-05, 'epoch': 1.06}


 53%|█████▎    | 1990/3750 [11:51<09:58,  2.94it/s]

{'loss': 5.0039, 'grad_norm': 2.7422423362731934, 'learning_rate': 2.3466666666666667e-05, 'epoch': 1.06}


 53%|█████▎    | 2000/3750 [11:55<09:55,  2.94it/s]

{'loss': 4.9992, 'grad_norm': 3.5454702377319336, 'learning_rate': 2.3333333333333336e-05, 'epoch': 1.07}


 54%|█████▎    | 2010/3750 [12:02<11:16,  2.57it/s]

{'loss': 5.0169, 'grad_norm': 1.5340635776519775, 'learning_rate': 2.32e-05, 'epoch': 1.07}


 54%|█████▍    | 2020/3750 [12:05<09:49,  2.94it/s]

{'loss': 5.0135, 'grad_norm': 2.7388715744018555, 'learning_rate': 2.3066666666666667e-05, 'epoch': 1.08}


 54%|█████▍    | 2030/3750 [12:09<09:45,  2.94it/s]

{'loss': 5.0021, 'grad_norm': 2.792670726776123, 'learning_rate': 2.2933333333333333e-05, 'epoch': 1.08}


 54%|█████▍    | 2040/3750 [12:12<09:41,  2.94it/s]

{'loss': 5.0366, 'grad_norm': 3.5010476112365723, 'learning_rate': 2.2800000000000002e-05, 'epoch': 1.09}


 55%|█████▍    | 2050/3750 [12:16<09:44,  2.91it/s]

{'loss': 5.0101, 'grad_norm': 1.5351109504699707, 'learning_rate': 2.2666666666666668e-05, 'epoch': 1.09}


 55%|█████▍    | 2060/3750 [12:19<09:34,  2.94it/s]

{'loss': 5.0214, 'grad_norm': 2.7447116374969482, 'learning_rate': 2.2533333333333333e-05, 'epoch': 1.1}


 55%|█████▌    | 2070/3750 [12:23<09:30,  2.95it/s]

{'loss': 5.0065, 'grad_norm': 2.8216593265533447, 'learning_rate': 2.2400000000000002e-05, 'epoch': 1.1}


 55%|█████▌    | 2080/3750 [12:26<09:27,  2.94it/s]

{'loss': 5.0378, 'grad_norm': 2.7824835777282715, 'learning_rate': 2.2266666666666668e-05, 'epoch': 1.11}


 56%|█████▌    | 2090/3750 [12:29<09:25,  2.94it/s]

{'loss': 5.018, 'grad_norm': 2.6015405654907227, 'learning_rate': 2.2133333333333334e-05, 'epoch': 1.11}


 56%|█████▌    | 2100/3750 [12:33<09:19,  2.95it/s]

{'loss': 5.0345, 'grad_norm': 2.815619468688965, 'learning_rate': 2.2000000000000003e-05, 'epoch': 1.12}


 56%|█████▋    | 2110/3750 [12:36<09:22,  2.92it/s]

{'loss': 5.0134, 'grad_norm': 3.174433946609497, 'learning_rate': 2.186666666666667e-05, 'epoch': 1.13}


 57%|█████▋    | 2120/3750 [12:40<09:14,  2.94it/s]

{'loss': 5.0275, 'grad_norm': 1.5369975566864014, 'learning_rate': 2.1733333333333334e-05, 'epoch': 1.13}


 57%|█████▋    | 2130/3750 [12:43<09:19,  2.89it/s]

{'loss': 5.0051, 'grad_norm': 1.6987415552139282, 'learning_rate': 2.16e-05, 'epoch': 1.14}


 57%|█████▋    | 2140/3750 [12:47<09:10,  2.93it/s]

{'loss': 5.0054, 'grad_norm': 2.7969515323638916, 'learning_rate': 2.146666666666667e-05, 'epoch': 1.14}


 57%|█████▋    | 2150/3750 [12:50<09:04,  2.94it/s]

{'loss': 5.0138, 'grad_norm': 3.465932607650757, 'learning_rate': 2.1333333333333335e-05, 'epoch': 1.15}


 58%|█████▊    | 2160/3750 [12:53<09:00,  2.94it/s]

{'loss': 5.0253, 'grad_norm': 1.4378010034561157, 'learning_rate': 2.12e-05, 'epoch': 1.15}


 58%|█████▊    | 2170/3750 [12:57<09:02,  2.91it/s]

{'loss': 5.017, 'grad_norm': 2.758476734161377, 'learning_rate': 2.106666666666667e-05, 'epoch': 1.16}


 58%|█████▊    | 2180/3750 [13:00<09:03,  2.89it/s]

{'loss': 5.051, 'grad_norm': 2.7837817668914795, 'learning_rate': 2.0933333333333335e-05, 'epoch': 1.16}


 58%|█████▊    | 2190/3750 [13:04<08:51,  2.94it/s]

{'loss': 4.9848, 'grad_norm': 1.4167017936706543, 'learning_rate': 2.08e-05, 'epoch': 1.17}


 59%|█████▊    | 2200/3750 [13:07<08:48,  2.93it/s]

{'loss': 5.0178, 'grad_norm': 2.921581983566284, 'learning_rate': 2.0666666666666666e-05, 'epoch': 1.17}


 59%|█████▉    | 2210/3750 [13:11<08:44,  2.94it/s]

{'loss': 5.0262, 'grad_norm': 1.6899060010910034, 'learning_rate': 2.0533333333333336e-05, 'epoch': 1.18}


 59%|█████▉    | 2220/3750 [13:14<08:41,  2.93it/s]

{'loss': 5.0351, 'grad_norm': 1.6831899881362915, 'learning_rate': 2.04e-05, 'epoch': 1.18}


 59%|█████▉    | 2230/3750 [13:17<08:36,  2.94it/s]

{'loss': 5.0106, 'grad_norm': 2.744379758834839, 'learning_rate': 2.0266666666666667e-05, 'epoch': 1.19}


 60%|█████▉    | 2240/3750 [13:21<08:33,  2.94it/s]

{'loss': 5.0201, 'grad_norm': 2.7832624912261963, 'learning_rate': 2.0133333333333336e-05, 'epoch': 1.19}


 60%|██████    | 2250/3750 [13:24<08:29,  2.94it/s]

{'loss': 4.9924, 'grad_norm': 2.71169114112854, 'learning_rate': 2e-05, 'epoch': 1.2}


 60%|██████    | 2260/3750 [13:28<08:26,  2.94it/s]

{'loss': 5.0222, 'grad_norm': 2.652489185333252, 'learning_rate': 1.9866666666666667e-05, 'epoch': 1.21}


 61%|██████    | 2270/3750 [13:31<08:26,  2.92it/s]

{'loss': 5.0152, 'grad_norm': 3.5605790615081787, 'learning_rate': 1.9733333333333333e-05, 'epoch': 1.21}


 61%|██████    | 2280/3750 [13:35<08:18,  2.95it/s]

{'loss': 5.011, 'grad_norm': 2.664022207260132, 'learning_rate': 1.9600000000000002e-05, 'epoch': 1.22}


 61%|██████    | 2290/3750 [13:38<08:15,  2.95it/s]

{'loss': 5.0237, 'grad_norm': 2.707054615020752, 'learning_rate': 1.9466666666666668e-05, 'epoch': 1.22}


 61%|██████▏   | 2300/3750 [13:41<08:13,  2.94it/s]

{'loss': 5.0269, 'grad_norm': 1.3754527568817139, 'learning_rate': 1.9333333333333333e-05, 'epoch': 1.23}


 62%|██████▏   | 2310/3750 [13:45<08:11,  2.93it/s]

{'loss': 5.0178, 'grad_norm': 2.7102837562561035, 'learning_rate': 1.9200000000000003e-05, 'epoch': 1.23}


 62%|██████▏   | 2320/3750 [13:48<08:05,  2.94it/s]

{'loss': 5.0389, 'grad_norm': 3.556898593902588, 'learning_rate': 1.9066666666666668e-05, 'epoch': 1.24}


 62%|██████▏   | 2330/3750 [13:52<08:02,  2.94it/s]

{'loss': 5.0313, 'grad_norm': 3.5588948726654053, 'learning_rate': 1.8933333333333334e-05, 'epoch': 1.24}


 62%|██████▏   | 2340/3750 [13:55<07:58,  2.95it/s]

{'loss': 5.0189, 'grad_norm': 1.5179052352905273, 'learning_rate': 1.88e-05, 'epoch': 1.25}


 63%|██████▎   | 2350/3750 [13:58<07:56,  2.94it/s]

{'loss': 5.0144, 'grad_norm': 1.444571852684021, 'learning_rate': 1.866666666666667e-05, 'epoch': 1.25}


 63%|██████▎   | 2360/3750 [14:02<07:54,  2.93it/s]

{'loss': 5.0069, 'grad_norm': 2.668718099594116, 'learning_rate': 1.8533333333333334e-05, 'epoch': 1.26}


 63%|██████▎   | 2370/3750 [14:05<07:49,  2.94it/s]

{'loss': 5.0224, 'grad_norm': 2.695526123046875, 'learning_rate': 1.84e-05, 'epoch': 1.26}


 63%|██████▎   | 2380/3750 [14:09<07:45,  2.94it/s]

{'loss': 5.0269, 'grad_norm': 1.4763048887252808, 'learning_rate': 1.826666666666667e-05, 'epoch': 1.27}


 64%|██████▎   | 2390/3750 [14:12<07:42,  2.94it/s]

{'loss': 5.0247, 'grad_norm': 2.753883123397827, 'learning_rate': 1.8133333333333335e-05, 'epoch': 1.27}


 64%|██████▍   | 2400/3750 [14:16<08:02,  2.80it/s]

{'loss': 5.0244, 'grad_norm': 2.7085249423980713, 'learning_rate': 1.8e-05, 'epoch': 1.28}


 64%|██████▍   | 2410/3750 [14:19<07:42,  2.90it/s]

{'loss': 5.0196, 'grad_norm': 1.4292210340499878, 'learning_rate': 1.7866666666666666e-05, 'epoch': 1.29}


 65%|██████▍   | 2420/3750 [14:23<07:36,  2.91it/s]

{'loss': 5.0307, 'grad_norm': 1.554763674736023, 'learning_rate': 1.7733333333333335e-05, 'epoch': 1.29}


 65%|██████▍   | 2430/3750 [14:26<07:30,  2.93it/s]

{'loss': 5.006, 'grad_norm': 2.9493093490600586, 'learning_rate': 1.76e-05, 'epoch': 1.3}


 65%|██████▌   | 2440/3750 [14:29<07:25,  2.94it/s]

{'loss': 5.0279, 'grad_norm': 1.4233521223068237, 'learning_rate': 1.7466666666666667e-05, 'epoch': 1.3}


 65%|██████▌   | 2450/3750 [14:33<07:25,  2.92it/s]

{'loss': 5.0167, 'grad_norm': 2.770817756652832, 'learning_rate': 1.7333333333333336e-05, 'epoch': 1.31}


 66%|██████▌   | 2460/3750 [14:36<07:20,  2.93it/s]

{'loss': 5.012, 'grad_norm': 2.805969476699829, 'learning_rate': 1.7199999999999998e-05, 'epoch': 1.31}


 66%|██████▌   | 2470/3750 [14:40<07:17,  2.92it/s]

{'loss': 5.0181, 'grad_norm': 3.4816734790802, 'learning_rate': 1.7066666666666667e-05, 'epoch': 1.32}


 66%|██████▌   | 2480/3750 [14:43<07:16,  2.91it/s]

{'loss': 5.0198, 'grad_norm': 2.810696601867676, 'learning_rate': 1.6933333333333333e-05, 'epoch': 1.32}


 66%|██████▋   | 2490/3750 [14:47<07:08,  2.94it/s]

{'loss': 5.0095, 'grad_norm': 2.0665881633758545, 'learning_rate': 1.6800000000000002e-05, 'epoch': 1.33}


 67%|██████▋   | 2500/3750 [14:50<07:04,  2.94it/s]

{'loss': 5.0147, 'grad_norm': 2.718904495239258, 'learning_rate': 1.6666666666666667e-05, 'epoch': 1.33}


 67%|██████▋   | 2510/3750 [14:57<07:59,  2.59it/s]

{'loss': 5.022, 'grad_norm': 4.235069274902344, 'learning_rate': 1.6533333333333333e-05, 'epoch': 1.34}


 67%|██████▋   | 2520/3750 [15:01<06:58,  2.94it/s]

{'loss': 5.0251, 'grad_norm': 1.4664664268493652, 'learning_rate': 1.6400000000000002e-05, 'epoch': 1.34}


 67%|██████▋   | 2530/3750 [15:04<06:54,  2.94it/s]

{'loss': 5.0299, 'grad_norm': 3.0237479209899902, 'learning_rate': 1.6266666666666665e-05, 'epoch': 1.35}


 68%|██████▊   | 2540/3750 [15:08<06:49,  2.95it/s]

{'loss': 5.0046, 'grad_norm': 2.801272392272949, 'learning_rate': 1.6133333333333334e-05, 'epoch': 1.35}


 68%|██████▊   | 2550/3750 [15:11<06:48,  2.94it/s]

{'loss': 5.0258, 'grad_norm': 2.744095802307129, 'learning_rate': 1.6000000000000003e-05, 'epoch': 1.36}


 68%|██████▊   | 2560/3750 [15:14<06:44,  2.94it/s]

{'loss': 5.0323, 'grad_norm': 1.4845709800720215, 'learning_rate': 1.586666666666667e-05, 'epoch': 1.37}


 69%|██████▊   | 2570/3750 [15:18<06:40,  2.95it/s]

{'loss': 5.0216, 'grad_norm': 2.3747448921203613, 'learning_rate': 1.5733333333333334e-05, 'epoch': 1.37}


 69%|██████▉   | 2580/3750 [15:21<06:36,  2.95it/s]

{'loss': 5.0249, 'grad_norm': 1.4782119989395142, 'learning_rate': 1.56e-05, 'epoch': 1.38}


 69%|██████▉   | 2590/3750 [15:25<06:34,  2.94it/s]

{'loss': 5.0165, 'grad_norm': 1.5176712274551392, 'learning_rate': 1.546666666666667e-05, 'epoch': 1.38}


 69%|██████▉   | 2600/3750 [15:28<06:30,  2.95it/s]

{'loss': 5.0311, 'grad_norm': 2.826639175415039, 'learning_rate': 1.5333333333333334e-05, 'epoch': 1.39}


 70%|██████▉   | 2610/3750 [15:31<06:25,  2.95it/s]

{'loss': 5.0361, 'grad_norm': 1.258367657661438, 'learning_rate': 1.52e-05, 'epoch': 1.39}


 70%|██████▉   | 2620/3750 [15:35<06:23,  2.95it/s]

{'loss': 5.0202, 'grad_norm': 3.553755044937134, 'learning_rate': 1.5066666666666668e-05, 'epoch': 1.4}


 70%|███████   | 2630/3750 [15:38<06:20,  2.95it/s]

{'loss': 5.0256, 'grad_norm': 1.3473525047302246, 'learning_rate': 1.4933333333333335e-05, 'epoch': 1.4}


 70%|███████   | 2640/3750 [15:42<06:17,  2.94it/s]

{'loss': 5.0324, 'grad_norm': 2.764150381088257, 'learning_rate': 1.48e-05, 'epoch': 1.41}


 71%|███████   | 2650/3750 [15:45<06:14,  2.94it/s]

{'loss': 5.0127, 'grad_norm': 2.7583751678466797, 'learning_rate': 1.4666666666666668e-05, 'epoch': 1.41}


 71%|███████   | 2660/3750 [15:48<06:09,  2.95it/s]

{'loss': 5.035, 'grad_norm': 2.83178448677063, 'learning_rate': 1.4533333333333335e-05, 'epoch': 1.42}


 71%|███████   | 2670/3750 [15:52<06:07,  2.94it/s]

{'loss': 5.0235, 'grad_norm': 1.484413743019104, 'learning_rate': 1.44e-05, 'epoch': 1.42}


 71%|███████▏  | 2680/3750 [15:55<06:02,  2.95it/s]

{'loss': 5.0011, 'grad_norm': 2.672693967819214, 'learning_rate': 1.4266666666666667e-05, 'epoch': 1.43}


 72%|███████▏  | 2690/3750 [15:59<06:00,  2.94it/s]

{'loss': 5.0092, 'grad_norm': 1.3225160837173462, 'learning_rate': 1.4133333333333334e-05, 'epoch': 1.43}


 72%|███████▏  | 2700/3750 [16:02<05:55,  2.95it/s]

{'loss': 5.0327, 'grad_norm': 1.4986344575881958, 'learning_rate': 1.4000000000000001e-05, 'epoch': 1.44}


 72%|███████▏  | 2710/3750 [16:05<05:53,  2.94it/s]

{'loss': 5.0375, 'grad_norm': 1.466361165046692, 'learning_rate': 1.3866666666666667e-05, 'epoch': 1.45}


 73%|███████▎  | 2720/3750 [16:09<05:50,  2.94it/s]

{'loss': 5.0272, 'grad_norm': 1.5197478532791138, 'learning_rate': 1.3733333333333335e-05, 'epoch': 1.45}


 73%|███████▎  | 2730/3750 [16:12<05:46,  2.94it/s]

{'loss': 5.0024, 'grad_norm': 1.4486379623413086, 'learning_rate': 1.3600000000000002e-05, 'epoch': 1.46}


 73%|███████▎  | 2740/3750 [16:16<05:42,  2.95it/s]

{'loss': 5.0192, 'grad_norm': 1.4123154878616333, 'learning_rate': 1.3466666666666666e-05, 'epoch': 1.46}


 73%|███████▎  | 2750/3750 [16:19<05:38,  2.95it/s]

{'loss': 5.029, 'grad_norm': 2.535797357559204, 'learning_rate': 1.3333333333333333e-05, 'epoch': 1.47}


 74%|███████▎  | 2760/3750 [16:22<05:35,  2.95it/s]

{'loss': 5.0256, 'grad_norm': 3.5081143379211426, 'learning_rate': 1.32e-05, 'epoch': 1.47}


 74%|███████▍  | 2770/3750 [16:26<05:33,  2.94it/s]

{'loss': 5.032, 'grad_norm': 1.384996771812439, 'learning_rate': 1.3066666666666666e-05, 'epoch': 1.48}


 74%|███████▍  | 2780/3750 [16:29<05:29,  2.95it/s]

{'loss': 5.0212, 'grad_norm': 1.6032767295837402, 'learning_rate': 1.2933333333333334e-05, 'epoch': 1.48}


 74%|███████▍  | 2790/3750 [16:33<05:25,  2.95it/s]

{'loss': 5.0156, 'grad_norm': 1.2822085618972778, 'learning_rate': 1.2800000000000001e-05, 'epoch': 1.49}


 75%|███████▍  | 2800/3750 [16:36<05:22,  2.95it/s]

{'loss': 5.0123, 'grad_norm': 3.4294097423553467, 'learning_rate': 1.2666666666666668e-05, 'epoch': 1.49}


 75%|███████▍  | 2810/3750 [16:39<05:18,  2.95it/s]

{'loss': 5.0191, 'grad_norm': 2.7531471252441406, 'learning_rate': 1.2533333333333332e-05, 'epoch': 1.5}


 75%|███████▌  | 2820/3750 [16:43<05:16,  2.94it/s]

{'loss': 5.0096, 'grad_norm': 3.378204107284546, 'learning_rate': 1.24e-05, 'epoch': 1.5}


 75%|███████▌  | 2830/3750 [16:46<05:14,  2.93it/s]

{'loss': 5.023, 'grad_norm': 1.4926892518997192, 'learning_rate': 1.2266666666666667e-05, 'epoch': 1.51}


 76%|███████▌  | 2840/3750 [16:50<05:08,  2.95it/s]

{'loss': 5.0071, 'grad_norm': 1.3952085971832275, 'learning_rate': 1.2133333333333335e-05, 'epoch': 1.51}


 76%|███████▌  | 2850/3750 [16:53<05:06,  2.94it/s]

{'loss': 5.0002, 'grad_norm': 3.2993831634521484, 'learning_rate': 1.2e-05, 'epoch': 1.52}


 76%|███████▋  | 2860/3750 [16:56<05:06,  2.91it/s]

{'loss': 5.0122, 'grad_norm': 2.5640478134155273, 'learning_rate': 1.1866666666666668e-05, 'epoch': 1.53}


 77%|███████▋  | 2870/3750 [17:00<05:00,  2.93it/s]

{'loss': 5.0155, 'grad_norm': 2.692556619644165, 'learning_rate': 1.1733333333333333e-05, 'epoch': 1.53}


 77%|███████▋  | 2880/3750 [17:03<04:55,  2.95it/s]

{'loss': 5.0139, 'grad_norm': 2.665656566619873, 'learning_rate': 1.16e-05, 'epoch': 1.54}


 77%|███████▋  | 2890/3750 [17:07<04:51,  2.95it/s]

{'loss': 5.032, 'grad_norm': 2.6540982723236084, 'learning_rate': 1.1466666666666666e-05, 'epoch': 1.54}


 77%|███████▋  | 2900/3750 [17:10<04:48,  2.95it/s]

{'loss': 5.0096, 'grad_norm': 1.4231983423233032, 'learning_rate': 1.1333333333333334e-05, 'epoch': 1.55}


 78%|███████▊  | 2910/3750 [17:13<04:44,  2.95it/s]

{'loss': 5.0129, 'grad_norm': 2.9984097480773926, 'learning_rate': 1.1200000000000001e-05, 'epoch': 1.55}


 78%|███████▊  | 2920/3750 [17:17<04:41,  2.95it/s]

{'loss': 5.017, 'grad_norm': 1.2947542667388916, 'learning_rate': 1.1066666666666667e-05, 'epoch': 1.56}


 78%|███████▊  | 2930/3750 [17:20<04:38,  2.95it/s]

{'loss': 4.9946, 'grad_norm': 1.4270726442337036, 'learning_rate': 1.0933333333333334e-05, 'epoch': 1.56}


 78%|███████▊  | 2940/3750 [17:24<04:35,  2.95it/s]

{'loss': 5.0028, 'grad_norm': 2.5134263038635254, 'learning_rate': 1.08e-05, 'epoch': 1.57}


 79%|███████▊  | 2950/3750 [17:27<04:32,  2.93it/s]

{'loss': 5.0397, 'grad_norm': 1.318304181098938, 'learning_rate': 1.0666666666666667e-05, 'epoch': 1.57}


 79%|███████▉  | 2960/3750 [17:30<04:28,  2.95it/s]

{'loss': 5.0091, 'grad_norm': 2.7043960094451904, 'learning_rate': 1.0533333333333335e-05, 'epoch': 1.58}


 79%|███████▉  | 2970/3750 [17:34<04:24,  2.95it/s]

{'loss': 5.0168, 'grad_norm': 2.7918407917022705, 'learning_rate': 1.04e-05, 'epoch': 1.58}


 79%|███████▉  | 2980/3750 [17:37<04:21,  2.94it/s]

{'loss': 5.0069, 'grad_norm': 3.6728456020355225, 'learning_rate': 1.0266666666666668e-05, 'epoch': 1.59}


 80%|███████▉  | 2990/3750 [17:41<04:18,  2.94it/s]

{'loss': 5.0101, 'grad_norm': 2.628221273422241, 'learning_rate': 1.0133333333333333e-05, 'epoch': 1.59}


 80%|████████  | 3000/3750 [17:44<04:16,  2.92it/s]

{'loss': 5.0245, 'grad_norm': 1.4444152116775513, 'learning_rate': 1e-05, 'epoch': 1.6}


 80%|████████  | 3010/3750 [17:51<04:42,  2.62it/s]

{'loss': 5.0159, 'grad_norm': 1.4512792825698853, 'learning_rate': 9.866666666666667e-06, 'epoch': 1.61}


 81%|████████  | 3020/3750 [17:54<04:09,  2.93it/s]

{'loss': 5.0084, 'grad_norm': 1.2586300373077393, 'learning_rate': 9.733333333333334e-06, 'epoch': 1.61}


 81%|████████  | 3030/3750 [17:58<04:04,  2.95it/s]

{'loss': 5.0097, 'grad_norm': 1.3922407627105713, 'learning_rate': 9.600000000000001e-06, 'epoch': 1.62}


 81%|████████  | 3040/3750 [18:01<04:00,  2.95it/s]

{'loss': 5.0077, 'grad_norm': 1.4453833103179932, 'learning_rate': 9.466666666666667e-06, 'epoch': 1.62}


 81%|████████▏ | 3050/3750 [18:04<03:57,  2.95it/s]

{'loss': 5.0156, 'grad_norm': 1.4835528135299683, 'learning_rate': 9.333333333333334e-06, 'epoch': 1.63}


 82%|████████▏ | 3060/3750 [18:08<03:54,  2.95it/s]

{'loss': 5.0257, 'grad_norm': 2.693194627761841, 'learning_rate': 9.2e-06, 'epoch': 1.63}


 82%|████████▏ | 3070/3750 [18:11<03:51,  2.94it/s]

{'loss': 5.018, 'grad_norm': 1.3444358110427856, 'learning_rate': 9.066666666666667e-06, 'epoch': 1.64}


 82%|████████▏ | 3080/3750 [18:15<03:47,  2.95it/s]

{'loss': 5.0252, 'grad_norm': 2.671643018722534, 'learning_rate': 8.933333333333333e-06, 'epoch': 1.64}


 82%|████████▏ | 3090/3750 [18:18<03:43,  2.96it/s]

{'loss': 5.011, 'grad_norm': 1.4503138065338135, 'learning_rate': 8.8e-06, 'epoch': 1.65}


 83%|████████▎ | 3100/3750 [18:21<03:40,  2.95it/s]

{'loss': 5.0205, 'grad_norm': 1.3998663425445557, 'learning_rate': 8.666666666666668e-06, 'epoch': 1.65}


 83%|████████▎ | 3110/3750 [18:25<03:36,  2.95it/s]

{'loss': 5.0155, 'grad_norm': 2.5927023887634277, 'learning_rate': 8.533333333333334e-06, 'epoch': 1.66}


 83%|████████▎ | 3120/3750 [18:28<03:34,  2.94it/s]

{'loss': 5.0213, 'grad_norm': 1.9265824556350708, 'learning_rate': 8.400000000000001e-06, 'epoch': 1.66}


 83%|████████▎ | 3130/3750 [18:32<03:30,  2.94it/s]

{'loss': 5.0197, 'grad_norm': 1.2566379308700562, 'learning_rate': 8.266666666666667e-06, 'epoch': 1.67}


 84%|████████▎ | 3140/3750 [18:35<03:27,  2.94it/s]

{'loss': 5.0337, 'grad_norm': 2.220806360244751, 'learning_rate': 8.133333333333332e-06, 'epoch': 1.67}


 84%|████████▍ | 3150/3750 [18:38<03:23,  2.95it/s]

{'loss': 5.0155, 'grad_norm': 2.610244035720825, 'learning_rate': 8.000000000000001e-06, 'epoch': 1.68}


 84%|████████▍ | 3160/3750 [18:42<03:20,  2.94it/s]

{'loss': 5.0264, 'grad_norm': 1.3313181400299072, 'learning_rate': 7.866666666666667e-06, 'epoch': 1.69}


 85%|████████▍ | 3170/3750 [18:45<03:16,  2.95it/s]

{'loss': 5.0093, 'grad_norm': 2.712149143218994, 'learning_rate': 7.733333333333334e-06, 'epoch': 1.69}


 85%|████████▍ | 3180/3750 [18:49<03:13,  2.95it/s]

{'loss': 5.0103, 'grad_norm': 1.477975606918335, 'learning_rate': 7.6e-06, 'epoch': 1.7}


 85%|████████▌ | 3190/3750 [18:52<03:09,  2.95it/s]

{'loss': 5.0224, 'grad_norm': 1.3278754949569702, 'learning_rate': 7.4666666666666675e-06, 'epoch': 1.7}


 85%|████████▌ | 3200/3750 [18:55<03:06,  2.94it/s]

{'loss': 5.0251, 'grad_norm': 2.6501078605651855, 'learning_rate': 7.333333333333334e-06, 'epoch': 1.71}


 86%|████████▌ | 3210/3750 [18:59<03:05,  2.91it/s]

{'loss': 5.0225, 'grad_norm': 2.652306079864502, 'learning_rate': 7.2e-06, 'epoch': 1.71}


 86%|████████▌ | 3220/3750 [19:02<03:10,  2.78it/s]

{'loss': 5.001, 'grad_norm': 3.6096112728118896, 'learning_rate': 7.066666666666667e-06, 'epoch': 1.72}


 86%|████████▌ | 3230/3750 [19:06<03:04,  2.81it/s]

{'loss': 5.0132, 'grad_norm': 2.587993860244751, 'learning_rate': 6.933333333333334e-06, 'epoch': 1.72}


 86%|████████▋ | 3240/3750 [19:09<02:57,  2.87it/s]

{'loss': 5.0182, 'grad_norm': 1.3589107990264893, 'learning_rate': 6.800000000000001e-06, 'epoch': 1.73}


 87%|████████▋ | 3250/3750 [19:13<02:54,  2.86it/s]

{'loss': 5.0204, 'grad_norm': 2.6201045513153076, 'learning_rate': 6.666666666666667e-06, 'epoch': 1.73}


 87%|████████▋ | 3260/3750 [19:16<02:47,  2.93it/s]

{'loss': 5.0361, 'grad_norm': 2.676574945449829, 'learning_rate': 6.533333333333333e-06, 'epoch': 1.74}


 87%|████████▋ | 3270/3750 [19:20<02:44,  2.92it/s]

{'loss': 5.0125, 'grad_norm': 1.2272406816482544, 'learning_rate': 6.4000000000000006e-06, 'epoch': 1.74}


 87%|████████▋ | 3280/3750 [19:23<02:42,  2.89it/s]

{'loss': 5.0106, 'grad_norm': 2.5826683044433594, 'learning_rate': 6.266666666666666e-06, 'epoch': 1.75}


 88%|████████▊ | 3290/3750 [19:27<02:38,  2.90it/s]

{'loss': 5.0185, 'grad_norm': 1.3057273626327515, 'learning_rate': 6.133333333333334e-06, 'epoch': 1.75}


 88%|████████▊ | 3300/3750 [19:30<02:34,  2.91it/s]

{'loss': 5.0222, 'grad_norm': 1.2107876539230347, 'learning_rate': 6e-06, 'epoch': 1.76}


 88%|████████▊ | 3310/3750 [19:34<02:30,  2.92it/s]

{'loss': 5.0404, 'grad_norm': 2.0179905891418457, 'learning_rate': 5.866666666666667e-06, 'epoch': 1.77}


 89%|████████▊ | 3320/3750 [19:37<02:26,  2.94it/s]

{'loss': 5.0036, 'grad_norm': 2.6139132976531982, 'learning_rate': 5.733333333333333e-06, 'epoch': 1.77}


 89%|████████▉ | 3330/3750 [19:40<02:24,  2.91it/s]

{'loss': 5.0258, 'grad_norm': 4.099403381347656, 'learning_rate': 5.600000000000001e-06, 'epoch': 1.78}


 89%|████████▉ | 3340/3750 [19:44<02:20,  2.93it/s]

{'loss': 5.0218, 'grad_norm': 1.3382303714752197, 'learning_rate': 5.466666666666667e-06, 'epoch': 1.78}


 89%|████████▉ | 3350/3750 [19:47<02:16,  2.93it/s]

{'loss': 5.0198, 'grad_norm': 2.9696195125579834, 'learning_rate': 5.333333333333334e-06, 'epoch': 1.79}


 90%|████████▉ | 3360/3750 [19:51<02:13,  2.93it/s]

{'loss': 5.0206, 'grad_norm': 1.382103681564331, 'learning_rate': 5.2e-06, 'epoch': 1.79}


 90%|████████▉ | 3370/3750 [19:54<02:09,  2.94it/s]

{'loss': 5.0179, 'grad_norm': 1.3862955570220947, 'learning_rate': 5.066666666666667e-06, 'epoch': 1.8}


 90%|█████████ | 3380/3750 [19:58<02:07,  2.91it/s]

{'loss': 5.0105, 'grad_norm': 3.6016180515289307, 'learning_rate': 4.933333333333333e-06, 'epoch': 1.8}


 90%|█████████ | 3390/3750 [20:01<02:03,  2.91it/s]

{'loss': 5.0165, 'grad_norm': 1.7297353744506836, 'learning_rate': 4.800000000000001e-06, 'epoch': 1.81}


 91%|█████████ | 3400/3750 [20:04<02:01,  2.89it/s]

{'loss': 5.0218, 'grad_norm': 1.3661400079727173, 'learning_rate': 4.666666666666667e-06, 'epoch': 1.81}


 91%|█████████ | 3410/3750 [20:08<01:56,  2.93it/s]

{'loss': 5.0188, 'grad_norm': 1.5095478296279907, 'learning_rate': 4.533333333333334e-06, 'epoch': 1.82}


 91%|█████████ | 3420/3750 [20:11<01:52,  2.93it/s]

{'loss': 5.0203, 'grad_norm': 2.6098804473876953, 'learning_rate': 4.4e-06, 'epoch': 1.82}


 91%|█████████▏| 3430/3750 [20:15<01:48,  2.94it/s]

{'loss': 5.0068, 'grad_norm': 1.4485251903533936, 'learning_rate': 4.266666666666667e-06, 'epoch': 1.83}


 92%|█████████▏| 3440/3750 [20:18<01:45,  2.94it/s]

{'loss': 5.0216, 'grad_norm': 1.3307026624679565, 'learning_rate': 4.133333333333333e-06, 'epoch': 1.83}


 92%|█████████▏| 3450/3750 [20:22<01:41,  2.94it/s]

{'loss': 5.047, 'grad_norm': 1.320772409439087, 'learning_rate': 4.000000000000001e-06, 'epoch': 1.84}


 92%|█████████▏| 3460/3750 [20:25<01:38,  2.93it/s]

{'loss': 5.0293, 'grad_norm': 2.7162318229675293, 'learning_rate': 3.866666666666667e-06, 'epoch': 1.85}


 93%|█████████▎| 3470/3750 [20:28<01:36,  2.91it/s]

{'loss': 5.0269, 'grad_norm': 2.6461033821105957, 'learning_rate': 3.7333333333333337e-06, 'epoch': 1.85}


 93%|█████████▎| 3480/3750 [20:32<01:32,  2.92it/s]

{'loss': 5.0198, 'grad_norm': 1.366752028465271, 'learning_rate': 3.6e-06, 'epoch': 1.86}


 93%|█████████▎| 3490/3750 [20:35<01:28,  2.94it/s]

{'loss': 5.0193, 'grad_norm': 1.3867157697677612, 'learning_rate': 3.466666666666667e-06, 'epoch': 1.86}


 93%|█████████▎| 3500/3750 [20:39<01:25,  2.93it/s]

{'loss': 5.0107, 'grad_norm': 2.6873209476470947, 'learning_rate': 3.3333333333333333e-06, 'epoch': 1.87}


 94%|█████████▎| 3510/3750 [20:46<01:33,  2.56it/s]

{'loss': 5.0078, 'grad_norm': 2.581895589828491, 'learning_rate': 3.2000000000000003e-06, 'epoch': 1.87}


 94%|█████████▍| 3520/3750 [20:50<01:18,  2.93it/s]

{'loss': 5.0185, 'grad_norm': 2.608963966369629, 'learning_rate': 3.066666666666667e-06, 'epoch': 1.88}


 94%|█████████▍| 3530/3750 [20:53<01:14,  2.94it/s]

{'loss': 5.0214, 'grad_norm': 2.6531271934509277, 'learning_rate': 2.9333333333333333e-06, 'epoch': 1.88}


 94%|█████████▍| 3540/3750 [20:56<01:11,  2.94it/s]

{'loss': 5.0331, 'grad_norm': 2.974628210067749, 'learning_rate': 2.8000000000000003e-06, 'epoch': 1.89}


 95%|█████████▍| 3550/3750 [21:00<01:07,  2.95it/s]

{'loss': 5.0146, 'grad_norm': 2.639070749282837, 'learning_rate': 2.666666666666667e-06, 'epoch': 1.89}


 95%|█████████▍| 3560/3750 [21:03<01:04,  2.94it/s]

{'loss': 5.0093, 'grad_norm': 3.024139404296875, 'learning_rate': 2.5333333333333334e-06, 'epoch': 1.9}


 95%|█████████▌| 3570/3750 [21:07<01:01,  2.91it/s]

{'loss': 5.0242, 'grad_norm': 2.588747978210449, 'learning_rate': 2.4000000000000003e-06, 'epoch': 1.9}


 95%|█████████▌| 3580/3750 [21:10<00:57,  2.95it/s]

{'loss': 5.0214, 'grad_norm': 1.907880425453186, 'learning_rate': 2.266666666666667e-06, 'epoch': 1.91}


 96%|█████████▌| 3590/3750 [21:13<00:54,  2.93it/s]

{'loss': 5.0124, 'grad_norm': 1.38330078125, 'learning_rate': 2.1333333333333334e-06, 'epoch': 1.91}


 96%|█████████▌| 3600/3750 [21:17<00:50,  2.95it/s]

{'loss': 5.0203, 'grad_norm': 2.657989263534546, 'learning_rate': 2.0000000000000003e-06, 'epoch': 1.92}


 96%|█████████▋| 3610/3750 [21:20<00:47,  2.95it/s]

{'loss': 5.0247, 'grad_norm': 1.468631625175476, 'learning_rate': 1.8666666666666669e-06, 'epoch': 1.93}


 97%|█████████▋| 3620/3750 [21:24<00:44,  2.93it/s]

{'loss': 5.0044, 'grad_norm': 2.6334617137908936, 'learning_rate': 1.7333333333333334e-06, 'epoch': 1.93}


 97%|█████████▋| 3630/3750 [21:27<00:40,  2.94it/s]

{'loss': 5.001, 'grad_norm': 1.34549880027771, 'learning_rate': 1.6000000000000001e-06, 'epoch': 1.94}


 97%|█████████▋| 3640/3750 [21:31<00:37,  2.94it/s]

{'loss': 5.0148, 'grad_norm': 1.3513126373291016, 'learning_rate': 1.4666666666666667e-06, 'epoch': 1.94}


 97%|█████████▋| 3650/3750 [21:34<00:34,  2.94it/s]

{'loss': 5.0058, 'grad_norm': 2.656839370727539, 'learning_rate': 1.3333333333333334e-06, 'epoch': 1.95}


 98%|█████████▊| 3660/3750 [21:37<00:30,  2.94it/s]

{'loss': 5.0103, 'grad_norm': 3.4006693363189697, 'learning_rate': 1.2000000000000002e-06, 'epoch': 1.95}


 98%|█████████▊| 3670/3750 [21:41<00:27,  2.95it/s]

{'loss': 5.019, 'grad_norm': 2.617344617843628, 'learning_rate': 1.0666666666666667e-06, 'epoch': 1.96}


 98%|█████████▊| 3680/3750 [21:44<00:23,  2.94it/s]

{'loss': 5.0273, 'grad_norm': 1.3387681245803833, 'learning_rate': 9.333333333333334e-07, 'epoch': 1.96}


 98%|█████████▊| 3690/3750 [21:48<00:20,  2.95it/s]

{'loss': 5.0171, 'grad_norm': 1.360410451889038, 'learning_rate': 8.000000000000001e-07, 'epoch': 1.97}


 99%|█████████▊| 3700/3750 [21:51<00:16,  2.94it/s]

{'loss': 5.0187, 'grad_norm': 1.2667558193206787, 'learning_rate': 6.666666666666667e-07, 'epoch': 1.97}


 99%|█████████▉| 3710/3750 [21:54<00:13,  2.94it/s]

{'loss': 5.0143, 'grad_norm': 2.654963254928589, 'learning_rate': 5.333333333333333e-07, 'epoch': 1.98}


 99%|█████████▉| 3720/3750 [21:58<00:10,  2.95it/s]

{'loss': 5.0127, 'grad_norm': 1.2621878385543823, 'learning_rate': 4.0000000000000003e-07, 'epoch': 1.98}


 99%|█████████▉| 3730/3750 [22:01<00:06,  2.95it/s]

{'loss': 5.0135, 'grad_norm': 1.4247764348983765, 'learning_rate': 2.6666666666666667e-07, 'epoch': 1.99}


100%|█████████▉| 3740/3750 [22:05<00:03,  2.95it/s]

{'loss': 5.0128, 'grad_norm': 2.997262716293335, 'learning_rate': 1.3333333333333334e-07, 'epoch': 1.99}


100%|██████████| 3750/3750 [22:08<00:00,  2.94it/s]

{'loss': 5.0222, 'grad_norm': 1.918997049331665, 'learning_rate': 0.0, 'epoch': 2.0}


                                                   
100%|██████████| 3750/3750 [22:28<00:00,  2.78it/s]


{'eval_loss': 5.011055946350098, 'eval_runtime': 15.8985, 'eval_samples_per_second': 188.697, 'eval_steps_per_second': 23.587, 'epoch': 2.0}
{'train_runtime': 1348.1063, 'train_samples_per_second': 22.253, 'train_steps_per_second': 2.782, 'train_loss': 5.022470348103841, 'epoch': 2.0}


('tokenizer_g_finetuned/tokenizer_config.json',
 'tokenizer_g_finetuned/special_tokens_map.json',
 'tokenizer_g_finetuned/vocab.json',
 'tokenizer_g_finetuned/merges.txt',
 'tokenizer_g_finetuned/added_tokens.json',
 'tokenizer_g_finetuned/tokenizer.json')

In [4]:
model_name_g_star = "roberta-base"
tokenizer_g_star = AutoTokenizer.from_pretrained(model_name_g_star)
model_g_star = AutoModelForSequenceClassification.from_pretrained(model_name_g_star, num_labels=150)

# Adjust classifier head if needed
if model_g_star.config.num_labels != num_labels:
    model_g_star.resize_token_embeddings(len(tokenizer_g_star))
    model_g_star.classifier = torch.nn.Linear(model_g_star.config.dim, num_labels)

model_g_star.config.label2id = label2id
model_g_star.config.id2label = id2label

def encode_batch_g_star(batch):
    """
    We feed a special token [EMPTY] or just an empty string 
    so the model can't see the real query.
    """
    # We'll feed "[EMPTY]" for each sample
    input_texts = ["[EMPTY]" for _ in batch["query"]]
    enc = tokenizer_g_star(
        input_texts,
        truncation=True, 
        padding="max_length",
        max_length=8
    )
    enc["labels"] = [label2id[intent] for intent in batch["intent"]]
    return enc

train_encoded_g_star = train_dataset.map(encode_batch_g_star, batched=True, remove_columns=train_dataset.column_names)
val_encoded_g_star   = val_dataset.map(encode_batch_g_star,   batched=True, remove_columns=val_dataset.column_names)

training_args_g_star = TrainingArguments(
    output_dir="g_star_output",
    eval_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=2,
    logging_steps=10,
    seed=42
)

trainer_g_star = Trainer(
    model=model_g_star,
    args=training_args_g_star,
    train_dataset=train_encoded_g_star,
    eval_dataset=val_encoded_g_star
)

trainer_g_star.train()

# Save final classifier g^*
model_g_star.save_pretrained("model_gstar_finetuned")
tokenizer_g_star.save_pretrained("tokenizer_gstar_finetuned")

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Map: 100%|██████████| 15000/15000 [00:00<00:00, 110094.81 examples/s]
Map: 100%|██████████| 3000/3000 [00:00<00:00, 125297.86 examples/s]
  0%|          | 10/3750 [00:03<17:44,  3.51it/s]

{'loss': 5.0411, 'grad_norm': 4.87498140335083, 'learning_rate': 4.986666666666667e-05, 'epoch': 0.01}


  1%|          | 20/3750 [00:06<17:16,  3.60it/s]

{'loss': 4.999, 'grad_norm': 4.582918167114258, 'learning_rate': 4.973333333333334e-05, 'epoch': 0.01}


  1%|          | 30/3750 [00:09<17:05,  3.63it/s]

{'loss': 5.0114, 'grad_norm': 3.8668017387390137, 'learning_rate': 4.96e-05, 'epoch': 0.02}


  1%|          | 40/3750 [00:11<16:54,  3.66it/s]

{'loss': 5.0382, 'grad_norm': 4.7583465576171875, 'learning_rate': 4.9466666666666665e-05, 'epoch': 0.02}


  1%|▏         | 50/3750 [00:14<17:00,  3.63it/s]

{'loss': 5.0375, 'grad_norm': 3.218398094177246, 'learning_rate': 4.933333333333334e-05, 'epoch': 0.03}


  2%|▏         | 60/3750 [00:17<17:28,  3.52it/s]

{'loss': 5.0209, 'grad_norm': 2.6647896766662598, 'learning_rate': 4.92e-05, 'epoch': 0.03}


  2%|▏         | 70/3750 [00:20<16:51,  3.64it/s]

{'loss': 5.0358, 'grad_norm': 3.7504491806030273, 'learning_rate': 4.906666666666667e-05, 'epoch': 0.04}


  2%|▏         | 80/3750 [00:22<17:15,  3.54it/s]

{'loss': 5.0421, 'grad_norm': 4.865965843200684, 'learning_rate': 4.8933333333333335e-05, 'epoch': 0.04}


  2%|▏         | 90/3750 [00:25<17:19,  3.52it/s]

{'loss': 5.0697, 'grad_norm': 4.955928325653076, 'learning_rate': 4.88e-05, 'epoch': 0.05}


  3%|▎         | 100/3750 [00:28<17:19,  3.51it/s]

{'loss': 5.0223, 'grad_norm': 2.904815435409546, 'learning_rate': 4.866666666666667e-05, 'epoch': 0.05}


  3%|▎         | 110/3750 [00:31<17:03,  3.56it/s]

{'loss': 5.0208, 'grad_norm': 2.8797402381896973, 'learning_rate': 4.853333333333334e-05, 'epoch': 0.06}


  3%|▎         | 120/3750 [00:34<17:01,  3.55it/s]

{'loss': 5.0413, 'grad_norm': 2.8216636180877686, 'learning_rate': 4.8400000000000004e-05, 'epoch': 0.06}


  3%|▎         | 130/3750 [00:37<17:39,  3.42it/s]

{'loss': 5.0346, 'grad_norm': 3.504875898361206, 'learning_rate': 4.826666666666667e-05, 'epoch': 0.07}


  4%|▎         | 140/3750 [00:39<16:42,  3.60it/s]

{'loss': 5.0614, 'grad_norm': 3.5332157611846924, 'learning_rate': 4.8133333333333336e-05, 'epoch': 0.07}


  4%|▍         | 150/3750 [00:42<16:59,  3.53it/s]

{'loss': 5.018, 'grad_norm': 2.9488518238067627, 'learning_rate': 4.8e-05, 'epoch': 0.08}


  4%|▍         | 160/3750 [00:45<17:29,  3.42it/s]

{'loss': 5.0481, 'grad_norm': 4.232275009155273, 'learning_rate': 4.7866666666666674e-05, 'epoch': 0.09}


  5%|▍         | 170/3750 [00:48<16:58,  3.52it/s]

{'loss': 5.0368, 'grad_norm': 2.7328052520751953, 'learning_rate': 4.773333333333333e-05, 'epoch': 0.09}


  5%|▍         | 180/3750 [00:51<16:18,  3.65it/s]

{'loss': 5.0279, 'grad_norm': 3.5492217540740967, 'learning_rate': 4.76e-05, 'epoch': 0.1}


  5%|▌         | 190/3750 [00:54<16:28,  3.60it/s]

{'loss': 5.0113, 'grad_norm': 3.409756898880005, 'learning_rate': 4.746666666666667e-05, 'epoch': 0.1}


  5%|▌         | 200/3750 [00:56<16:18,  3.63it/s]

{'loss': 5.073, 'grad_norm': 3.3516128063201904, 'learning_rate': 4.7333333333333336e-05, 'epoch': 0.11}


  6%|▌         | 210/3750 [00:59<16:14,  3.63it/s]

{'loss': 5.0145, 'grad_norm': 2.4469032287597656, 'learning_rate': 4.72e-05, 'epoch': 0.11}


  6%|▌         | 220/3750 [01:02<16:15,  3.62it/s]

{'loss': 5.0398, 'grad_norm': 3.1949546337127686, 'learning_rate': 4.706666666666667e-05, 'epoch': 0.12}


  6%|▌         | 230/3750 [01:05<16:01,  3.66it/s]

{'loss': 5.0103, 'grad_norm': 3.285869598388672, 'learning_rate': 4.6933333333333333e-05, 'epoch': 0.12}


  6%|▋         | 240/3750 [01:07<16:01,  3.65it/s]

{'loss': 5.0083, 'grad_norm': 2.324749231338501, 'learning_rate': 4.6800000000000006e-05, 'epoch': 0.13}


  7%|▋         | 250/3750 [01:10<16:06,  3.62it/s]

{'loss': 5.023, 'grad_norm': 3.328219175338745, 'learning_rate': 4.666666666666667e-05, 'epoch': 0.13}


  7%|▋         | 260/3750 [01:13<17:03,  3.41it/s]

{'loss': 5.0458, 'grad_norm': 2.25479793548584, 'learning_rate': 4.653333333333334e-05, 'epoch': 0.14}


  7%|▋         | 270/3750 [01:16<16:40,  3.48it/s]

{'loss': 5.006, 'grad_norm': 2.280449390411377, 'learning_rate': 4.64e-05, 'epoch': 0.14}


  7%|▋         | 280/3750 [01:19<16:18,  3.55it/s]

{'loss': 5.0025, 'grad_norm': 2.6033685207366943, 'learning_rate': 4.626666666666667e-05, 'epoch': 0.15}


  8%|▊         | 290/3750 [01:22<16:02,  3.60it/s]

{'loss': 5.0106, 'grad_norm': 2.361663341522217, 'learning_rate': 4.6133333333333334e-05, 'epoch': 0.15}


  8%|▊         | 300/3750 [01:24<16:04,  3.58it/s]

{'loss': 5.0448, 'grad_norm': 2.4748215675354004, 'learning_rate': 4.600000000000001e-05, 'epoch': 0.16}


  8%|▊         | 310/3750 [01:27<16:26,  3.49it/s]

{'loss': 5.0258, 'grad_norm': 2.4906957149505615, 'learning_rate': 4.5866666666666666e-05, 'epoch': 0.17}


  9%|▊         | 320/3750 [01:30<16:11,  3.53it/s]

{'loss': 5.0225, 'grad_norm': 2.521146774291992, 'learning_rate': 4.573333333333333e-05, 'epoch': 0.17}


  9%|▉         | 330/3750 [01:33<15:48,  3.61it/s]

{'loss': 5.007, 'grad_norm': 2.1917550563812256, 'learning_rate': 4.5600000000000004e-05, 'epoch': 0.18}


  9%|▉         | 340/3750 [01:36<15:55,  3.57it/s]

{'loss': 5.0093, 'grad_norm': 3.236006021499634, 'learning_rate': 4.546666666666667e-05, 'epoch': 0.18}


  9%|▉         | 350/3750 [01:39<15:32,  3.65it/s]

{'loss': 5.0676, 'grad_norm': 2.4075136184692383, 'learning_rate': 4.5333333333333335e-05, 'epoch': 0.19}


 10%|▉         | 360/3750 [01:41<16:11,  3.49it/s]

{'loss': 5.0189, 'grad_norm': 2.412642002105713, 'learning_rate': 4.52e-05, 'epoch': 0.19}


 10%|▉         | 370/3750 [01:44<15:35,  3.61it/s]

{'loss': 5.0535, 'grad_norm': 4.331594944000244, 'learning_rate': 4.5066666666666667e-05, 'epoch': 0.2}


 10%|█         | 380/3750 [01:47<15:21,  3.66it/s]

{'loss': 5.0276, 'grad_norm': 3.5147268772125244, 'learning_rate': 4.493333333333333e-05, 'epoch': 0.2}


 10%|█         | 390/3750 [01:50<15:37,  3.58it/s]

{'loss': 5.0189, 'grad_norm': 3.382849931716919, 'learning_rate': 4.4800000000000005e-05, 'epoch': 0.21}


 11%|█         | 400/3750 [01:52<15:18,  3.65it/s]

{'loss': 5.0071, 'grad_norm': 3.036137342453003, 'learning_rate': 4.466666666666667e-05, 'epoch': 0.21}


 11%|█         | 410/3750 [01:55<15:25,  3.61it/s]

{'loss': 5.012, 'grad_norm': 3.6867856979370117, 'learning_rate': 4.4533333333333336e-05, 'epoch': 0.22}


 11%|█         | 420/3750 [01:58<15:09,  3.66it/s]

{'loss': 5.0057, 'grad_norm': 3.7774863243103027, 'learning_rate': 4.44e-05, 'epoch': 0.22}


 11%|█▏        | 430/3750 [02:01<15:17,  3.62it/s]

{'loss': 5.0234, 'grad_norm': 2.1909220218658447, 'learning_rate': 4.426666666666667e-05, 'epoch': 0.23}


 12%|█▏        | 440/3750 [02:03<15:05,  3.66it/s]

{'loss': 5.0301, 'grad_norm': 3.1183536052703857, 'learning_rate': 4.413333333333334e-05, 'epoch': 0.23}


 12%|█▏        | 450/3750 [02:06<14:59,  3.67it/s]

{'loss': 4.9856, 'grad_norm': 2.2674052715301514, 'learning_rate': 4.4000000000000006e-05, 'epoch': 0.24}


 12%|█▏        | 460/3750 [02:09<15:05,  3.64it/s]

{'loss': 5.01, 'grad_norm': 2.793706178665161, 'learning_rate': 4.3866666666666665e-05, 'epoch': 0.25}


 13%|█▎        | 470/3750 [02:12<15:03,  3.63it/s]

{'loss': 5.0613, 'grad_norm': 2.8413660526275635, 'learning_rate': 4.373333333333334e-05, 'epoch': 0.25}


 13%|█▎        | 480/3750 [02:14<14:52,  3.67it/s]

{'loss': 5.0215, 'grad_norm': 5.26024055480957, 'learning_rate': 4.36e-05, 'epoch': 0.26}


 13%|█▎        | 490/3750 [02:17<14:49,  3.66it/s]

{'loss': 5.0181, 'grad_norm': 2.222818374633789, 'learning_rate': 4.346666666666667e-05, 'epoch': 0.26}


 13%|█▎        | 500/3750 [02:20<16:04,  3.37it/s]

{'loss': 5.0216, 'grad_norm': 2.178652048110962, 'learning_rate': 4.3333333333333334e-05, 'epoch': 0.27}


 14%|█▎        | 510/3750 [02:28<21:23,  2.52it/s]  

{'loss': 5.0253, 'grad_norm': 2.221745252609253, 'learning_rate': 4.32e-05, 'epoch': 0.27}


 14%|█▍        | 520/3750 [02:31<15:44,  3.42it/s]

{'loss': 5.0185, 'grad_norm': 3.916652202606201, 'learning_rate': 4.3066666666666665e-05, 'epoch': 0.28}


 14%|█▍        | 530/3750 [02:34<14:47,  3.63it/s]

{'loss': 5.0163, 'grad_norm': 2.128480911254883, 'learning_rate': 4.293333333333334e-05, 'epoch': 0.28}


 14%|█▍        | 540/3750 [02:36<16:04,  3.33it/s]

{'loss': 5.0478, 'grad_norm': 3.044049024581909, 'learning_rate': 4.2800000000000004e-05, 'epoch': 0.29}


 15%|█▍        | 550/3750 [02:39<14:44,  3.62it/s]

{'loss': 5.0897, 'grad_norm': 3.8607373237609863, 'learning_rate': 4.266666666666667e-05, 'epoch': 0.29}


 15%|█▍        | 560/3750 [02:42<14:38,  3.63it/s]

{'loss': 5.0724, 'grad_norm': 2.1478264331817627, 'learning_rate': 4.2533333333333335e-05, 'epoch': 0.3}


 15%|█▌        | 570/3750 [02:45<14:37,  3.63it/s]

{'loss': 5.068, 'grad_norm': 1.9317256212234497, 'learning_rate': 4.24e-05, 'epoch': 0.3}


 15%|█▌        | 580/3750 [02:48<14:33,  3.63it/s]

{'loss': 5.0214, 'grad_norm': 2.108262062072754, 'learning_rate': 4.226666666666667e-05, 'epoch': 0.31}


 16%|█▌        | 590/3750 [02:50<14:23,  3.66it/s]

{'loss': 5.026, 'grad_norm': 3.8005459308624268, 'learning_rate': 4.213333333333334e-05, 'epoch': 0.31}


 16%|█▌        | 600/3750 [02:53<14:28,  3.63it/s]

{'loss': 5.0411, 'grad_norm': 2.241079330444336, 'learning_rate': 4.2e-05, 'epoch': 0.32}


 16%|█▋        | 610/3750 [02:56<14:26,  3.63it/s]

{'loss': 5.0304, 'grad_norm': 2.3359005451202393, 'learning_rate': 4.186666666666667e-05, 'epoch': 0.33}


 17%|█▋        | 620/3750 [02:59<14:20,  3.64it/s]

{'loss': 5.0564, 'grad_norm': 2.9313321113586426, 'learning_rate': 4.1733333333333336e-05, 'epoch': 0.33}


 17%|█▋        | 630/3750 [03:02<14:17,  3.64it/s]

{'loss': 5.048, 'grad_norm': 2.9440274238586426, 'learning_rate': 4.16e-05, 'epoch': 0.34}


 17%|█▋        | 640/3750 [03:04<14:17,  3.63it/s]

{'loss': 5.0419, 'grad_norm': 2.948061943054199, 'learning_rate': 4.146666666666667e-05, 'epoch': 0.34}


 17%|█▋        | 650/3750 [03:07<14:12,  3.64it/s]

{'loss': 5.0552, 'grad_norm': 2.8813626766204834, 'learning_rate': 4.133333333333333e-05, 'epoch': 0.35}


 18%|█▊        | 660/3750 [03:10<14:06,  3.65it/s]

{'loss': 5.0244, 'grad_norm': 3.027540445327759, 'learning_rate': 4.12e-05, 'epoch': 0.35}


 18%|█▊        | 670/3750 [03:13<14:03,  3.65it/s]

{'loss': 5.0259, 'grad_norm': 3.0776493549346924, 'learning_rate': 4.106666666666667e-05, 'epoch': 0.36}


 18%|█▊        | 680/3750 [03:15<14:02,  3.65it/s]

{'loss': 5.0404, 'grad_norm': 3.055420398712158, 'learning_rate': 4.093333333333334e-05, 'epoch': 0.36}


 18%|█▊        | 690/3750 [03:18<13:59,  3.64it/s]

{'loss': 5.0339, 'grad_norm': 1.8678576946258545, 'learning_rate': 4.08e-05, 'epoch': 0.37}


 19%|█▊        | 700/3750 [03:21<13:59,  3.63it/s]

{'loss': 5.033, 'grad_norm': 4.639257907867432, 'learning_rate': 4.066666666666667e-05, 'epoch': 0.37}


 19%|█▉        | 710/3750 [03:24<13:54,  3.64it/s]

{'loss': 5.0271, 'grad_norm': 1.818009614944458, 'learning_rate': 4.0533333333333334e-05, 'epoch': 0.38}


 19%|█▉        | 720/3750 [03:26<13:50,  3.65it/s]

{'loss': 5.0318, 'grad_norm': 4.501953601837158, 'learning_rate': 4.0400000000000006e-05, 'epoch': 0.38}


 19%|█▉        | 730/3750 [03:29<13:50,  3.64it/s]

{'loss': 5.0411, 'grad_norm': 1.9583051204681396, 'learning_rate': 4.026666666666667e-05, 'epoch': 0.39}


 20%|█▉        | 740/3750 [03:32<13:45,  3.65it/s]

{'loss': 5.03, 'grad_norm': 2.108957290649414, 'learning_rate': 4.013333333333333e-05, 'epoch': 0.39}


 20%|██        | 750/3750 [03:35<13:44,  3.64it/s]

{'loss': 5.0328, 'grad_norm': 3.7602288722991943, 'learning_rate': 4e-05, 'epoch': 0.4}


 20%|██        | 760/3750 [03:38<13:44,  3.63it/s]

{'loss': 5.0338, 'grad_norm': 3.8576583862304688, 'learning_rate': 3.986666666666667e-05, 'epoch': 0.41}


 21%|██        | 770/3750 [03:40<13:35,  3.65it/s]

{'loss': 5.0389, 'grad_norm': 2.8658151626586914, 'learning_rate': 3.9733333333333335e-05, 'epoch': 0.41}


 21%|██        | 780/3750 [03:43<13:35,  3.64it/s]

{'loss': 5.0444, 'grad_norm': 3.2357499599456787, 'learning_rate': 3.960000000000001e-05, 'epoch': 0.42}


 21%|██        | 790/3750 [03:46<13:36,  3.63it/s]

{'loss': 4.9979, 'grad_norm': 1.9431648254394531, 'learning_rate': 3.9466666666666666e-05, 'epoch': 0.42}


 21%|██▏       | 800/3750 [03:49<13:35,  3.62it/s]

{'loss': 5.0298, 'grad_norm': 3.1789300441741943, 'learning_rate': 3.933333333333333e-05, 'epoch': 0.43}


 22%|██▏       | 810/3750 [03:51<13:25,  3.65it/s]

{'loss': 5.0294, 'grad_norm': 3.763378858566284, 'learning_rate': 3.9200000000000004e-05, 'epoch': 0.43}


 22%|██▏       | 820/3750 [03:54<13:36,  3.59it/s]

{'loss': 5.0333, 'grad_norm': 3.882399082183838, 'learning_rate': 3.906666666666667e-05, 'epoch': 0.44}


 22%|██▏       | 830/3750 [03:57<13:28,  3.61it/s]

{'loss': 5.0333, 'grad_norm': 2.079432964324951, 'learning_rate': 3.8933333333333336e-05, 'epoch': 0.44}


 22%|██▏       | 840/3750 [04:00<13:17,  3.65it/s]

{'loss': 5.0268, 'grad_norm': 2.9932804107666016, 'learning_rate': 3.88e-05, 'epoch': 0.45}


 23%|██▎       | 850/3750 [04:03<13:13,  3.65it/s]

{'loss': 5.0146, 'grad_norm': 3.8778300285339355, 'learning_rate': 3.866666666666667e-05, 'epoch': 0.45}


 23%|██▎       | 860/3750 [04:05<13:13,  3.64it/s]

{'loss': 5.0189, 'grad_norm': 2.5839827060699463, 'learning_rate': 3.853333333333334e-05, 'epoch': 0.46}


 23%|██▎       | 870/3750 [04:08<13:11,  3.64it/s]

{'loss': 5.0477, 'grad_norm': 3.1881003379821777, 'learning_rate': 3.8400000000000005e-05, 'epoch': 0.46}


 23%|██▎       | 880/3750 [04:11<13:13,  3.62it/s]

{'loss': 5.0286, 'grad_norm': 3.0367848873138428, 'learning_rate': 3.8266666666666664e-05, 'epoch': 0.47}


 24%|██▎       | 890/3750 [04:14<13:04,  3.65it/s]

{'loss': 5.0121, 'grad_norm': 3.246896266937256, 'learning_rate': 3.8133333333333336e-05, 'epoch': 0.47}


 24%|██▍       | 900/3750 [04:16<13:01,  3.65it/s]

{'loss': 5.0339, 'grad_norm': 3.7992968559265137, 'learning_rate': 3.8e-05, 'epoch': 0.48}


 24%|██▍       | 910/3750 [04:19<13:01,  3.63it/s]

{'loss': 5.0048, 'grad_norm': 3.1165215969085693, 'learning_rate': 3.786666666666667e-05, 'epoch': 0.49}


 25%|██▍       | 920/3750 [04:22<12:53,  3.66it/s]

{'loss': 5.0142, 'grad_norm': 1.8914240598678589, 'learning_rate': 3.773333333333334e-05, 'epoch': 0.49}


 25%|██▍       | 930/3750 [04:25<13:00,  3.61it/s]

{'loss': 5.0158, 'grad_norm': 1.829842209815979, 'learning_rate': 3.76e-05, 'epoch': 0.5}


 25%|██▌       | 940/3750 [04:28<12:48,  3.66it/s]

{'loss': 5.0372, 'grad_norm': 1.8350920677185059, 'learning_rate': 3.7466666666666665e-05, 'epoch': 0.5}


 25%|██▌       | 950/3750 [04:30<13:18,  3.51it/s]

{'loss': 5.0468, 'grad_norm': 3.7605443000793457, 'learning_rate': 3.733333333333334e-05, 'epoch': 0.51}


 26%|██▌       | 960/3750 [04:33<12:46,  3.64it/s]

{'loss': 5.0471, 'grad_norm': 2.027670383453369, 'learning_rate': 3.72e-05, 'epoch': 0.51}


 26%|██▌       | 970/3750 [04:36<12:43,  3.64it/s]

{'loss': 5.0382, 'grad_norm': 2.4428720474243164, 'learning_rate': 3.706666666666667e-05, 'epoch': 0.52}


 26%|██▌       | 980/3750 [04:39<12:39,  3.65it/s]

{'loss': 5.0252, 'grad_norm': 1.7761033773422241, 'learning_rate': 3.6933333333333334e-05, 'epoch': 0.52}


 26%|██▋       | 990/3750 [04:41<12:41,  3.62it/s]

{'loss': 5.0174, 'grad_norm': 4.234087944030762, 'learning_rate': 3.68e-05, 'epoch': 0.53}


 27%|██▋       | 1000/3750 [04:44<12:42,  3.61it/s]

{'loss': 5.0118, 'grad_norm': 4.067380428314209, 'learning_rate': 3.6666666666666666e-05, 'epoch': 0.53}


 27%|██▋       | 1010/3750 [04:53<16:00,  2.85it/s]  

{'loss': 5.0368, 'grad_norm': 2.910496234893799, 'learning_rate': 3.653333333333334e-05, 'epoch': 0.54}


 27%|██▋       | 1020/3750 [04:56<12:35,  3.61it/s]

{'loss': 5.0261, 'grad_norm': 2.8998894691467285, 'learning_rate': 3.6400000000000004e-05, 'epoch': 0.54}


 27%|██▋       | 1030/3750 [04:59<12:24,  3.65it/s]

{'loss': 5.0418, 'grad_norm': 3.9191319942474365, 'learning_rate': 3.626666666666667e-05, 'epoch': 0.55}


 28%|██▊       | 1040/3750 [05:01<12:28,  3.62it/s]

{'loss': 5.0008, 'grad_norm': 1.909684658050537, 'learning_rate': 3.6133333333333335e-05, 'epoch': 0.55}


 28%|██▊       | 1050/3750 [05:04<12:21,  3.64it/s]

{'loss': 5.0431, 'grad_norm': 2.8629090785980225, 'learning_rate': 3.6e-05, 'epoch': 0.56}


 28%|██▊       | 1060/3750 [05:07<12:16,  3.65it/s]

{'loss': 5.0322, 'grad_norm': 1.9099479913711548, 'learning_rate': 3.586666666666667e-05, 'epoch': 0.57}


 29%|██▊       | 1070/3750 [05:10<12:15,  3.65it/s]

{'loss': 5.0157, 'grad_norm': 1.9441359043121338, 'learning_rate': 3.573333333333333e-05, 'epoch': 0.57}


 29%|██▉       | 1080/3750 [05:12<12:11,  3.65it/s]

{'loss': 5.0336, 'grad_norm': 2.965832233428955, 'learning_rate': 3.56e-05, 'epoch': 0.58}


 29%|██▉       | 1090/3750 [05:15<12:09,  3.65it/s]

{'loss': 5.0393, 'grad_norm': 1.7189311981201172, 'learning_rate': 3.546666666666667e-05, 'epoch': 0.58}


 29%|██▉       | 1100/3750 [05:18<12:13,  3.61it/s]

{'loss': 4.9885, 'grad_norm': 2.96598744392395, 'learning_rate': 3.5333333333333336e-05, 'epoch': 0.59}


 30%|██▉       | 1110/3750 [05:21<12:03,  3.65it/s]

{'loss': 5.0243, 'grad_norm': 3.7126829624176025, 'learning_rate': 3.52e-05, 'epoch': 0.59}


 30%|██▉       | 1120/3750 [05:23<12:02,  3.64it/s]

{'loss': 5.0238, 'grad_norm': 1.8280525207519531, 'learning_rate': 3.506666666666667e-05, 'epoch': 0.6}


 30%|███       | 1130/3750 [05:26<12:00,  3.64it/s]

{'loss': 5.0288, 'grad_norm': 3.8486149311065674, 'learning_rate': 3.493333333333333e-05, 'epoch': 0.6}


 30%|███       | 1140/3750 [05:29<11:55,  3.65it/s]

{'loss': 5.0357, 'grad_norm': 4.666848182678223, 'learning_rate': 3.48e-05, 'epoch': 0.61}


 31%|███       | 1150/3750 [05:32<11:51,  3.66it/s]

{'loss': 5.0244, 'grad_norm': 1.928571343421936, 'learning_rate': 3.466666666666667e-05, 'epoch': 0.61}


 31%|███       | 1160/3750 [05:34<11:52,  3.63it/s]

{'loss': 5.0564, 'grad_norm': 1.8924102783203125, 'learning_rate': 3.453333333333334e-05, 'epoch': 0.62}


 31%|███       | 1170/3750 [05:37<11:45,  3.66it/s]

{'loss': 5.043, 'grad_norm': 1.6663519144058228, 'learning_rate': 3.4399999999999996e-05, 'epoch': 0.62}


 31%|███▏      | 1180/3750 [05:40<11:46,  3.64it/s]

{'loss': 5.0059, 'grad_norm': 2.8179192543029785, 'learning_rate': 3.426666666666667e-05, 'epoch': 0.63}


 32%|███▏      | 1190/3750 [05:43<11:42,  3.65it/s]

{'loss': 5.0242, 'grad_norm': 2.45803165435791, 'learning_rate': 3.4133333333333334e-05, 'epoch': 0.63}


 32%|███▏      | 1200/3750 [05:45<11:40,  3.64it/s]

{'loss': 5.0396, 'grad_norm': 2.845400810241699, 'learning_rate': 3.4000000000000007e-05, 'epoch': 0.64}


 32%|███▏      | 1210/3750 [05:48<11:35,  3.65it/s]

{'loss': 5.0238, 'grad_norm': 1.6909408569335938, 'learning_rate': 3.3866666666666665e-05, 'epoch': 0.65}


 33%|███▎      | 1220/3750 [05:51<11:36,  3.63it/s]

{'loss': 5.0139, 'grad_norm': 2.781573534011841, 'learning_rate': 3.373333333333333e-05, 'epoch': 0.65}


 33%|███▎      | 1230/3750 [05:54<11:33,  3.63it/s]

{'loss': 5.0088, 'grad_norm': 2.9748482704162598, 'learning_rate': 3.3600000000000004e-05, 'epoch': 0.66}


 33%|███▎      | 1240/3750 [05:56<11:29,  3.64it/s]

{'loss': 5.0112, 'grad_norm': 2.954566717147827, 'learning_rate': 3.346666666666667e-05, 'epoch': 0.66}


 33%|███▎      | 1250/3750 [05:59<11:26,  3.64it/s]

{'loss': 5.0078, 'grad_norm': 2.901113986968994, 'learning_rate': 3.3333333333333335e-05, 'epoch': 0.67}


 34%|███▎      | 1260/3750 [06:02<11:20,  3.66it/s]

{'loss': 5.0143, 'grad_norm': 2.8856334686279297, 'learning_rate': 3.32e-05, 'epoch': 0.67}


 34%|███▍      | 1270/3750 [06:05<11:19,  3.65it/s]

{'loss': 5.0413, 'grad_norm': 3.6822962760925293, 'learning_rate': 3.3066666666666666e-05, 'epoch': 0.68}


 34%|███▍      | 1280/3750 [06:08<11:21,  3.62it/s]

{'loss': 5.0413, 'grad_norm': 1.64469313621521, 'learning_rate': 3.293333333333333e-05, 'epoch': 0.68}


 34%|███▍      | 1290/3750 [06:10<11:15,  3.64it/s]

{'loss': 5.0112, 'grad_norm': 1.7471096515655518, 'learning_rate': 3.2800000000000004e-05, 'epoch': 0.69}


 35%|███▍      | 1300/3750 [06:13<11:13,  3.64it/s]

{'loss': 5.0029, 'grad_norm': 1.7746918201446533, 'learning_rate': 3.266666666666667e-05, 'epoch': 0.69}


 35%|███▍      | 1310/3750 [06:16<11:10,  3.64it/s]

{'loss': 5.0244, 'grad_norm': 2.99838924407959, 'learning_rate': 3.253333333333333e-05, 'epoch': 0.7}


 35%|███▌      | 1320/3750 [06:19<11:08,  3.63it/s]

{'loss': 5.0245, 'grad_norm': 1.8908040523529053, 'learning_rate': 3.24e-05, 'epoch': 0.7}


 35%|███▌      | 1330/3750 [06:21<11:05,  3.64it/s]

{'loss': 5.0266, 'grad_norm': 2.872054100036621, 'learning_rate': 3.226666666666667e-05, 'epoch': 0.71}


 36%|███▌      | 1340/3750 [06:24<11:05,  3.62it/s]

{'loss': 5.0289, 'grad_norm': 1.9686161279678345, 'learning_rate': 3.213333333333334e-05, 'epoch': 0.71}


 36%|███▌      | 1350/3750 [06:27<10:57,  3.65it/s]

{'loss': 5.0275, 'grad_norm': 3.0692131519317627, 'learning_rate': 3.2000000000000005e-05, 'epoch': 0.72}


 36%|███▋      | 1360/3750 [06:30<10:53,  3.66it/s]

{'loss': 5.0351, 'grad_norm': 3.7537925243377686, 'learning_rate': 3.1866666666666664e-05, 'epoch': 0.73}


 37%|███▋      | 1370/3750 [06:32<11:08,  3.56it/s]

{'loss': 5.0391, 'grad_norm': 1.6657209396362305, 'learning_rate': 3.173333333333334e-05, 'epoch': 0.73}


 37%|███▋      | 1380/3750 [06:35<11:55,  3.31it/s]

{'loss': 5.0219, 'grad_norm': 1.8294754028320312, 'learning_rate': 3.16e-05, 'epoch': 0.74}


 37%|███▋      | 1390/3750 [06:38<11:30,  3.42it/s]

{'loss': 5.0151, 'grad_norm': 2.875176191329956, 'learning_rate': 3.146666666666667e-05, 'epoch': 0.74}


 37%|███▋      | 1400/3750 [06:41<11:19,  3.46it/s]

{'loss': 5.0448, 'grad_norm': 1.663692831993103, 'learning_rate': 3.1333333333333334e-05, 'epoch': 0.75}


 38%|███▊      | 1410/3750 [06:44<11:21,  3.43it/s]

{'loss': 5.0021, 'grad_norm': 1.7369016408920288, 'learning_rate': 3.12e-05, 'epoch': 0.75}


 38%|███▊      | 1420/3750 [06:47<12:18,  3.16it/s]

{'loss': 5.0219, 'grad_norm': 2.7993967533111572, 'learning_rate': 3.1066666666666665e-05, 'epoch': 0.76}


 38%|███▊      | 1430/3750 [06:51<11:24,  3.39it/s]

{'loss': 5.0246, 'grad_norm': 2.857797384262085, 'learning_rate': 3.093333333333334e-05, 'epoch': 0.76}


 38%|███▊      | 1440/3750 [06:53<10:56,  3.52it/s]

{'loss': 4.9893, 'grad_norm': 1.55553138256073, 'learning_rate': 3.08e-05, 'epoch': 0.77}


 39%|███▊      | 1450/3750 [06:57<11:55,  3.21it/s]

{'loss': 5.0462, 'grad_norm': 1.6703948974609375, 'learning_rate': 3.066666666666667e-05, 'epoch': 0.77}


 39%|███▉      | 1460/3750 [07:00<11:02,  3.45it/s]

{'loss': 5.0351, 'grad_norm': 3.6495416164398193, 'learning_rate': 3.0533333333333335e-05, 'epoch': 0.78}


 39%|███▉      | 1470/3750 [07:03<11:11,  3.40it/s]

{'loss': 5.0187, 'grad_norm': 1.533376693725586, 'learning_rate': 3.04e-05, 'epoch': 0.78}


 39%|███▉      | 1480/3750 [07:06<11:08,  3.39it/s]

{'loss': 5.0285, 'grad_norm': 1.6335912942886353, 'learning_rate': 3.0266666666666666e-05, 'epoch': 0.79}


 40%|███▉      | 1490/3750 [07:09<10:53,  3.46it/s]

{'loss': 5.006, 'grad_norm': 3.6588027477264404, 'learning_rate': 3.0133333333333335e-05, 'epoch': 0.79}


 40%|████      | 1500/3750 [07:12<10:29,  3.57it/s]

{'loss': 5.0492, 'grad_norm': 2.759274959564209, 'learning_rate': 3e-05, 'epoch': 0.8}


 40%|████      | 1510/3750 [07:20<12:52,  2.90it/s]  

{'loss': 5.0299, 'grad_norm': 1.5956676006317139, 'learning_rate': 2.986666666666667e-05, 'epoch': 0.81}


 41%|████      | 1520/3750 [07:23<10:17,  3.61it/s]

{'loss': 5.0225, 'grad_norm': 2.9244282245635986, 'learning_rate': 2.9733333333333336e-05, 'epoch': 0.81}


 41%|████      | 1530/3750 [07:25<10:08,  3.65it/s]

{'loss': 4.9955, 'grad_norm': 1.6059097051620483, 'learning_rate': 2.96e-05, 'epoch': 0.82}


 41%|████      | 1540/3750 [07:28<10:07,  3.64it/s]

{'loss': 5.0285, 'grad_norm': 2.8536908626556396, 'learning_rate': 2.946666666666667e-05, 'epoch': 0.82}


 41%|████▏     | 1550/3750 [07:31<10:02,  3.65it/s]

{'loss': 5.0138, 'grad_norm': 3.648655652999878, 'learning_rate': 2.9333333333333336e-05, 'epoch': 0.83}


 42%|████▏     | 1560/3750 [07:34<10:17,  3.55it/s]

{'loss': 5.0014, 'grad_norm': 4.321866035461426, 'learning_rate': 2.9199999999999998e-05, 'epoch': 0.83}


 42%|████▏     | 1570/3750 [07:37<10:43,  3.39it/s]

{'loss': 5.0195, 'grad_norm': 1.6863266229629517, 'learning_rate': 2.906666666666667e-05, 'epoch': 0.84}


 42%|████▏     | 1580/3750 [07:40<10:33,  3.42it/s]

{'loss': 5.0191, 'grad_norm': 1.7170865535736084, 'learning_rate': 2.8933333333333333e-05, 'epoch': 0.84}


 42%|████▏     | 1590/3750 [07:43<10:27,  3.44it/s]

{'loss': 5.0292, 'grad_norm': 2.7437572479248047, 'learning_rate': 2.88e-05, 'epoch': 0.85}


 43%|████▎     | 1600/3750 [07:45<10:20,  3.47it/s]

{'loss': 5.044, 'grad_norm': 1.534181833267212, 'learning_rate': 2.8666666666666668e-05, 'epoch': 0.85}


 43%|████▎     | 1610/3750 [07:48<10:03,  3.55it/s]

{'loss': 5.0074, 'grad_norm': 1.6131919622421265, 'learning_rate': 2.8533333333333333e-05, 'epoch': 0.86}


 43%|████▎     | 1620/3750 [07:51<09:54,  3.58it/s]

{'loss': 5.005, 'grad_norm': 1.5446022748947144, 'learning_rate': 2.84e-05, 'epoch': 0.86}


 43%|████▎     | 1630/3750 [07:54<11:13,  3.15it/s]

{'loss': 5.043, 'grad_norm': 1.5599403381347656, 'learning_rate': 2.8266666666666668e-05, 'epoch': 0.87}


 44%|████▎     | 1640/3750 [08:01<29:42,  1.18it/s]

{'loss': 5.0373, 'grad_norm': 2.8273677825927734, 'learning_rate': 2.8133333333333334e-05, 'epoch': 0.87}


 44%|████▍     | 1650/3750 [08:06<11:48,  2.96it/s]

{'loss': 5.0162, 'grad_norm': 2.7342097759246826, 'learning_rate': 2.8000000000000003e-05, 'epoch': 0.88}


 44%|████▍     | 1660/3750 [08:09<10:16,  3.39it/s]

{'loss': 5.0112, 'grad_norm': 2.8452296257019043, 'learning_rate': 2.786666666666667e-05, 'epoch': 0.89}


 45%|████▍     | 1670/3750 [08:12<10:10,  3.40it/s]

{'loss': 4.9995, 'grad_norm': 2.8334267139434814, 'learning_rate': 2.7733333333333334e-05, 'epoch': 0.89}


 45%|████▍     | 1680/3750 [08:15<10:44,  3.21it/s]

{'loss': 5.0106, 'grad_norm': 2.7116281986236572, 'learning_rate': 2.7600000000000003e-05, 'epoch': 0.9}


 45%|████▌     | 1690/3750 [08:18<10:58,  3.13it/s]

{'loss': 5.0379, 'grad_norm': 2.7910313606262207, 'learning_rate': 2.746666666666667e-05, 'epoch': 0.9}


 45%|████▌     | 1700/3750 [08:21<10:33,  3.23it/s]

{'loss': 5.0427, 'grad_norm': 1.532429575920105, 'learning_rate': 2.733333333333333e-05, 'epoch': 0.91}


 46%|████▌     | 1710/3750 [08:24<10:06,  3.36it/s]

{'loss': 5.0266, 'grad_norm': 1.4945663213729858, 'learning_rate': 2.7200000000000004e-05, 'epoch': 0.91}


 46%|████▌     | 1720/3750 [08:27<09:53,  3.42it/s]

{'loss': 5.0328, 'grad_norm': 1.5072131156921387, 'learning_rate': 2.706666666666667e-05, 'epoch': 0.92}


 46%|████▌     | 1730/3750 [08:30<09:34,  3.51it/s]

{'loss': 5.0184, 'grad_norm': 1.5087556838989258, 'learning_rate': 2.6933333333333332e-05, 'epoch': 0.92}


 46%|████▋     | 1740/3750 [08:33<09:25,  3.55it/s]

{'loss': 5.0356, 'grad_norm': 1.5227700471878052, 'learning_rate': 2.6800000000000004e-05, 'epoch': 0.93}


 47%|████▋     | 1750/3750 [08:36<09:16,  3.59it/s]

{'loss': 5.0416, 'grad_norm': 2.968895435333252, 'learning_rate': 2.6666666666666667e-05, 'epoch': 0.93}


 47%|████▋     | 1760/3750 [08:39<09:27,  3.51it/s]

{'loss': 5.0139, 'grad_norm': 1.5443806648254395, 'learning_rate': 2.6533333333333332e-05, 'epoch': 0.94}


 47%|████▋     | 1770/3750 [08:42<09:16,  3.56it/s]

{'loss': 5.0124, 'grad_norm': 1.4623035192489624, 'learning_rate': 2.64e-05, 'epoch': 0.94}


 47%|████▋     | 1780/3750 [08:45<09:50,  3.33it/s]

{'loss': 4.9914, 'grad_norm': 2.800147771835327, 'learning_rate': 2.6266666666666667e-05, 'epoch': 0.95}


 48%|████▊     | 1790/3750 [08:48<09:23,  3.48it/s]

{'loss': 5.0205, 'grad_norm': 3.218068838119507, 'learning_rate': 2.6133333333333333e-05, 'epoch': 0.95}


 48%|████▊     | 1800/3750 [08:51<09:23,  3.46it/s]

{'loss': 5.041, 'grad_norm': 3.6499836444854736, 'learning_rate': 2.6000000000000002e-05, 'epoch': 0.96}


 48%|████▊     | 1810/3750 [08:54<09:16,  3.49it/s]

{'loss': 5.0092, 'grad_norm': 1.5574818849563599, 'learning_rate': 2.5866666666666667e-05, 'epoch': 0.97}


 49%|████▊     | 1820/3750 [08:56<08:55,  3.61it/s]

{'loss': 5.0211, 'grad_norm': 2.828359842300415, 'learning_rate': 2.5733333333333337e-05, 'epoch': 0.97}


 49%|████▉     | 1830/3750 [08:59<09:26,  3.39it/s]

{'loss': 5.0354, 'grad_norm': 2.6359193325042725, 'learning_rate': 2.5600000000000002e-05, 'epoch': 0.98}


 49%|████▉     | 1840/3750 [09:02<08:58,  3.55it/s]

{'loss': 4.9997, 'grad_norm': 2.8757219314575195, 'learning_rate': 2.5466666666666668e-05, 'epoch': 0.98}


 49%|████▉     | 1850/3750 [09:05<08:53,  3.56it/s]

{'loss': 5.0088, 'grad_norm': 2.7583603858947754, 'learning_rate': 2.5333333333333337e-05, 'epoch': 0.99}


 50%|████▉     | 1860/3750 [09:08<08:59,  3.51it/s]

{'loss': 5.0297, 'grad_norm': 1.5706846714019775, 'learning_rate': 2.5200000000000003e-05, 'epoch': 0.99}


 50%|████▉     | 1870/3750 [09:11<08:59,  3.49it/s]

{'loss': 4.9978, 'grad_norm': 1.5416786670684814, 'learning_rate': 2.5066666666666665e-05, 'epoch': 1.0}


                                                   
 50%|█████     | 1875/3750 [09:21<09:54,  3.16it/s]

{'eval_loss': 5.014677047729492, 'eval_runtime': 8.1814, 'eval_samples_per_second': 366.687, 'eval_steps_per_second': 45.836, 'epoch': 1.0}


 50%|█████     | 1880/3750 [09:23<28:53,  1.08it/s]  

{'loss': 5.0163, 'grad_norm': 2.7050719261169434, 'learning_rate': 2.4933333333333334e-05, 'epoch': 1.0}


 50%|█████     | 1890/3750 [09:26<09:16,  3.34it/s]

{'loss': 5.0309, 'grad_norm': 1.5518616437911987, 'learning_rate': 2.48e-05, 'epoch': 1.01}


 51%|█████     | 1900/3750 [09:29<08:52,  3.47it/s]

{'loss': 5.0245, 'grad_norm': 1.5811983346939087, 'learning_rate': 2.466666666666667e-05, 'epoch': 1.01}


 51%|█████     | 1910/3750 [09:32<08:46,  3.50it/s]

{'loss': 5.0165, 'grad_norm': 2.8027079105377197, 'learning_rate': 2.4533333333333334e-05, 'epoch': 1.02}


 51%|█████     | 1920/3750 [09:34<08:40,  3.52it/s]

{'loss': 5.0175, 'grad_norm': 1.616010308265686, 'learning_rate': 2.44e-05, 'epoch': 1.02}


 51%|█████▏    | 1930/3750 [09:37<08:29,  3.57it/s]

{'loss': 5.019, 'grad_norm': 3.6236753463745117, 'learning_rate': 2.426666666666667e-05, 'epoch': 1.03}


 52%|█████▏    | 1940/3750 [09:40<08:39,  3.49it/s]

{'loss': 5.0139, 'grad_norm': 1.643107533454895, 'learning_rate': 2.4133333333333335e-05, 'epoch': 1.03}


 52%|█████▏    | 1950/3750 [09:43<08:19,  3.60it/s]

{'loss': 5.0177, 'grad_norm': 3.6250030994415283, 'learning_rate': 2.4e-05, 'epoch': 1.04}


 52%|█████▏    | 1960/3750 [09:46<08:31,  3.50it/s]

{'loss': 5.0238, 'grad_norm': 2.8704771995544434, 'learning_rate': 2.3866666666666666e-05, 'epoch': 1.05}


 53%|█████▎    | 1970/3750 [09:49<09:02,  3.28it/s]

{'loss': 5.0362, 'grad_norm': 2.882352352142334, 'learning_rate': 2.3733333333333335e-05, 'epoch': 1.05}


 53%|█████▎    | 1980/3750 [09:52<08:30,  3.47it/s]

{'loss': 5.0389, 'grad_norm': 1.601444959640503, 'learning_rate': 2.36e-05, 'epoch': 1.06}


 53%|█████▎    | 1990/3750 [09:55<08:16,  3.54it/s]

{'loss': 5.0047, 'grad_norm': 2.8128857612609863, 'learning_rate': 2.3466666666666667e-05, 'epoch': 1.06}


 53%|█████▎    | 2000/3750 [09:58<08:10,  3.56it/s]

{'loss': 5.0098, 'grad_norm': 1.6462125778198242, 'learning_rate': 2.3333333333333336e-05, 'epoch': 1.07}


 54%|█████▎    | 2010/3750 [10:06<09:55,  2.92it/s]

{'loss': 5.0251, 'grad_norm': 2.7701051235198975, 'learning_rate': 2.32e-05, 'epoch': 1.07}


 54%|█████▍    | 2020/3750 [10:09<07:56,  3.63it/s]

{'loss': 5.0166, 'grad_norm': 2.784715175628662, 'learning_rate': 2.3066666666666667e-05, 'epoch': 1.08}


 54%|█████▍    | 2030/3750 [10:12<07:52,  3.64it/s]

{'loss': 5.0114, 'grad_norm': 4.335941314697266, 'learning_rate': 2.2933333333333333e-05, 'epoch': 1.08}


 54%|█████▍    | 2040/3750 [10:15<07:49,  3.64it/s]

{'loss': 5.0351, 'grad_norm': 2.7849361896514893, 'learning_rate': 2.2800000000000002e-05, 'epoch': 1.09}


 55%|█████▍    | 2050/3750 [10:17<07:46,  3.65it/s]

{'loss': 5.02, 'grad_norm': 1.6537408828735352, 'learning_rate': 2.2666666666666668e-05, 'epoch': 1.09}


 55%|█████▍    | 2060/3750 [10:20<07:43,  3.64it/s]

{'loss': 5.009, 'grad_norm': 1.6093568801879883, 'learning_rate': 2.2533333333333333e-05, 'epoch': 1.1}


 55%|█████▌    | 2070/3750 [10:23<07:41,  3.64it/s]

{'loss': 5.0113, 'grad_norm': 1.6236519813537598, 'learning_rate': 2.2400000000000002e-05, 'epoch': 1.1}


 55%|█████▌    | 2080/3750 [10:26<07:41,  3.62it/s]

{'loss': 5.0277, 'grad_norm': 1.5431344509124756, 'learning_rate': 2.2266666666666668e-05, 'epoch': 1.11}


 56%|█████▌    | 2090/3750 [10:28<07:35,  3.65it/s]

{'loss': 5.0225, 'grad_norm': 1.4770076274871826, 'learning_rate': 2.2133333333333334e-05, 'epoch': 1.11}


 56%|█████▌    | 2100/3750 [10:31<07:32,  3.65it/s]

{'loss': 5.0183, 'grad_norm': 1.5542486906051636, 'learning_rate': 2.2000000000000003e-05, 'epoch': 1.12}


 56%|█████▋    | 2110/3750 [10:34<07:32,  3.63it/s]

{'loss': 5.0242, 'grad_norm': 3.8840460777282715, 'learning_rate': 2.186666666666667e-05, 'epoch': 1.13}


 57%|█████▋    | 2120/3750 [10:37<07:25,  3.66it/s]

{'loss': 5.0171, 'grad_norm': 2.774930238723755, 'learning_rate': 2.1733333333333334e-05, 'epoch': 1.13}


 57%|█████▋    | 2130/3750 [10:39<07:23,  3.65it/s]

{'loss': 5.0081, 'grad_norm': 1.7752742767333984, 'learning_rate': 2.16e-05, 'epoch': 1.14}


 57%|█████▋    | 2140/3750 [10:42<07:22,  3.64it/s]

{'loss': 5.0323, 'grad_norm': 2.872545003890991, 'learning_rate': 2.146666666666667e-05, 'epoch': 1.14}


 57%|█████▋    | 2150/3750 [10:45<07:31,  3.55it/s]

{'loss': 5.0054, 'grad_norm': 2.7893989086151123, 'learning_rate': 2.1333333333333335e-05, 'epoch': 1.15}


 58%|█████▊    | 2160/3750 [10:48<07:15,  3.65it/s]

{'loss': 5.0196, 'grad_norm': 1.636918306350708, 'learning_rate': 2.12e-05, 'epoch': 1.15}


 58%|█████▊    | 2170/3750 [10:51<07:11,  3.66it/s]

{'loss': 5.028, 'grad_norm': 1.55983304977417, 'learning_rate': 2.106666666666667e-05, 'epoch': 1.16}


 58%|█████▊    | 2180/3750 [10:53<07:12,  3.63it/s]

{'loss': 5.0185, 'grad_norm': 2.86751389503479, 'learning_rate': 2.0933333333333335e-05, 'epoch': 1.16}


 58%|█████▊    | 2190/3750 [10:56<07:08,  3.64it/s]

{'loss': 5.0033, 'grad_norm': 1.5145115852355957, 'learning_rate': 2.08e-05, 'epoch': 1.17}


 59%|█████▊    | 2200/3750 [10:59<07:07,  3.63it/s]

{'loss': 5.0168, 'grad_norm': 3.052233934402466, 'learning_rate': 2.0666666666666666e-05, 'epoch': 1.17}


 59%|█████▉    | 2210/3750 [11:02<07:01,  3.65it/s]

{'loss': 5.0273, 'grad_norm': 3.896301507949829, 'learning_rate': 2.0533333333333336e-05, 'epoch': 1.18}


 59%|█████▉    | 2220/3750 [11:04<06:59,  3.65it/s]

{'loss': 5.0019, 'grad_norm': 1.7871277332305908, 'learning_rate': 2.04e-05, 'epoch': 1.18}


 59%|█████▉    | 2230/3750 [11:07<06:55,  3.65it/s]

{'loss': 5.0211, 'grad_norm': 2.911146640777588, 'learning_rate': 2.0266666666666667e-05, 'epoch': 1.19}


 60%|█████▉    | 2240/3750 [11:10<06:55,  3.64it/s]

{'loss': 5.0255, 'grad_norm': 1.626009225845337, 'learning_rate': 2.0133333333333336e-05, 'epoch': 1.19}


 60%|██████    | 2250/3750 [11:13<06:50,  3.65it/s]

{'loss': 4.9926, 'grad_norm': 3.6207468509674072, 'learning_rate': 2e-05, 'epoch': 1.2}


 60%|██████    | 2260/3750 [11:15<06:49,  3.64it/s]

{'loss': 5.0093, 'grad_norm': 2.810546636581421, 'learning_rate': 1.9866666666666667e-05, 'epoch': 1.21}


 61%|██████    | 2270/3750 [11:18<06:45,  3.65it/s]

{'loss': 5.0189, 'grad_norm': 2.1305789947509766, 'learning_rate': 1.9733333333333333e-05, 'epoch': 1.21}


 61%|██████    | 2280/3750 [11:21<06:42,  3.65it/s]

{'loss': 5.0008, 'grad_norm': 1.5703930854797363, 'learning_rate': 1.9600000000000002e-05, 'epoch': 1.22}


 61%|██████    | 2290/3750 [11:24<06:40,  3.64it/s]

{'loss': 5.0374, 'grad_norm': 4.283422946929932, 'learning_rate': 1.9466666666666668e-05, 'epoch': 1.22}


 61%|██████▏   | 2300/3750 [11:26<06:38,  3.64it/s]

{'loss': 5.0155, 'grad_norm': 1.535410761833191, 'learning_rate': 1.9333333333333333e-05, 'epoch': 1.23}


 62%|██████▏   | 2310/3750 [11:29<06:36,  3.63it/s]

{'loss': 5.0282, 'grad_norm': 2.7922985553741455, 'learning_rate': 1.9200000000000003e-05, 'epoch': 1.23}


 62%|██████▏   | 2320/3750 [11:32<06:33,  3.64it/s]

{'loss': 5.0361, 'grad_norm': 3.9202346801757812, 'learning_rate': 1.9066666666666668e-05, 'epoch': 1.24}


 62%|██████▏   | 2330/3750 [11:35<06:36,  3.58it/s]

{'loss': 5.0348, 'grad_norm': 1.5553337335586548, 'learning_rate': 1.8933333333333334e-05, 'epoch': 1.24}


 62%|██████▏   | 2340/3750 [11:38<06:28,  3.63it/s]

{'loss': 5.0261, 'grad_norm': 3.690824031829834, 'learning_rate': 1.88e-05, 'epoch': 1.25}


 63%|██████▎   | 2350/3750 [11:40<06:33,  3.56it/s]

{'loss': 5.0427, 'grad_norm': 1.6337206363677979, 'learning_rate': 1.866666666666667e-05, 'epoch': 1.25}


 63%|██████▎   | 2360/3750 [11:43<06:21,  3.64it/s]

{'loss': 5.0068, 'grad_norm': 1.7313930988311768, 'learning_rate': 1.8533333333333334e-05, 'epoch': 1.26}


 63%|██████▎   | 2370/3750 [11:46<06:19,  3.64it/s]

{'loss': 5.0138, 'grad_norm': 1.6062843799591064, 'learning_rate': 1.84e-05, 'epoch': 1.26}


 63%|██████▎   | 2380/3750 [11:49<06:19,  3.61it/s]

{'loss': 5.0331, 'grad_norm': 1.6512880325317383, 'learning_rate': 1.826666666666667e-05, 'epoch': 1.27}


 64%|██████▎   | 2390/3750 [11:51<06:15,  3.63it/s]

{'loss': 5.0117, 'grad_norm': 1.650955080986023, 'learning_rate': 1.8133333333333335e-05, 'epoch': 1.27}


 64%|██████▍   | 2400/3750 [11:54<06:16,  3.59it/s]

{'loss': 5.0283, 'grad_norm': 3.65177583694458, 'learning_rate': 1.8e-05, 'epoch': 1.28}


 64%|██████▍   | 2410/3750 [11:57<06:07,  3.65it/s]

{'loss': 5.0179, 'grad_norm': 2.728630304336548, 'learning_rate': 1.7866666666666666e-05, 'epoch': 1.29}


 65%|██████▍   | 2420/3750 [12:00<06:04,  3.65it/s]

{'loss': 5.011, 'grad_norm': 2.699319362640381, 'learning_rate': 1.7733333333333335e-05, 'epoch': 1.29}


 65%|██████▍   | 2430/3750 [12:03<06:03,  3.63it/s]

{'loss': 5.0183, 'grad_norm': 2.0057737827301025, 'learning_rate': 1.76e-05, 'epoch': 1.3}


 65%|██████▌   | 2440/3750 [12:05<05:59,  3.64it/s]

{'loss': 5.0108, 'grad_norm': 3.426788091659546, 'learning_rate': 1.7466666666666667e-05, 'epoch': 1.3}


 65%|██████▌   | 2450/3750 [12:08<05:57,  3.64it/s]

{'loss': 5.0064, 'grad_norm': 1.7832400798797607, 'learning_rate': 1.7333333333333336e-05, 'epoch': 1.31}


 66%|██████▌   | 2460/3750 [12:11<06:10,  3.48it/s]

{'loss': 5.0075, 'grad_norm': 2.923548698425293, 'learning_rate': 1.7199999999999998e-05, 'epoch': 1.31}


 66%|██████▌   | 2470/3750 [12:14<06:03,  3.52it/s]

{'loss': 5.0375, 'grad_norm': 2.762422561645508, 'learning_rate': 1.7066666666666667e-05, 'epoch': 1.32}


 66%|██████▌   | 2480/3750 [12:17<05:51,  3.62it/s]

{'loss': 5.025, 'grad_norm': 3.1968960762023926, 'learning_rate': 1.6933333333333333e-05, 'epoch': 1.32}


 66%|██████▋   | 2490/3750 [12:20<05:50,  3.60it/s]

{'loss': 5.0137, 'grad_norm': 1.553161859512329, 'learning_rate': 1.6800000000000002e-05, 'epoch': 1.33}


 67%|██████▋   | 2500/3750 [12:22<05:59,  3.47it/s]

{'loss': 5.0148, 'grad_norm': 1.5934566259384155, 'learning_rate': 1.6666666666666667e-05, 'epoch': 1.33}


 67%|██████▋   | 2510/3750 [12:32<07:22,  2.80it/s]

{'loss': 4.9947, 'grad_norm': 3.5251617431640625, 'learning_rate': 1.6533333333333333e-05, 'epoch': 1.34}


 67%|██████▋   | 2520/3750 [12:35<05:38,  3.63it/s]

{'loss': 5.0001, 'grad_norm': 1.5931254625320435, 'learning_rate': 1.6400000000000002e-05, 'epoch': 1.34}


 67%|██████▋   | 2530/3750 [12:37<05:32,  3.66it/s]

{'loss': 5.0241, 'grad_norm': 1.622199535369873, 'learning_rate': 1.6266666666666665e-05, 'epoch': 1.35}


 68%|██████▊   | 2540/3750 [12:40<05:31,  3.66it/s]

{'loss': 5.0061, 'grad_norm': 2.8794193267822266, 'learning_rate': 1.6133333333333334e-05, 'epoch': 1.35}


 68%|██████▊   | 2550/3750 [12:43<05:27,  3.66it/s]

{'loss': 5.0215, 'grad_norm': 1.6137274503707886, 'learning_rate': 1.6000000000000003e-05, 'epoch': 1.36}


 68%|██████▊   | 2560/3750 [12:46<05:24,  3.67it/s]

{'loss': 5.0319, 'grad_norm': 1.5526845455169678, 'learning_rate': 1.586666666666667e-05, 'epoch': 1.37}


 69%|██████▊   | 2570/3750 [12:48<05:25,  3.63it/s]

{'loss': 5.0177, 'grad_norm': 2.7664167881011963, 'learning_rate': 1.5733333333333334e-05, 'epoch': 1.37}


 69%|██████▉   | 2580/3750 [12:51<05:19,  3.67it/s]

{'loss': 5.014, 'grad_norm': 2.6837570667266846, 'learning_rate': 1.56e-05, 'epoch': 1.38}


 69%|██████▉   | 2590/3750 [12:54<05:16,  3.66it/s]

{'loss': 5.0265, 'grad_norm': 1.594706416130066, 'learning_rate': 1.546666666666667e-05, 'epoch': 1.38}


 69%|██████▉   | 2600/3750 [12:57<05:15,  3.64it/s]

{'loss': 5.0157, 'grad_norm': 1.772100567817688, 'learning_rate': 1.5333333333333334e-05, 'epoch': 1.39}


 70%|██████▉   | 2610/3750 [12:59<05:10,  3.67it/s]

{'loss': 4.9997, 'grad_norm': 3.2264328002929688, 'learning_rate': 1.52e-05, 'epoch': 1.39}


 70%|██████▉   | 2620/3750 [13:02<05:08,  3.66it/s]

{'loss': 5.0142, 'grad_norm': 1.6414415836334229, 'learning_rate': 1.5066666666666668e-05, 'epoch': 1.4}


 70%|███████   | 2630/3750 [13:05<05:07,  3.65it/s]

{'loss': 5.0397, 'grad_norm': 1.4844627380371094, 'learning_rate': 1.4933333333333335e-05, 'epoch': 1.4}


 70%|███████   | 2640/3750 [13:08<05:02,  3.66it/s]

{'loss': 5.0185, 'grad_norm': 1.8188114166259766, 'learning_rate': 1.48e-05, 'epoch': 1.41}


 71%|███████   | 2650/3750 [13:10<05:00,  3.66it/s]

{'loss': 5.0345, 'grad_norm': 4.312217712402344, 'learning_rate': 1.4666666666666668e-05, 'epoch': 1.41}


 71%|███████   | 2660/3750 [13:13<04:58,  3.65it/s]

{'loss': 5.0279, 'grad_norm': 2.8527824878692627, 'learning_rate': 1.4533333333333335e-05, 'epoch': 1.42}


 71%|███████   | 2670/3750 [13:16<04:55,  3.65it/s]

{'loss': 5.0233, 'grad_norm': 1.580461025238037, 'learning_rate': 1.44e-05, 'epoch': 1.42}


 71%|███████▏  | 2680/3750 [13:19<04:52,  3.66it/s]

{'loss': 5.0234, 'grad_norm': 1.529295802116394, 'learning_rate': 1.4266666666666667e-05, 'epoch': 1.43}


 72%|███████▏  | 2690/3750 [13:21<04:49,  3.66it/s]

{'loss': 5.0176, 'grad_norm': 2.697564125061035, 'learning_rate': 1.4133333333333334e-05, 'epoch': 1.43}


 72%|███████▏  | 2700/3750 [13:24<04:47,  3.66it/s]

{'loss': 5.0322, 'grad_norm': 2.8068814277648926, 'learning_rate': 1.4000000000000001e-05, 'epoch': 1.44}


 72%|███████▏  | 2710/3750 [13:27<04:43,  3.66it/s]

{'loss': 5.0327, 'grad_norm': 1.6408483982086182, 'learning_rate': 1.3866666666666667e-05, 'epoch': 1.45}


 73%|███████▎  | 2720/3750 [13:30<04:42,  3.64it/s]

{'loss': 5.0163, 'grad_norm': 2.869729518890381, 'learning_rate': 1.3733333333333335e-05, 'epoch': 1.45}


 73%|███████▎  | 2730/3750 [13:32<04:38,  3.66it/s]

{'loss': 5.0188, 'grad_norm': 2.7319836616516113, 'learning_rate': 1.3600000000000002e-05, 'epoch': 1.46}


 73%|███████▎  | 2740/3750 [13:35<04:35,  3.67it/s]

{'loss': 5.0038, 'grad_norm': 3.650707721710205, 'learning_rate': 1.3466666666666666e-05, 'epoch': 1.46}


 73%|███████▎  | 2750/3750 [13:38<04:33,  3.65it/s]

{'loss': 5.0126, 'grad_norm': 3.342423915863037, 'learning_rate': 1.3333333333333333e-05, 'epoch': 1.47}


 74%|███████▎  | 2760/3750 [13:40<04:29,  3.67it/s]

{'loss': 5.0108, 'grad_norm': 2.8106629848480225, 'learning_rate': 1.32e-05, 'epoch': 1.47}


 74%|███████▍  | 2770/3750 [13:43<04:26,  3.67it/s]

{'loss': 5.0345, 'grad_norm': 2.7609376907348633, 'learning_rate': 1.3066666666666666e-05, 'epoch': 1.48}


 74%|███████▍  | 2780/3750 [13:46<04:25,  3.66it/s]

{'loss': 5.0195, 'grad_norm': 2.3294999599456787, 'learning_rate': 1.2933333333333334e-05, 'epoch': 1.48}


 74%|███████▍  | 2790/3750 [13:49<04:22,  3.66it/s]

{'loss': 5.0304, 'grad_norm': 4.123051643371582, 'learning_rate': 1.2800000000000001e-05, 'epoch': 1.49}


 75%|███████▍  | 2800/3750 [13:51<04:18,  3.67it/s]

{'loss': 5.0269, 'grad_norm': 2.83894419670105, 'learning_rate': 1.2666666666666668e-05, 'epoch': 1.49}


 75%|███████▍  | 2810/3750 [13:54<04:17,  3.65it/s]

{'loss': 5.0324, 'grad_norm': 4.468384742736816, 'learning_rate': 1.2533333333333332e-05, 'epoch': 1.5}


 75%|███████▌  | 2820/3750 [13:57<04:14,  3.66it/s]

{'loss': 5.0157, 'grad_norm': 1.5796661376953125, 'learning_rate': 1.24e-05, 'epoch': 1.5}


 75%|███████▌  | 2830/3750 [14:00<04:10,  3.67it/s]

{'loss': 5.0069, 'grad_norm': 1.8194960355758667, 'learning_rate': 1.2266666666666667e-05, 'epoch': 1.51}


 76%|███████▌  | 2840/3750 [14:02<04:08,  3.66it/s]

{'loss': 5.0229, 'grad_norm': 2.7857775688171387, 'learning_rate': 1.2133333333333335e-05, 'epoch': 1.51}


 76%|███████▌  | 2850/3750 [14:05<04:06,  3.65it/s]

{'loss': 5.0352, 'grad_norm': 2.753293514251709, 'learning_rate': 1.2e-05, 'epoch': 1.52}


 76%|███████▋  | 2860/3750 [14:08<04:02,  3.67it/s]

{'loss': 5.0158, 'grad_norm': 1.556828260421753, 'learning_rate': 1.1866666666666668e-05, 'epoch': 1.53}


 77%|███████▋  | 2870/3750 [14:11<04:02,  3.63it/s]

{'loss': 5.0264, 'grad_norm': 2.812811851501465, 'learning_rate': 1.1733333333333333e-05, 'epoch': 1.53}


 77%|███████▋  | 2880/3750 [14:13<03:57,  3.67it/s]

{'loss': 5.0331, 'grad_norm': 2.8100833892822266, 'learning_rate': 1.16e-05, 'epoch': 1.54}


 77%|███████▋  | 2890/3750 [14:16<03:54,  3.67it/s]

{'loss': 5.0334, 'grad_norm': 3.6167333126068115, 'learning_rate': 1.1466666666666666e-05, 'epoch': 1.54}


 77%|███████▋  | 2900/3750 [14:19<03:53,  3.64it/s]

{'loss': 5.011, 'grad_norm': 1.6547760963439941, 'learning_rate': 1.1333333333333334e-05, 'epoch': 1.55}


 78%|███████▊  | 2910/3750 [14:21<03:49,  3.66it/s]

{'loss': 5.0162, 'grad_norm': 2.7664430141448975, 'learning_rate': 1.1200000000000001e-05, 'epoch': 1.55}


 78%|███████▊  | 2920/3750 [14:24<03:46,  3.67it/s]

{'loss': 5.0211, 'grad_norm': 3.627753496170044, 'learning_rate': 1.1066666666666667e-05, 'epoch': 1.56}


 78%|███████▊  | 2930/3750 [14:27<03:45,  3.64it/s]

{'loss': 5.0057, 'grad_norm': 1.62455415725708, 'learning_rate': 1.0933333333333334e-05, 'epoch': 1.56}


 78%|███████▊  | 2940/3750 [14:30<03:41,  3.66it/s]

{'loss': 5.0122, 'grad_norm': 1.5753886699676514, 'learning_rate': 1.08e-05, 'epoch': 1.57}


 79%|███████▊  | 2950/3750 [14:32<03:37,  3.67it/s]

{'loss': 5.0157, 'grad_norm': 3.5693485736846924, 'learning_rate': 1.0666666666666667e-05, 'epoch': 1.57}


 79%|███████▉  | 2960/3750 [14:35<03:36,  3.64it/s]

{'loss': 5.0135, 'grad_norm': 2.803635597229004, 'learning_rate': 1.0533333333333335e-05, 'epoch': 1.58}


 79%|███████▉  | 2970/3750 [14:38<03:32,  3.67it/s]

{'loss': 5.017, 'grad_norm': 3.702570676803589, 'learning_rate': 1.04e-05, 'epoch': 1.58}


 79%|███████▉  | 2980/3750 [14:41<03:30,  3.66it/s]

{'loss': 5.0149, 'grad_norm': 4.071106433868408, 'learning_rate': 1.0266666666666668e-05, 'epoch': 1.59}


 80%|███████▉  | 2990/3750 [14:43<03:28,  3.64it/s]

{'loss': 5.0105, 'grad_norm': 3.579202175140381, 'learning_rate': 1.0133333333333333e-05, 'epoch': 1.59}


 80%|████████  | 3000/3750 [14:46<03:25,  3.66it/s]

{'loss': 5.024, 'grad_norm': 2.7905328273773193, 'learning_rate': 1e-05, 'epoch': 1.6}


 80%|████████  | 3010/3750 [14:56<04:23,  2.80it/s]

{'loss': 5.02, 'grad_norm': 3.0755069255828857, 'learning_rate': 9.866666666666667e-06, 'epoch': 1.61}


 81%|████████  | 3020/3750 [14:58<03:22,  3.60it/s]

{'loss': 5.0343, 'grad_norm': 1.5295876264572144, 'learning_rate': 9.733333333333334e-06, 'epoch': 1.61}


 81%|████████  | 3030/3750 [15:01<03:17,  3.65it/s]

{'loss': 5.0166, 'grad_norm': 2.7622780799865723, 'learning_rate': 9.600000000000001e-06, 'epoch': 1.62}


 81%|████████  | 3040/3750 [15:04<03:14,  3.65it/s]

{'loss': 5.0204, 'grad_norm': 2.7418594360351562, 'learning_rate': 9.466666666666667e-06, 'epoch': 1.62}


 81%|████████▏ | 3050/3750 [15:07<03:13,  3.62it/s]

{'loss': 5.0213, 'grad_norm': 1.6133452653884888, 'learning_rate': 9.333333333333334e-06, 'epoch': 1.63}


 82%|████████▏ | 3060/3750 [15:09<03:08,  3.66it/s]

{'loss': 5.0189, 'grad_norm': 2.7249653339385986, 'learning_rate': 9.2e-06, 'epoch': 1.63}


 82%|████████▏ | 3070/3750 [15:12<03:06,  3.65it/s]

{'loss': 5.0232, 'grad_norm': 2.722229242324829, 'learning_rate': 9.066666666666667e-06, 'epoch': 1.64}


 82%|████████▏ | 3080/3750 [15:15<03:03,  3.64it/s]

{'loss': 5.0249, 'grad_norm': 2.746267557144165, 'learning_rate': 8.933333333333333e-06, 'epoch': 1.64}


 82%|████████▏ | 3090/3750 [15:18<03:00,  3.65it/s]

{'loss': 5.0121, 'grad_norm': 2.6491503715515137, 'learning_rate': 8.8e-06, 'epoch': 1.65}


 83%|████████▎ | 3100/3750 [15:20<02:58,  3.64it/s]

{'loss': 5.0177, 'grad_norm': 1.5857007503509521, 'learning_rate': 8.666666666666668e-06, 'epoch': 1.65}


 83%|████████▎ | 3110/3750 [15:23<02:55,  3.65it/s]

{'loss': 5.0055, 'grad_norm': 1.5128023624420166, 'learning_rate': 8.533333333333334e-06, 'epoch': 1.66}


 83%|████████▎ | 3120/3750 [15:26<02:58,  3.53it/s]

{'loss': 5.0188, 'grad_norm': 1.4895061254501343, 'learning_rate': 8.400000000000001e-06, 'epoch': 1.66}


 83%|████████▎ | 3130/3750 [15:29<02:50,  3.63it/s]

{'loss': 5.0142, 'grad_norm': 1.4211151599884033, 'learning_rate': 8.266666666666667e-06, 'epoch': 1.67}


 84%|████████▎ | 3140/3750 [15:31<02:47,  3.65it/s]

{'loss': 5.0168, 'grad_norm': 2.76082181930542, 'learning_rate': 8.133333333333332e-06, 'epoch': 1.67}


 84%|████████▍ | 3150/3750 [15:34<02:44,  3.66it/s]

{'loss': 5.0013, 'grad_norm': 2.7276501655578613, 'learning_rate': 8.000000000000001e-06, 'epoch': 1.68}


 84%|████████▍ | 3160/3750 [15:37<02:41,  3.64it/s]

{'loss': 5.0192, 'grad_norm': 1.6432873010635376, 'learning_rate': 7.866666666666667e-06, 'epoch': 1.69}


 85%|████████▍ | 3170/3750 [15:40<02:40,  3.62it/s]

{'loss': 5.0178, 'grad_norm': 3.6139140129089355, 'learning_rate': 7.733333333333334e-06, 'epoch': 1.69}


 85%|████████▍ | 3180/3750 [15:42<02:36,  3.65it/s]

{'loss': 5.0105, 'grad_norm': 3.393378734588623, 'learning_rate': 7.6e-06, 'epoch': 1.7}


 85%|████████▌ | 3190/3750 [15:45<02:34,  3.63it/s]

{'loss': 5.0087, 'grad_norm': 1.4618873596191406, 'learning_rate': 7.4666666666666675e-06, 'epoch': 1.7}


 85%|████████▌ | 3200/3750 [15:48<02:31,  3.63it/s]

{'loss': 5.0201, 'grad_norm': 1.529545783996582, 'learning_rate': 7.333333333333334e-06, 'epoch': 1.71}


 86%|████████▌ | 3210/3750 [15:51<02:28,  3.63it/s]

{'loss': 5.0189, 'grad_norm': 2.743926763534546, 'learning_rate': 7.2e-06, 'epoch': 1.71}


 86%|████████▌ | 3220/3750 [15:54<02:25,  3.64it/s]

{'loss': 5.0226, 'grad_norm': 3.864109754562378, 'learning_rate': 7.066666666666667e-06, 'epoch': 1.72}


 86%|████████▌ | 3230/3750 [15:56<02:23,  3.64it/s]

{'loss': 5.0135, 'grad_norm': 3.610963821411133, 'learning_rate': 6.933333333333334e-06, 'epoch': 1.72}


 86%|████████▋ | 3240/3750 [15:59<02:19,  3.65it/s]

{'loss': 5.0117, 'grad_norm': 2.727806806564331, 'learning_rate': 6.800000000000001e-06, 'epoch': 1.73}


 87%|████████▋ | 3250/3750 [16:02<02:16,  3.65it/s]

{'loss': 5.025, 'grad_norm': 1.505232572555542, 'learning_rate': 6.666666666666667e-06, 'epoch': 1.73}


 87%|████████▋ | 3260/3750 [16:04<02:13,  3.66it/s]

{'loss': 5.0153, 'grad_norm': 4.289198875427246, 'learning_rate': 6.533333333333333e-06, 'epoch': 1.74}


 87%|████████▋ | 3270/3750 [16:07<02:10,  3.67it/s]

{'loss': 5.0003, 'grad_norm': 1.933202862739563, 'learning_rate': 6.4000000000000006e-06, 'epoch': 1.74}


 87%|████████▋ | 3280/3750 [16:10<02:08,  3.65it/s]

{'loss': 5.0195, 'grad_norm': 1.4971963167190552, 'learning_rate': 6.266666666666666e-06, 'epoch': 1.75}


 88%|████████▊ | 3290/3750 [16:13<02:08,  3.58it/s]

{'loss': 5.0099, 'grad_norm': 1.550596833229065, 'learning_rate': 6.133333333333334e-06, 'epoch': 1.75}


 88%|████████▊ | 3300/3750 [16:15<02:03,  3.66it/s]

{'loss': 5.0149, 'grad_norm': 1.5357730388641357, 'learning_rate': 6e-06, 'epoch': 1.76}


 88%|████████▊ | 3310/3750 [16:18<02:00,  3.66it/s]

{'loss': 5.0209, 'grad_norm': 3.0172648429870605, 'learning_rate': 5.866666666666667e-06, 'epoch': 1.77}


 89%|████████▊ | 3320/3750 [16:21<01:57,  3.65it/s]

{'loss': 5.0265, 'grad_norm': 2.8280842304229736, 'learning_rate': 5.733333333333333e-06, 'epoch': 1.77}


 89%|████████▉ | 3330/3750 [16:24<01:54,  3.67it/s]

{'loss': 5.0248, 'grad_norm': 2.7648000717163086, 'learning_rate': 5.600000000000001e-06, 'epoch': 1.78}


 89%|████████▉ | 3340/3750 [16:26<01:52,  3.64it/s]

{'loss': 5.0213, 'grad_norm': 1.5429611206054688, 'learning_rate': 5.466666666666667e-06, 'epoch': 1.78}


 89%|████████▉ | 3350/3750 [16:29<01:53,  3.52it/s]

{'loss': 5.0104, 'grad_norm': 1.7775843143463135, 'learning_rate': 5.333333333333334e-06, 'epoch': 1.79}


 90%|████████▉ | 3360/3750 [16:32<01:46,  3.66it/s]

{'loss': 5.0119, 'grad_norm': 2.7464494705200195, 'learning_rate': 5.2e-06, 'epoch': 1.79}


 90%|████████▉ | 3370/3750 [16:35<01:45,  3.59it/s]

{'loss': 5.0185, 'grad_norm': 2.799138069152832, 'learning_rate': 5.066666666666667e-06, 'epoch': 1.8}


 90%|█████████ | 3380/3750 [16:37<01:41,  3.65it/s]

{'loss': 5.0195, 'grad_norm': 2.076788902282715, 'learning_rate': 4.933333333333333e-06, 'epoch': 1.8}


 90%|█████████ | 3390/3750 [16:40<01:38,  3.67it/s]

{'loss': 5.0172, 'grad_norm': 3.448338747024536, 'learning_rate': 4.800000000000001e-06, 'epoch': 1.81}


 91%|█████████ | 3400/3750 [16:43<01:35,  3.67it/s]

{'loss': 5.0135, 'grad_norm': 1.595160961151123, 'learning_rate': 4.666666666666667e-06, 'epoch': 1.81}


 91%|█████████ | 3410/3750 [16:46<01:33,  3.65it/s]

{'loss': 5.0234, 'grad_norm': 3.796657085418701, 'learning_rate': 4.533333333333334e-06, 'epoch': 1.82}


 91%|█████████ | 3420/3750 [16:48<01:30,  3.65it/s]

{'loss': 5.0158, 'grad_norm': 1.634120225906372, 'learning_rate': 4.4e-06, 'epoch': 1.82}


 91%|█████████▏| 3430/3750 [16:51<01:27,  3.67it/s]

{'loss': 5.02, 'grad_norm': 1.5951825380325317, 'learning_rate': 4.266666666666667e-06, 'epoch': 1.83}


 92%|█████████▏| 3440/3750 [16:54<01:25,  3.64it/s]

{'loss': 4.9965, 'grad_norm': 2.0029680728912354, 'learning_rate': 4.133333333333333e-06, 'epoch': 1.83}


 92%|█████████▏| 3450/3750 [16:57<01:21,  3.66it/s]

{'loss': 5.0122, 'grad_norm': 2.7034502029418945, 'learning_rate': 4.000000000000001e-06, 'epoch': 1.84}


 92%|█████████▏| 3460/3750 [16:59<01:19,  3.63it/s]

{'loss': 5.0141, 'grad_norm': 2.8401877880096436, 'learning_rate': 3.866666666666667e-06, 'epoch': 1.85}


 93%|█████████▎| 3470/3750 [17:02<01:16,  3.65it/s]

{'loss': 5.0244, 'grad_norm': 2.787203788757324, 'learning_rate': 3.7333333333333337e-06, 'epoch': 1.85}


 93%|█████████▎| 3480/3750 [17:05<01:13,  3.65it/s]

{'loss': 5.011, 'grad_norm': 1.5763678550720215, 'learning_rate': 3.6e-06, 'epoch': 1.86}


 93%|█████████▎| 3490/3750 [17:08<01:11,  3.65it/s]

{'loss': 5.0198, 'grad_norm': 2.8147525787353516, 'learning_rate': 3.466666666666667e-06, 'epoch': 1.86}


 93%|█████████▎| 3500/3750 [17:10<01:08,  3.66it/s]

{'loss': 5.0193, 'grad_norm': 1.5708460807800293, 'learning_rate': 3.3333333333333333e-06, 'epoch': 1.87}


 94%|█████████▎| 3510/3750 [17:19<01:22,  2.93it/s]

{'loss': 5.0185, 'grad_norm': 1.9528287649154663, 'learning_rate': 3.2000000000000003e-06, 'epoch': 1.87}


 94%|█████████▍| 3520/3750 [17:21<01:03,  3.63it/s]

{'loss': 5.0284, 'grad_norm': 3.846543788909912, 'learning_rate': 3.066666666666667e-06, 'epoch': 1.88}


 94%|█████████▍| 3530/3750 [17:24<01:00,  3.66it/s]

{'loss': 5.0129, 'grad_norm': 3.6976613998413086, 'learning_rate': 2.9333333333333333e-06, 'epoch': 1.88}


 94%|█████████▍| 3540/3750 [17:27<00:57,  3.67it/s]

{'loss': 5.0191, 'grad_norm': 2.905935525894165, 'learning_rate': 2.8000000000000003e-06, 'epoch': 1.89}


 95%|█████████▍| 3550/3750 [17:30<00:54,  3.66it/s]

{'loss': 5.0198, 'grad_norm': 1.524248719215393, 'learning_rate': 2.666666666666667e-06, 'epoch': 1.89}


 95%|█████████▍| 3560/3750 [17:32<00:51,  3.66it/s]

{'loss': 5.0111, 'grad_norm': 1.6580009460449219, 'learning_rate': 2.5333333333333334e-06, 'epoch': 1.9}


 95%|█████████▌| 3570/3750 [17:35<00:49,  3.65it/s]

{'loss': 5.0233, 'grad_norm': 2.6744492053985596, 'learning_rate': 2.4000000000000003e-06, 'epoch': 1.9}


 95%|█████████▌| 3580/3750 [17:38<00:46,  3.66it/s]

{'loss': 5.0188, 'grad_norm': 1.5005227327346802, 'learning_rate': 2.266666666666667e-06, 'epoch': 1.91}


 96%|█████████▌| 3590/3750 [17:41<00:43,  3.67it/s]

{'loss': 5.0064, 'grad_norm': 2.744971513748169, 'learning_rate': 2.1333333333333334e-06, 'epoch': 1.91}


 96%|█████████▌| 3600/3750 [17:43<00:40,  3.66it/s]

{'loss': 5.0127, 'grad_norm': 1.499510407447815, 'learning_rate': 2.0000000000000003e-06, 'epoch': 1.92}


 96%|█████████▋| 3610/3750 [17:46<00:38,  3.65it/s]

{'loss': 5.02, 'grad_norm': 3.136204719543457, 'learning_rate': 1.8666666666666669e-06, 'epoch': 1.93}


 97%|█████████▋| 3620/3750 [17:49<00:35,  3.66it/s]

{'loss': 5.0159, 'grad_norm': 2.8858301639556885, 'learning_rate': 1.7333333333333334e-06, 'epoch': 1.93}


 97%|█████████▋| 3630/3750 [17:52<00:32,  3.67it/s]

{'loss': 5.0028, 'grad_norm': 4.2108659744262695, 'learning_rate': 1.6000000000000001e-06, 'epoch': 1.94}


 97%|█████████▋| 3640/3750 [17:54<00:30,  3.65it/s]

{'loss': 5.0222, 'grad_norm': 1.5356874465942383, 'learning_rate': 1.4666666666666667e-06, 'epoch': 1.94}


 97%|█████████▋| 3650/3750 [17:57<00:27,  3.66it/s]

{'loss': 5.0132, 'grad_norm': 2.733123540878296, 'learning_rate': 1.3333333333333334e-06, 'epoch': 1.95}


 98%|█████████▊| 3660/3750 [18:00<00:24,  3.68it/s]

{'loss': 5.0198, 'grad_norm': 1.5113340616226196, 'learning_rate': 1.2000000000000002e-06, 'epoch': 1.95}


 98%|█████████▊| 3670/3750 [18:03<00:21,  3.66it/s]

{'loss': 5.0142, 'grad_norm': 1.5985277891159058, 'learning_rate': 1.0666666666666667e-06, 'epoch': 1.96}


 98%|█████████▊| 3680/3750 [18:05<00:19,  3.66it/s]

{'loss': 5.0165, 'grad_norm': 1.5213712453842163, 'learning_rate': 9.333333333333334e-07, 'epoch': 1.96}


 98%|█████████▊| 3690/3750 [18:08<00:16,  3.61it/s]

{'loss': 5.0196, 'grad_norm': 4.379112720489502, 'learning_rate': 8.000000000000001e-07, 'epoch': 1.97}


 99%|█████████▊| 3700/3750 [18:11<00:13,  3.66it/s]

{'loss': 5.0159, 'grad_norm': 1.5179146528244019, 'learning_rate': 6.666666666666667e-07, 'epoch': 1.97}


 99%|█████████▉| 3710/3750 [18:14<00:10,  3.66it/s]

{'loss': 5.0387, 'grad_norm': 1.4304097890853882, 'learning_rate': 5.333333333333333e-07, 'epoch': 1.98}


 99%|█████████▉| 3720/3750 [18:16<00:08,  3.52it/s]

{'loss': 5.0095, 'grad_norm': 2.7124741077423096, 'learning_rate': 4.0000000000000003e-07, 'epoch': 1.98}


 99%|█████████▉| 3730/3750 [18:19<00:05,  3.61it/s]

{'loss': 5.0225, 'grad_norm': 2.8318891525268555, 'learning_rate': 2.6666666666666667e-07, 'epoch': 1.99}


100%|█████████▉| 3740/3750 [18:22<00:02,  3.60it/s]

{'loss': 5.0045, 'grad_norm': 1.5591527223587036, 'learning_rate': 1.3333333333333334e-07, 'epoch': 1.99}


100%|██████████| 3750/3750 [18:25<00:00,  3.63it/s]

{'loss': 5.0061, 'grad_norm': 2.7860090732574463, 'learning_rate': 0.0, 'epoch': 2.0}


                                                   
100%|██████████| 3750/3750 [18:40<00:00,  3.35it/s]


{'eval_loss': 5.0111284255981445, 'eval_runtime': 7.9533, 'eval_samples_per_second': 377.203, 'eval_steps_per_second': 47.15, 'epoch': 2.0}
{'train_runtime': 1120.6935, 'train_samples_per_second': 26.769, 'train_steps_per_second': 3.346, 'train_loss': 5.022509630330403, 'epoch': 2.0}


('tokenizer_gstar_finetuned/tokenizer_config.json',
 'tokenizer_gstar_finetuned/special_tokens_map.json',
 'tokenizer_gstar_finetuned/vocab.json',
 'tokenizer_gstar_finetuned/merges.txt',
 'tokenizer_gstar_finetuned/added_tokens.json',
 'tokenizer_gstar_finetuned/tokenizer.json')

In [6]:
# Cell 5: Generate Synthetic Data

# Load the GPT-2 medium tokenizer and model for text generation
model_name_opt = "gpt2-medium"
tokenizer_opt = AutoTokenizer.from_pretrained(model_name_opt)
model_opt = AutoModelForCausalLM.from_pretrained(model_name_opt)

# Check and set the available device (MPS, CUDA, or CPU)
if torch.backends.mps.is_available():
    device = torch.device("mps")
    model_opt = model_opt.to(device)
    print("Using MPS device for acceleration.")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    model_opt = model_opt.to(device)
    print("Using CUDA device for acceleration.")
else:
    device = torch.device("cpu")
    model_opt = model_opt.to(device)
    print("Using CPU.")

def build_prompt(intent_name, real_examples, max_examples=3):
    """
    Create a prompt for the language model with example queries.
    """
    random.shuffle(real_examples)
    real_examples = real_examples[:max_examples]
    
    prompt_lines = [
        f"Intent: {intent_name}",
        "",
        "Here are some example user queries for this intent:"
    ]
    for i, ex in enumerate(real_examples, start=1):
        prompt_lines.append(f"{i}) \"{ex}\"")
    prompt_lines.append("")
    prompt_lines.append(f"Now, provide a new user query that also belongs to the '{intent_name}' intent:\n")
    
    return "\n".join(prompt_lines)

def generate_synthetic_utterances(
    prompt_text,
    tokenizer,
    model,
    num_samples=2,
    max_new_tokens=20,
    top_p=0.9,
    temperature=0.9,
    repetition_penalty=1.1
):
    """
    Generate synthetic utterances based on the prompt using the language model.
    """
    inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)

    generated_texts = []
    for _ in range(num_samples):
        with torch.no_grad():
            output_ids = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                top_p=top_p,
                temperature=temperature,
                repetition_penalty=repetition_penalty,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id
            )
        gen_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        completion = gen_text[len(prompt_text):].strip()
        generated_texts.append(completion)
    return generated_texts

def generate_synthetic_dataset(
    real_dataset: Dataset,
    tokenizer_llm,
    model_llm,
    data_multiplier=1,
    max_real_examples=3,
    num_gen_per_prompt=1,
    top_p=0.9,
    temperature=0.9,
    repetition_penalty=1.1,
    max_new_tokens=20
):
    """
    Generate a synthetic dataset by augmenting real data with generated queries.
    """
    real_list = real_dataset.to_list()
    # Group queries by their intent
    intent_to_queries = defaultdict(list)
    for row in real_list:
        intent_to_queries[row["intent"]].append(row["query"])

    synthetic_rows = []
    for intent_name, queries in intent_to_queries.items():
        real_count = len(queries)
        target_count = data_multiplier * real_count
        total_generated = 0
        while total_generated < target_count:
            prompt_text = build_prompt(
                intent_name=intent_name,
                real_examples=queries,
                max_examples=max_real_examples
            )
            gen_texts = generate_synthetic_utterances(
                prompt_text=prompt_text,
                tokenizer=tokenizer_llm,
                model=model_llm,
                num_samples=num_gen_per_prompt,
                max_new_tokens=max_new_tokens,
                top_p=top_p,
                temperature=temperature,
                repetition_penalty=repetition_penalty
            )
            for gtext in gen_texts:
                gtext = gtext.strip()
                synthetic_rows.append({
                    "query": gtext,
                    "intent": intent_name,
                    "source": "synthetic"
                })
            total_generated += num_gen_per_prompt
        print(
            f"Finished intent '{intent_name}'. "
            f"Generated {target_count} synthetic samples for it."
        )

    # Convert synthetic data to a Hugging Face Dataset
    df_synth = pd.DataFrame(synthetic_rows)
    synth_dataset = Dataset.from_pandas(df_synth)
    return synth_dataset

# Generate synthetic data with specified parameters
synthetic_dataset = generate_synthetic_dataset(
    real_dataset=train_dataset,
    tokenizer_llm=tokenizer_opt,
    model_llm=model_opt,
    data_multiplier=2,       # Generate ~2x the size of real data -> 30k new data
    max_real_examples=1,     # Use up to 1 real query in prompts
    num_gen_per_prompt=8,    # Generate 8 utterances per prompt
    top_p=0.9,
    temperature=0.9,
    repetition_penalty=1.1,
    max_new_tokens=20
)

# Output the size and a sample of the synthetic dataset
print("Synthetic dataset size:", len(synthetic_dataset))
print("Sample synthetic row:")
print(synthetic_dataset[0])


Using MPS device for acceleration.
Finished intent 'translate'. Generated 200 synthetic samples for it.
Finished intent 'transfer'. Generated 200 synthetic samples for it.
Finished intent 'timer'. Generated 200 synthetic samples for it.
Finished intent 'definition'. Generated 200 synthetic samples for it.
Finished intent 'meaning_of_life'. Generated 200 synthetic samples for it.
Finished intent 'insurance_change'. Generated 200 synthetic samples for it.
Finished intent 'find_phone'. Generated 200 synthetic samples for it.
Finished intent 'travel_alert'. Generated 200 synthetic samples for it.
Finished intent 'pto_request'. Generated 200 synthetic samples for it.
Finished intent 'improve_credit_score'. Generated 200 synthetic samples for it.
Finished intent 'fun_fact'. Generated 200 synthetic samples for it.
Finished intent 'change_language'. Generated 200 synthetic samples for it.
Finished intent 'payday'. Generated 200 synthetic samples for it.
Finished intent 'replacement_card_durati

In [7]:

# PVI Filtering

# Using the trained models g & g^* to compute:
# PVI(x->y) = -log2[g^*(y|empty)] + log2[g(y|x)]

# 1) Probability helpers

def get_label_probs(model, tokenizer, text, use_empty=False):
    """
    Returns a dictionary of label probabilities for the given text or an empty input.
    """
    # Use "[EMPTY]" as input if use_empty is True
    text_input = "[EMPTY]" if use_empty else text
    # Tokenize the input text
    inputs = tokenizer(text_input, return_tensors="pt", truncation=True, padding=True)
    device = torch.device("mps")  # Change to "cuda" or "cpu" if needed
    inputs = {k: v.to(device) for k, v in inputs.items()}
    model = model.to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits[0]  # Get logits for the first (and only) example
        probs = torch.softmax(logits, dim=-1)  # Apply softmax to get probabilities
    
    # Map label indices to their probabilities
    label_prob = {id2label[idx]: float(p.item()) for idx, p in enumerate(probs)}
    return label_prob

def compute_pvi_for_sample(query_text, intent_label):
    """
    Computes the PVI value for a single query-intent pair.
    """
    eps = 1e-12  # Small epsilon to prevent log(0)
    # Get probability from model g
    prob_g = get_label_probs(model_g, tokenizer_g, query_text, use_empty=False)
    p_g_xy = prob_g.get(intent_label, 0.0)
    
    # Get probability from model g^*
    prob_g_star = get_label_probs(model_g_star, tokenizer_g_star, query_text, use_empty=True)
    p_g_star_y = prob_g_star.get(intent_label, 0.0)
    
    # Calculate PVI
    pvi_val = -math.log2(max(p_g_star_y, eps)) + math.log2(max(p_g_xy, eps))
    return pvi_val

def filter_synthetic_data_by_pvi(
    synthetic_data: Dataset,
    pvi_threshold=0.5
):
    """
    Filters the synthetic dataset by keeping only samples with PVI >= pvi_threshold.
    Adds a 'pvi' column to the filtered dataset.
    """
    rows = synthetic_data.to_list()
    filtered = []
    
    for row in rows:
        x = row["query"]
        y = row["intent"]
        pvi_val = compute_pvi_for_sample(x, y)
        if pvi_val >= pvi_threshold:
            new_row = dict(row)
            new_row["pvi"] = pvi_val  # Add PVI value to the row
            filtered.append(new_row)
    
    # Convert filtered rows back to a Dataset
    df_filt = pd.DataFrame(filtered)
    return Dataset.from_pandas(df_filt)

# Apply PVI Filtering to Synthetic Data

# This function was not used in the final synthesis
# Initially used, then discarded for average PVI value for each class for filtering.
filtered_synthetic_dataset = filter_synthetic_data_by_pvi(
    synthetic_data=synthetic_dataset,
    pvi_threshold=0.5
)


In [14]:
def compute_classwise_avg_pvi(val_dataset):
    # Store a list of PVI values for each intent
    class_to_pvi_vals = {}
    
    for example in val_dataset:
        x_query = example["query"]
        y_intent = example["intent"]
        
        pvi_val = compute_pvi_for_sample(x_query, y_intent) 
        if y_intent not in class_to_pvi_vals:
            class_to_pvi_vals[y_intent] = []
        class_to_pvi_vals[y_intent].append(pvi_val)
    
    # Compute average per class
    classwise_avg_pvi = {}
    for intent_label, pvi_list in class_to_pvi_vals.items():
        if len(pvi_list) == 0:
            # If somehow a class doesn't have any labels
            classwise_avg_pvi[intent_label] = 0.0
        else:
            avg_val = sum(pvi_list) / len(pvi_list)
            classwise_avg_pvi[intent_label] = avg_val
    
    return classwise_avg_pvi

# Example usage:
threshold_dict = compute_classwise_avg_pvi(val_dataset)

print("Per-class average PVI threshold:")
for k, v in threshold_dict.items():
    print(f"  Class={k}, threshold={v:.4f}")



Per-class average PVI threshold:
  Class=translate, threshold=-0.0256
  Class=transfer, threshold=0.0328
  Class=timer, threshold=0.0395
  Class=definition, threshold=-0.0469
  Class=meaning_of_life, threshold=0.0163
  Class=insurance_change, threshold=0.0612
  Class=find_phone, threshold=-0.0788
  Class=travel_alert, threshold=-0.0040
  Class=pto_request, threshold=-0.0392
  Class=improve_credit_score, threshold=0.1001
  Class=fun_fact, threshold=-0.0602
  Class=change_language, threshold=0.0579
  Class=payday, threshold=-0.1091
  Class=replacement_card_duration, threshold=0.0804
  Class=time, threshold=0.0001
  Class=application_status, threshold=0.0301
  Class=flight_status, threshold=0.0277
  Class=flip_coin, threshold=0.0046
  Class=change_user_name, threshold=-0.0542
  Class=where_are_you_from, threshold=-0.0019
  Class=shopping_list_update, threshold=0.1012
  Class=what_can_i_ask_you, threshold=-0.0076
  Class=maybe, threshold=0.0002
  Class=oil_change_how, threshold=0.0349
  Cl

In [9]:
def filter_synthetic_data_by_pvi_per_class(
    synthetic_data, 
    threshold_dict
):
    """
    For each row (x, y) in synthetic_data, compute PVI(x->y).
    Keep only those with PVI >= threshold_dict[y].
    """
    filtered_rows = []
    original_rows = synthetic_data.to_list()  # if it's a HF Dataset
    for row in original_rows:
        x_query = row["query"]
        y_intent = row["intent"]
        
        pvi_val = compute_pvi_for_sample(x_query, y_intent)
        
        # Compare with the class-specific threshold
        class_threshold = threshold_dict.get(y_intent, 0.0)
        
        if pvi_val >= class_threshold:
            new_row = dict(row)
            new_row["pvi"] = pvi_val
            filtered_rows.append(new_row)
    
    import pandas as pd
    df_filtered = pd.DataFrame(filtered_rows)
    
    from datasets import Dataset
    filtered_dataset = Dataset.from_pandas(df_filtered)
    
    # print stats
    original_count = len(original_rows)
    filtered_count = len(filtered_dataset)
    percent = 100.0 * filtered_count / original_count if original_count > 0 else 0.0
    print(f"Filtered synthetic data: kept {filtered_count}/{original_count} = {percent:.1f}%")
    
    return filtered_dataset

filtered_synth_dataset = filter_synthetic_data_by_pvi_per_class(
    synthetic_data=synthetic_dataset,
    threshold_dict=threshold_dict
)



Filtered synthetic data: kept 13860/30000 = 46.2%


In [None]:
# Cell 7: Combine real + filtered synthetic

df_real = train_dataset.to_pandas()
df_filt_synth = filtered_synth_dataset.to_pandas()

df_aug = pd.concat([df_real, df_filt_synth], ignore_index=True)
aug_dataset = Dataset.from_pandas(df_aug)
print(f"\nAugmented dataset size: {len(aug_dataset)}")
if len(aug_dataset) > 0:
    print("Augmented sample row:", aug_dataset[0])


Augmented dataset size: 28860
Augmented sample row: {'query': 'what expression would i use to say i love you if i were an italian', 'intent': 'translate', 'source': None, 'pvi': None}


In [11]:
# Define the output file path
output_json_path = "augmented_dataset.json"

# Save the dataset to a JSON file
aug_dataset.to_json(output_json_path)

print(f"\nAugmented dataset has been saved to {output_json_path}")


Creating json from Arrow format: 100%|██████████| 29/29 [00:00<00:00, 596.01ba/s]


Augmented dataset has been saved to augmented_dataset.json





In [12]:
# Write newly generated data to a txt file
df_synthetic = synthetic_dataset.to_pandas()
with open("synthetic_dataset.txt", "w", encoding="utf-8") as file:
    for query in df_synthetic["query"]:
        file.write(query + "\n")

print("Synthetic dataset has been written to 'synthetic_dataset.txt'.")

Synthetic dataset has been written to 'synthetic_dataset.txt'.


In [13]:
print(torch.backends.mps.is_available())  # should be True
print(torch.backends.mps.is_built())

True
True
