In [2]:
# %pip install accelerate # charset-normalizer  # pandas python-dotenv transformers
# %pip install --upgrade accelerate

In [1]:
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import Trainer, TrainingArguments
from huggingface_hub import login
from dotenv import load_dotenv
from torch.utils.data import Dataset

import pandas as pd

import torch
import os

  from .autonotebook import tqdm as notebook_tqdm





In [2]:
# List all available CUDA devices
for i in range(torch.cuda.device_count()):
    print(f"Device {i}: {torch.cuda.get_device_name(i)}")

Device 0: NVIDIA GeForce RTX 3060 Laptop GPU


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

cuda


In [4]:
# Load environment variables from .env file
load_dotenv()

token = os.getenv("HUGGINGFACE_API_TOKEN")

login(token)

In [5]:
# model_id = "meta-llama/Llama-3.2-1B"
model_id = "gpt2-medium"
output_dir = "model/gpt2-medium-food-v2"
enpoint_url = "../apis/model/gpt-v1"

In [6]:
pipe = pipeline(
    "text-generation",
    model=model_id, 
    torch_dtype=torch.bfloat16,
    device=device
)

# response = pipe("What is the most common eaten foodstuff in USA?", return_full_text=False, truncation=True)
response = pipe("What is the most eaten food in Algeria?")
response

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


[{'generated_text': 'What is the most eaten food in Algeria?\n\nEating fish is the norm in Algeria, probably more so than any other nation across Africa. Fish has become famous by the millions as one of the dishes that we eat every day, especially as'}]

In [7]:
from charset_normalizer import detect

def return_dataset_encoding(dataset_path: str) -> str:
    # Read a sample of the file
    with open(dataset_path, 'rb') as file:
        raw_data = file.read()

    # Detect encoding
    result = detect(raw_data)
    print(f"Detected encoding: {result['encoding']}")

    return result['encoding']

In [8]:
# Dishes script
def get_most_eaten_food_in_dishes_dataset():
    sentences = []

    dish = pd.read_csv("data/dishes.csv")
    dish['english_name'] = dish['english_name'].fillna(dish['local_name'])

    unique_countries = dish['countries'].unique()
    list_of_foods = []
    for country in unique_countries:
        country_data = dish[dish['countries'] == country]

        country_regions = country_data['regions'].unique()

        for region in country_regions:
            region_data = country_data[country_data['regions'] == region]

            if len(region_data['english_name'].values) == 0:
                continue
            
            sentences.append(f"What is the most eaten food in {country}, {region}? In {country}, {region} the most eaten food is {region_data['english_name'].values[0]}")
            sentences.append(f"What is the most eaten food in {region}? In {country}, {region} the most eaten food is {region_data['english_name'].values[0]}")
            sentences.append(f"What do people in {region} eat? In {country}, {region} the most eaten food is {region_data['english_name'].values[0]}")
            sentences.append(f"What do people in {country}, {region} eat? In {country}, {region} the most eaten food is {region_data['english_name'].values[0]}")
            list_of_foods.append(f"{region_data['english_name'].values[0]}")
            # print(f"What is the most eaten food in {country}? In {country}, {region} the most eaten food is {region_data['english_name'].values[0]}")

        sentences.append(f"What is the most eaten food in {country}? The most common eaten foods in {country}: {', '.join(list_of_foods).rstrip()}")
        sentences.append(f"What is the most eaten food in {country}? The most eaten foods in {country}: {', '.join(list_of_foods).rstrip()}")
        sentences.append(f"What do people in {country} eat? The most eaten foods in {country}: {', '.join(list_of_foods).rstrip()}")
        list_of_foods = []

    return sentences

In [9]:
# Get some sentences to ask our AI
sentences = get_most_eaten_food_in_dishes_dataset()

# print(sentences)

filtered_items = [item for item in sentences if "Bosnia and Herzegovina" in item]

print(filtered_items)

['What is the most eaten food in Bosnia and Herzegovina, Balkans? In Bosnia and Herzegovina, Balkans the most eaten food is Cevapi', 'What is the most eaten food in Balkans? In Bosnia and Herzegovina, Balkans the most eaten food is Cevapi', 'What do people in Balkans eat? In Bosnia and Herzegovina, Balkans the most eaten food is Cevapi', 'What do people in Bosnia and Herzegovina, Balkans eat? In Bosnia and Herzegovina, Balkans the most eaten food is Cevapi', 'What is the most eaten food in Bosnia and Herzegovina? The most common eaten foods in Bosnia and Herzegovina: Cevapi', 'What is the most eaten food in Bosnia and Herzegovina? The most eaten foods in Bosnia and Herzegovina: Cevapi', 'What do people in Bosnia and Herzegovina eat? The most eaten foods in Bosnia and Herzegovina: Cevapi']


In [10]:
# Custom Dataset class for text generation
class TextGenerationDataset(Dataset):
    def __init__(self, texts, tokenizer, device, max_length=512):
        self.tokenizer = tokenizer
        self.texts = texts
        self.max_length = max_length
        self.device = device

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        # Tokenize the text with padding and truncation
        encoding = self.tokenizer(
            self.texts[idx],
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt",
        ).to(device)
        # Labels are the same as input_ids for causal language modeling
        encoding["labels"] = encoding["input_ids"]
        return {key: val.squeeze(0) for key, val in encoding.items()}

In [11]:
def train_model(model_to_train, tokenizer_for_model, texts, save_model_dir):
    # Prepare the dataset
    train_dataset = TextGenerationDataset(texts, tokenizer_for_model, device)

    # Training arguments
    training_args = TrainingArguments(
        output_dir="trainer",               # Output directory
        learning_rate=5e-5,                 # Learning rate
        per_device_train_batch_size=2,      # Batch size
        weight_decay=0.01,                  # Weight decay
        save_strategy="no",                 # No saving on checkpoints
        logging_dir="logs",                 # Log directory
        logging_steps=10,                   # Log every 10 steps
        fp16=True,                          # Enable mixed precision (if supported)
    )

    # Initialize the Trainer
    trainer = Trainer(
        model=model_to_train,
        args=training_args,
        train_dataset=train_dataset,
        tokenizer=tokenizer_for_model,
    )

    # Fine-tune the model
    trainer.train()

    model_to_train.save_pretrained(save_model_dir)
    tokenizer_for_model.save_pretrained(save_model_dir)

In [12]:
# Load the pretrained text-generation model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_id
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Ensure padding tokens are set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Loading dataset
texts = get_most_eaten_food_in_dishes_dataset()

train_model(model, tokenizer, texts, output_dir)

  trainer = Trainer(
  0%|          | 10/2262 [00:07<52:43,  1.40s/it]

{'loss': 9.8534, 'grad_norm': 175.21209716796875, 'learning_rate': 4.995579133510168e-05, 'epoch': 0.01}


  1%|          | 20/2262 [00:36<1:45:24,  2.82s/it]

{'loss': 0.8459, 'grad_norm': 1.2516539096832275, 'learning_rate': 4.9734748010610085e-05, 'epoch': 0.03}


  1%|▏         | 30/2262 [01:05<1:46:23,  2.86s/it]

{'loss': 0.1593, 'grad_norm': 0.9526187777519226, 'learning_rate': 4.9513704686118484e-05, 'epoch': 0.04}


  2%|▏         | 40/2262 [01:33<1:46:01,  2.86s/it]

{'loss': 0.1424, 'grad_norm': 0.6546040773391724, 'learning_rate': 4.9292661361626876e-05, 'epoch': 0.05}


  2%|▏         | 50/2262 [02:02<1:45:25,  2.86s/it]

{'loss': 0.1601, 'grad_norm': 0.9747548699378967, 'learning_rate': 4.907161803713528e-05, 'epoch': 0.07}


  3%|▎         | 60/2262 [02:30<1:44:59,  2.86s/it]

{'loss': 0.1298, 'grad_norm': 0.7499803304672241, 'learning_rate': 4.885057471264368e-05, 'epoch': 0.08}


  3%|▎         | 70/2262 [02:59<1:44:35,  2.86s/it]

{'loss': 0.2515, 'grad_norm': 0.5427398085594177, 'learning_rate': 4.862953138815208e-05, 'epoch': 0.09}


  4%|▎         | 80/2262 [03:28<1:44:15,  2.87s/it]

{'loss': 0.0965, 'grad_norm': 0.49027547240257263, 'learning_rate': 4.840848806366048e-05, 'epoch': 0.11}


  4%|▍         | 90/2262 [03:56<1:43:43,  2.87s/it]

{'loss': 0.1357, 'grad_norm': 0.5854550004005432, 'learning_rate': 4.818744473916888e-05, 'epoch': 0.12}


  4%|▍         | 100/2262 [04:25<1:43:25,  2.87s/it]

{'loss': 0.1181, 'grad_norm': 0.8604946732521057, 'learning_rate': 4.796640141467728e-05, 'epoch': 0.13}


  5%|▍         | 110/2262 [04:54<1:42:57,  2.87s/it]

{'loss': 0.0905, 'grad_norm': 0.6448808908462524, 'learning_rate': 4.7745358090185675e-05, 'epoch': 0.15}


  5%|▌         | 120/2262 [05:23<1:42:41,  2.88s/it]

{'loss': 0.1076, 'grad_norm': 0.8801952600479126, 'learning_rate': 4.752431476569408e-05, 'epoch': 0.16}


  6%|▌         | 130/2262 [05:51<1:41:51,  2.87s/it]

{'loss': 0.1227, 'grad_norm': 1.362291693687439, 'learning_rate': 4.730327144120248e-05, 'epoch': 0.17}


  6%|▌         | 140/2262 [06:20<1:41:27,  2.87s/it]

{'loss': 0.1087, 'grad_norm': 0.980648398399353, 'learning_rate': 4.708222811671088e-05, 'epoch': 0.19}


  7%|▋         | 150/2262 [06:49<1:40:54,  2.87s/it]

{'loss': 0.1002, 'grad_norm': 0.6042937636375427, 'learning_rate': 4.686118479221928e-05, 'epoch': 0.2}


  7%|▋         | 160/2262 [07:20<1:53:10,  3.23s/it]

{'loss': 0.0815, 'grad_norm': 0.7473962903022766, 'learning_rate': 4.6640141467727676e-05, 'epoch': 0.21}


  8%|▊         | 170/2262 [07:53<1:57:36,  3.37s/it]

{'loss': 0.1003, 'grad_norm': 0.5895016193389893, 'learning_rate': 4.6419098143236075e-05, 'epoch': 0.23}


  8%|▊         | 180/2262 [08:28<2:05:59,  3.63s/it]

{'loss': 0.2424, 'grad_norm': 0.6577419638633728, 'learning_rate': 4.6198054818744474e-05, 'epoch': 0.24}


  8%|▊         | 190/2262 [09:04<2:03:00,  3.56s/it]

{'loss': 0.1059, 'grad_norm': 0.5609899759292603, 'learning_rate': 4.597701149425287e-05, 'epoch': 0.25}


  9%|▉         | 200/2262 [09:40<2:02:45,  3.57s/it]

{'loss': 0.1657, 'grad_norm': 0.5589551329612732, 'learning_rate': 4.575596816976128e-05, 'epoch': 0.27}


  9%|▉         | 210/2262 [10:17<2:07:43,  3.73s/it]

{'loss': 0.0789, 'grad_norm': 0.5004270672798157, 'learning_rate': 4.553492484526968e-05, 'epoch': 0.28}


 10%|▉         | 220/2262 [10:52<2:00:13,  3.53s/it]

{'loss': 0.0595, 'grad_norm': 0.6022709012031555, 'learning_rate': 4.531388152077807e-05, 'epoch': 0.29}


 10%|█         | 230/2262 [11:28<1:59:34,  3.53s/it]

{'loss': 0.1104, 'grad_norm': 0.7172099351882935, 'learning_rate': 4.5092838196286476e-05, 'epoch': 0.31}


 11%|█         | 240/2262 [12:03<1:58:24,  3.51s/it]

{'loss': 0.0796, 'grad_norm': 0.7717597484588623, 'learning_rate': 4.4871794871794874e-05, 'epoch': 0.32}


 11%|█         | 250/2262 [12:38<1:57:49,  3.51s/it]

{'loss': 0.0736, 'grad_norm': 0.6089113354682922, 'learning_rate': 4.465075154730327e-05, 'epoch': 0.33}


 11%|█▏        | 260/2262 [13:14<1:57:15,  3.51s/it]

{'loss': 0.0975, 'grad_norm': 1.2935066223144531, 'learning_rate': 4.442970822281167e-05, 'epoch': 0.34}


 12%|█▏        | 270/2262 [13:50<1:59:59,  3.61s/it]

{'loss': 0.079, 'grad_norm': 0.6654071807861328, 'learning_rate': 4.420866489832007e-05, 'epoch': 0.36}


 12%|█▏        | 280/2262 [14:26<1:56:55,  3.54s/it]

{'loss': 0.0694, 'grad_norm': 0.6702502369880676, 'learning_rate': 4.398762157382848e-05, 'epoch': 0.37}


 13%|█▎        | 290/2262 [15:01<1:56:09,  3.53s/it]

{'loss': 0.0975, 'grad_norm': 0.5751957893371582, 'learning_rate': 4.376657824933687e-05, 'epoch': 0.38}


 13%|█▎        | 300/2262 [15:37<1:57:16,  3.59s/it]

{'loss': 0.069, 'grad_norm': 1.4003931283950806, 'learning_rate': 4.3545534924845275e-05, 'epoch': 0.4}


 14%|█▎        | 310/2262 [16:16<2:00:17,  3.70s/it]

{'loss': 0.0558, 'grad_norm': 0.5723843574523926, 'learning_rate': 4.3324491600353674e-05, 'epoch': 0.41}


 14%|█▍        | 320/2262 [16:51<1:52:02,  3.46s/it]

{'loss': 0.0912, 'grad_norm': 1.2069740295410156, 'learning_rate': 4.3103448275862066e-05, 'epoch': 0.42}


 15%|█▍        | 330/2262 [17:26<1:51:13,  3.45s/it]

{'loss': 0.0957, 'grad_norm': 0.7491254210472107, 'learning_rate': 4.288240495137047e-05, 'epoch': 0.44}


 15%|█▌        | 340/2262 [18:00<1:51:38,  3.49s/it]

{'loss': 0.0661, 'grad_norm': 0.6237572431564331, 'learning_rate': 4.266136162687887e-05, 'epoch': 0.45}


 15%|█▌        | 350/2262 [18:35<1:51:52,  3.51s/it]

{'loss': 0.0535, 'grad_norm': 0.4811672270298004, 'learning_rate': 4.244031830238727e-05, 'epoch': 0.46}


 16%|█▌        | 360/2262 [19:10<1:49:57,  3.47s/it]

{'loss': 0.054, 'grad_norm': 0.7957543134689331, 'learning_rate': 4.221927497789567e-05, 'epoch': 0.48}


 16%|█▋        | 370/2262 [19:45<1:48:39,  3.45s/it]

{'loss': 0.0607, 'grad_norm': 1.0442675352096558, 'learning_rate': 4.199823165340407e-05, 'epoch': 0.49}


 17%|█▋        | 380/2262 [20:19<1:48:21,  3.45s/it]

{'loss': 0.0603, 'grad_norm': 0.5527072548866272, 'learning_rate': 4.177718832891247e-05, 'epoch': 0.5}


 17%|█▋        | 390/2262 [20:55<1:51:02,  3.56s/it]

{'loss': 0.0678, 'grad_norm': 1.4365309476852417, 'learning_rate': 4.155614500442087e-05, 'epoch': 0.52}


 18%|█▊        | 400/2262 [21:30<1:49:50,  3.54s/it]

{'loss': 0.0796, 'grad_norm': 2.0800652503967285, 'learning_rate': 4.1335101679929264e-05, 'epoch': 0.53}


 18%|█▊        | 410/2262 [22:06<1:50:22,  3.58s/it]

{'loss': 0.0797, 'grad_norm': 1.225983738899231, 'learning_rate': 4.111405835543767e-05, 'epoch': 0.54}


 19%|█▊        | 420/2262 [22:42<1:48:58,  3.55s/it]

{'loss': 0.0513, 'grad_norm': 0.7133216261863708, 'learning_rate': 4.089301503094607e-05, 'epoch': 0.56}


 19%|█▉        | 430/2262 [23:18<1:49:32,  3.59s/it]

{'loss': 0.0435, 'grad_norm': 0.4223274290561676, 'learning_rate': 4.067197170645447e-05, 'epoch': 0.57}


 19%|█▉        | 440/2262 [23:54<1:49:33,  3.61s/it]

{'loss': 0.0799, 'grad_norm': 0.5866912603378296, 'learning_rate': 4.0450928381962866e-05, 'epoch': 0.58}


 20%|█▉        | 450/2262 [24:31<1:53:38,  3.76s/it]

{'loss': 0.0589, 'grad_norm': 1.0984677076339722, 'learning_rate': 4.0229885057471265e-05, 'epoch': 0.6}


 20%|██        | 460/2262 [25:08<1:51:21,  3.71s/it]

{'loss': 0.0606, 'grad_norm': 0.4143548607826233, 'learning_rate': 4.000884173297967e-05, 'epoch': 0.61}


 21%|██        | 470/2262 [25:45<1:50:12,  3.69s/it]

{'loss': 0.0485, 'grad_norm': 0.7672938108444214, 'learning_rate': 3.978779840848806e-05, 'epoch': 0.62}


 21%|██        | 480/2262 [26:22<1:49:01,  3.67s/it]

{'loss': 0.0534, 'grad_norm': 0.6271587014198303, 'learning_rate': 3.956675508399647e-05, 'epoch': 0.64}


 22%|██▏       | 490/2262 [26:59<1:48:48,  3.68s/it]

{'loss': 0.0909, 'grad_norm': 0.8482453227043152, 'learning_rate': 3.934571175950487e-05, 'epoch': 0.65}


 22%|██▏       | 500/2262 [27:36<1:48:08,  3.68s/it]

{'loss': 0.0847, 'grad_norm': 1.0207377672195435, 'learning_rate': 3.912466843501326e-05, 'epoch': 0.66}


 23%|██▎       | 510/2262 [28:13<1:47:11,  3.67s/it]

{'loss': 0.0462, 'grad_norm': 0.476155549287796, 'learning_rate': 3.8903625110521665e-05, 'epoch': 0.68}


 23%|██▎       | 520/2262 [28:50<1:49:20,  3.77s/it]

{'loss': 0.0524, 'grad_norm': 0.713136613368988, 'learning_rate': 3.8682581786030064e-05, 'epoch': 0.69}


 23%|██▎       | 530/2262 [29:27<1:45:33,  3.66s/it]

{'loss': 0.0718, 'grad_norm': 0.6278985738754272, 'learning_rate': 3.846153846153846e-05, 'epoch': 0.7}


 24%|██▍       | 540/2262 [30:04<1:44:55,  3.66s/it]

{'loss': 0.0631, 'grad_norm': 1.3655047416687012, 'learning_rate': 3.824049513704686e-05, 'epoch': 0.72}


 24%|██▍       | 550/2262 [30:40<1:41:32,  3.56s/it]

{'loss': 0.0481, 'grad_norm': 0.40692824125289917, 'learning_rate': 3.801945181255526e-05, 'epoch': 0.73}


 25%|██▍       | 560/2262 [31:15<1:41:32,  3.58s/it]

{'loss': 0.0792, 'grad_norm': 1.2302744388580322, 'learning_rate': 3.7798408488063666e-05, 'epoch': 0.74}


 25%|██▌       | 570/2262 [31:52<1:44:04,  3.69s/it]

{'loss': 0.0442, 'grad_norm': 1.1150234937667847, 'learning_rate': 3.757736516357206e-05, 'epoch': 0.76}


 26%|██▌       | 580/2262 [32:28<1:39:56,  3.57s/it]

{'loss': 0.0444, 'grad_norm': 0.5058020949363708, 'learning_rate': 3.735632183908046e-05, 'epoch': 0.77}


 26%|██▌       | 590/2262 [33:03<1:37:41,  3.51s/it]

{'loss': 0.0655, 'grad_norm': 0.5333741307258606, 'learning_rate': 3.713527851458886e-05, 'epoch': 0.78}


 27%|██▋       | 600/2262 [33:38<1:37:31,  3.52s/it]

{'loss': 0.0455, 'grad_norm': 0.44664397835731506, 'learning_rate': 3.691423519009726e-05, 'epoch': 0.8}


 27%|██▋       | 610/2262 [34:14<1:37:43,  3.55s/it]

{'loss': 0.0623, 'grad_norm': 1.5896797180175781, 'learning_rate': 3.669319186560566e-05, 'epoch': 0.81}


 27%|██▋       | 620/2262 [34:50<1:37:37,  3.57s/it]

{'loss': 0.0521, 'grad_norm': 0.4629710018634796, 'learning_rate': 3.647214854111406e-05, 'epoch': 0.82}


 28%|██▊       | 630/2262 [35:25<1:36:32,  3.55s/it]

{'loss': 0.0477, 'grad_norm': 0.5778882503509521, 'learning_rate': 3.625110521662246e-05, 'epoch': 0.84}


 28%|██▊       | 640/2262 [36:01<1:35:09,  3.52s/it]

{'loss': 0.0423, 'grad_norm': 0.42588740587234497, 'learning_rate': 3.6030061892130864e-05, 'epoch': 0.85}


 29%|██▊       | 650/2262 [36:36<1:34:16,  3.51s/it]

{'loss': 0.0368, 'grad_norm': 0.3419855833053589, 'learning_rate': 3.5809018567639256e-05, 'epoch': 0.86}


 29%|██▉       | 660/2262 [37:11<1:33:41,  3.51s/it]

{'loss': 0.0371, 'grad_norm': 0.8928683996200562, 'learning_rate': 3.558797524314766e-05, 'epoch': 0.88}


 30%|██▉       | 670/2262 [37:47<1:33:24,  3.52s/it]

{'loss': 0.0517, 'grad_norm': 0.41624972224235535, 'learning_rate': 3.536693191865606e-05, 'epoch': 0.89}


 30%|███       | 680/2262 [38:22<1:33:10,  3.53s/it]

{'loss': 0.0772, 'grad_norm': 0.6066255569458008, 'learning_rate': 3.514588859416445e-05, 'epoch': 0.9}


 31%|███       | 690/2262 [38:57<1:32:02,  3.51s/it]

{'loss': 0.0606, 'grad_norm': 0.7702192068099976, 'learning_rate': 3.492484526967286e-05, 'epoch': 0.92}


 31%|███       | 700/2262 [39:32<1:31:17,  3.51s/it]

{'loss': 0.0476, 'grad_norm': 0.6097784638404846, 'learning_rate': 3.470380194518126e-05, 'epoch': 0.93}


 31%|███▏      | 710/2262 [40:08<1:30:33,  3.50s/it]

{'loss': 0.0458, 'grad_norm': 0.7179458141326904, 'learning_rate': 3.4482758620689657e-05, 'epoch': 0.94}


 32%|███▏      | 720/2262 [40:43<1:30:08,  3.51s/it]

{'loss': 0.049, 'grad_norm': 0.5139002799987793, 'learning_rate': 3.4261715296198055e-05, 'epoch': 0.95}


 32%|███▏      | 730/2262 [41:18<1:29:27,  3.50s/it]

{'loss': 0.053, 'grad_norm': 0.5167627930641174, 'learning_rate': 3.4040671971706454e-05, 'epoch': 0.97}


 33%|███▎      | 740/2262 [41:53<1:29:04,  3.51s/it]

{'loss': 0.04, 'grad_norm': 0.4545002579689026, 'learning_rate': 3.381962864721486e-05, 'epoch': 0.98}


 33%|███▎      | 750/2262 [42:30<1:32:03,  3.65s/it]

{'loss': 0.0454, 'grad_norm': 0.7502549886703491, 'learning_rate': 3.359858532272325e-05, 'epoch': 0.99}


 34%|███▎      | 760/2262 [43:06<1:29:13,  3.56s/it]

{'loss': 0.0385, 'grad_norm': 0.7701166272163391, 'learning_rate': 3.337754199823165e-05, 'epoch': 1.01}


 34%|███▍      | 770/2262 [43:41<1:28:08,  3.54s/it]

{'loss': 0.0455, 'grad_norm': 0.7713834643363953, 'learning_rate': 3.315649867374006e-05, 'epoch': 1.02}


 34%|███▍      | 780/2262 [44:18<1:29:05,  3.61s/it]

{'loss': 0.0403, 'grad_norm': 0.5538753271102905, 'learning_rate': 3.2935455349248456e-05, 'epoch': 1.03}


 35%|███▍      | 790/2262 [44:54<1:29:54,  3.66s/it]

{'loss': 0.0419, 'grad_norm': 0.39843693375587463, 'learning_rate': 3.2714412024756855e-05, 'epoch': 1.05}


 35%|███▌      | 800/2262 [45:31<1:28:31,  3.63s/it]

{'loss': 0.0344, 'grad_norm': 0.9922560453414917, 'learning_rate': 3.2493368700265253e-05, 'epoch': 1.06}


 36%|███▌      | 810/2262 [46:08<1:28:37,  3.66s/it]

{'loss': 0.0459, 'grad_norm': 0.3455717861652374, 'learning_rate': 3.227232537577365e-05, 'epoch': 1.07}


 36%|███▋      | 820/2262 [46:44<1:27:36,  3.65s/it]

{'loss': 0.0341, 'grad_norm': 0.6272831559181213, 'learning_rate': 3.205128205128206e-05, 'epoch': 1.09}


 37%|███▋      | 830/2262 [47:21<1:27:55,  3.68s/it]

{'loss': 0.0325, 'grad_norm': 0.6280528903007507, 'learning_rate': 3.183023872679045e-05, 'epoch': 1.1}


 37%|███▋      | 840/2262 [47:58<1:30:31,  3.82s/it]

{'loss': 0.0394, 'grad_norm': 0.7014539241790771, 'learning_rate': 3.160919540229885e-05, 'epoch': 1.11}


 38%|███▊      | 850/2262 [48:36<1:29:17,  3.79s/it]

{'loss': 0.0473, 'grad_norm': 0.6335972547531128, 'learning_rate': 3.1388152077807255e-05, 'epoch': 1.13}


 38%|███▊      | 860/2262 [49:13<1:26:07,  3.69s/it]

{'loss': 0.0382, 'grad_norm': 0.4598778784275055, 'learning_rate': 3.116710875331565e-05, 'epoch': 1.14}


 38%|███▊      | 870/2262 [49:50<1:24:59,  3.66s/it]

{'loss': 0.0431, 'grad_norm': 0.33862200379371643, 'learning_rate': 3.094606542882405e-05, 'epoch': 1.15}


 39%|███▉      | 880/2262 [50:26<1:24:15,  3.66s/it]

{'loss': 0.0383, 'grad_norm': 1.2104374170303345, 'learning_rate': 3.072502210433245e-05, 'epoch': 1.17}


 39%|███▉      | 890/2262 [51:03<1:23:39,  3.66s/it]

{'loss': 0.0327, 'grad_norm': 0.3921877145767212, 'learning_rate': 3.0503978779840854e-05, 'epoch': 1.18}


 40%|███▉      | 900/2262 [51:40<1:22:53,  3.65s/it]

{'loss': 0.0419, 'grad_norm': 0.6071534156799316, 'learning_rate': 3.028293545534925e-05, 'epoch': 1.19}


 40%|████      | 910/2262 [52:17<1:23:17,  3.70s/it]

{'loss': 0.0394, 'grad_norm': 0.5639289617538452, 'learning_rate': 3.0061892130857648e-05, 'epoch': 1.21}


 41%|████      | 920/2262 [52:54<1:24:07,  3.76s/it]

{'loss': 0.0389, 'grad_norm': 0.5945936441421509, 'learning_rate': 2.984084880636605e-05, 'epoch': 1.22}


 41%|████      | 930/2262 [53:32<1:24:46,  3.82s/it]

{'loss': 0.0313, 'grad_norm': 0.7268903255462646, 'learning_rate': 2.9619805481874446e-05, 'epoch': 1.23}


 42%|████▏     | 940/2262 [54:09<1:21:45,  3.71s/it]

{'loss': 0.0318, 'grad_norm': 0.44109708070755005, 'learning_rate': 2.9398762157382848e-05, 'epoch': 1.25}


 42%|████▏     | 950/2262 [54:46<1:20:44,  3.69s/it]

{'loss': 0.0313, 'grad_norm': 0.8760420680046082, 'learning_rate': 2.9177718832891247e-05, 'epoch': 1.26}


 42%|████▏     | 960/2262 [55:23<1:19:37,  3.67s/it]

{'loss': 0.034, 'grad_norm': 0.841117262840271, 'learning_rate': 2.895667550839965e-05, 'epoch': 1.27}


 43%|████▎     | 970/2262 [55:59<1:18:57,  3.67s/it]

{'loss': 0.0344, 'grad_norm': 0.6635660529136658, 'learning_rate': 2.8735632183908045e-05, 'epoch': 1.29}


 43%|████▎     | 980/2262 [56:36<1:18:19,  3.67s/it]

{'loss': 0.0345, 'grad_norm': 0.7031806111335754, 'learning_rate': 2.8514588859416447e-05, 'epoch': 1.3}


 44%|████▍     | 990/2262 [57:13<1:18:20,  3.69s/it]

{'loss': 0.0319, 'grad_norm': 0.7232035398483276, 'learning_rate': 2.829354553492485e-05, 'epoch': 1.31}


 44%|████▍     | 1000/2262 [57:50<1:17:12,  3.67s/it]

{'loss': 0.0479, 'grad_norm': 0.744864821434021, 'learning_rate': 2.8072502210433245e-05, 'epoch': 1.33}


 45%|████▍     | 1010/2262 [58:27<1:16:51,  3.68s/it]

{'loss': 0.0316, 'grad_norm': 0.5130192041397095, 'learning_rate': 2.7851458885941644e-05, 'epoch': 1.34}


 45%|████▌     | 1020/2262 [59:04<1:16:22,  3.69s/it]

{'loss': 0.0375, 'grad_norm': 0.3779190480709076, 'learning_rate': 2.7630415561450046e-05, 'epoch': 1.35}


 46%|████▌     | 1030/2262 [59:40<1:15:16,  3.67s/it]

{'loss': 0.0536, 'grad_norm': 0.5492127537727356, 'learning_rate': 2.740937223695845e-05, 'epoch': 1.37}


 46%|████▌     | 1040/2262 [1:00:18<1:14:56,  3.68s/it]

{'loss': 0.0353, 'grad_norm': 1.2836363315582275, 'learning_rate': 2.7188328912466844e-05, 'epoch': 1.38}


 46%|████▋     | 1050/2262 [1:00:54<1:14:02,  3.67s/it]

{'loss': 0.0301, 'grad_norm': 0.7622262835502625, 'learning_rate': 2.6967285587975243e-05, 'epoch': 1.39}


 47%|████▋     | 1060/2262 [1:01:31<1:13:28,  3.67s/it]

{'loss': 0.0347, 'grad_norm': 0.9500275254249573, 'learning_rate': 2.6746242263483645e-05, 'epoch': 1.41}


 47%|████▋     | 1070/2262 [1:02:08<1:12:55,  3.67s/it]

{'loss': 0.0289, 'grad_norm': 0.6841263771057129, 'learning_rate': 2.6525198938992047e-05, 'epoch': 1.42}


 48%|████▊     | 1080/2262 [1:02:45<1:12:24,  3.68s/it]

{'loss': 0.0348, 'grad_norm': 1.4756230115890503, 'learning_rate': 2.6304155614500443e-05, 'epoch': 1.43}


 48%|████▊     | 1090/2262 [1:03:22<1:12:00,  3.69s/it]

{'loss': 0.0443, 'grad_norm': 0.5264756083488464, 'learning_rate': 2.6083112290008842e-05, 'epoch': 1.45}


 49%|████▊     | 1100/2262 [1:04:00<1:13:23,  3.79s/it]

{'loss': 0.0419, 'grad_norm': 0.9467110633850098, 'learning_rate': 2.5862068965517244e-05, 'epoch': 1.46}


 49%|████▉     | 1110/2262 [1:04:37<1:11:00,  3.70s/it]

{'loss': 0.036, 'grad_norm': 0.39834728837013245, 'learning_rate': 2.564102564102564e-05, 'epoch': 1.47}


 50%|████▉     | 1120/2262 [1:05:14<1:10:08,  3.69s/it]

{'loss': 0.0419, 'grad_norm': 0.5657243132591248, 'learning_rate': 2.5419982316534042e-05, 'epoch': 1.49}


 50%|████▉     | 1130/2262 [1:05:55<1:22:07,  4.35s/it]

{'loss': 0.0352, 'grad_norm': 0.7309568524360657, 'learning_rate': 2.519893899204244e-05, 'epoch': 1.5}


 50%|█████     | 1140/2262 [1:06:33<1:09:39,  3.72s/it]

{'loss': 0.0371, 'grad_norm': 0.43793901801109314, 'learning_rate': 2.497789566755084e-05, 'epoch': 1.51}


 51%|█████     | 1150/2262 [1:07:10<1:09:32,  3.75s/it]

{'loss': 0.0342, 'grad_norm': 0.5139971375465393, 'learning_rate': 2.4756852343059242e-05, 'epoch': 1.53}


 51%|█████▏    | 1160/2262 [1:07:47<1:07:28,  3.67s/it]

{'loss': 0.0312, 'grad_norm': 0.781981348991394, 'learning_rate': 2.453580901856764e-05, 'epoch': 1.54}


 52%|█████▏    | 1170/2262 [1:08:24<1:07:39,  3.72s/it]

{'loss': 0.0573, 'grad_norm': 0.9538938403129578, 'learning_rate': 2.431476569407604e-05, 'epoch': 1.55}


 52%|█████▏    | 1180/2262 [1:09:01<1:06:20,  3.68s/it]

{'loss': 0.0289, 'grad_norm': 0.895419180393219, 'learning_rate': 2.409372236958444e-05, 'epoch': 1.56}


 53%|█████▎    | 1190/2262 [1:09:38<1:05:40,  3.68s/it]

{'loss': 0.0308, 'grad_norm': 0.5615108609199524, 'learning_rate': 2.3872679045092838e-05, 'epoch': 1.58}


 53%|█████▎    | 1200/2262 [1:10:15<1:04:57,  3.67s/it]

{'loss': 0.0335, 'grad_norm': 0.4663921892642975, 'learning_rate': 2.365163572060124e-05, 'epoch': 1.59}


 53%|█████▎    | 1210/2262 [1:10:51<1:04:23,  3.67s/it]

{'loss': 0.0306, 'grad_norm': 0.7378522157669067, 'learning_rate': 2.343059239610964e-05, 'epoch': 1.6}


 54%|█████▍    | 1220/2262 [1:11:28<1:03:42,  3.67s/it]

{'loss': 0.0279, 'grad_norm': 0.4682999849319458, 'learning_rate': 2.3209549071618038e-05, 'epoch': 1.62}


 54%|█████▍    | 1230/2262 [1:12:05<1:03:05,  3.67s/it]

{'loss': 0.0439, 'grad_norm': 0.4807395935058594, 'learning_rate': 2.2988505747126437e-05, 'epoch': 1.63}


 55%|█████▍    | 1240/2262 [1:12:42<1:02:42,  3.68s/it]

{'loss': 0.028, 'grad_norm': 0.4725799858570099, 'learning_rate': 2.276746242263484e-05, 'epoch': 1.64}


 55%|█████▌    | 1250/2262 [1:13:19<1:02:27,  3.70s/it]

{'loss': 0.0334, 'grad_norm': 0.4375993013381958, 'learning_rate': 2.2546419098143238e-05, 'epoch': 1.66}


 56%|█████▌    | 1260/2262 [1:13:56<1:01:15,  3.67s/it]

{'loss': 0.0315, 'grad_norm': 0.528128981590271, 'learning_rate': 2.2325375773651637e-05, 'epoch': 1.67}


 56%|█████▌    | 1270/2262 [1:14:33<1:02:07,  3.76s/it]

{'loss': 0.0267, 'grad_norm': 0.43073126673698425, 'learning_rate': 2.2104332449160036e-05, 'epoch': 1.68}


 57%|█████▋    | 1280/2262 [1:15:12<1:02:47,  3.84s/it]

{'loss': 0.0344, 'grad_norm': 0.43659839034080505, 'learning_rate': 2.1883289124668434e-05, 'epoch': 1.7}


 57%|█████▋    | 1290/2262 [1:15:49<59:30,  3.67s/it]  

{'loss': 0.029, 'grad_norm': 0.6256820559501648, 'learning_rate': 2.1662245800176837e-05, 'epoch': 1.71}


 57%|█████▋    | 1300/2262 [1:16:25<58:46,  3.67s/it]

{'loss': 0.0318, 'grad_norm': 0.4573845863342285, 'learning_rate': 2.1441202475685236e-05, 'epoch': 1.72}


 58%|█████▊    | 1310/2262 [1:17:02<58:14,  3.67s/it]

{'loss': 0.0331, 'grad_norm': 0.61989825963974, 'learning_rate': 2.1220159151193635e-05, 'epoch': 1.74}


 58%|█████▊    | 1320/2262 [1:17:39<57:57,  3.69s/it]

{'loss': 0.0295, 'grad_norm': 1.1312410831451416, 'learning_rate': 2.0999115826702033e-05, 'epoch': 1.75}


 59%|█████▉    | 1330/2262 [1:18:16<57:10,  3.68s/it]

{'loss': 0.0268, 'grad_norm': 0.5128161907196045, 'learning_rate': 2.0778072502210436e-05, 'epoch': 1.76}


 59%|█████▉    | 1340/2262 [1:18:53<56:30,  3.68s/it]

{'loss': 0.0316, 'grad_norm': 0.6683998107910156, 'learning_rate': 2.0557029177718835e-05, 'epoch': 1.78}


 60%|█████▉    | 1350/2262 [1:19:30<55:39,  3.66s/it]

{'loss': 0.0534, 'grad_norm': 0.6485111713409424, 'learning_rate': 2.0335985853227234e-05, 'epoch': 1.79}


 60%|██████    | 1360/2262 [1:20:06<55:08,  3.67s/it]

{'loss': 0.0296, 'grad_norm': 0.6382503509521484, 'learning_rate': 2.0114942528735632e-05, 'epoch': 1.8}


 61%|██████    | 1370/2262 [1:20:43<54:33,  3.67s/it]

{'loss': 0.0529, 'grad_norm': 0.5866284370422363, 'learning_rate': 1.989389920424403e-05, 'epoch': 1.82}


 61%|██████    | 1380/2262 [1:21:20<53:45,  3.66s/it]

{'loss': 0.0333, 'grad_norm': 0.5697847604751587, 'learning_rate': 1.9672855879752434e-05, 'epoch': 1.83}


 61%|██████▏   | 1390/2262 [1:21:57<53:13,  3.66s/it]

{'loss': 0.0561, 'grad_norm': 0.39651262760162354, 'learning_rate': 1.9451812555260833e-05, 'epoch': 1.84}


 62%|██████▏   | 1400/2262 [1:22:33<52:35,  3.66s/it]

{'loss': 0.0282, 'grad_norm': 0.5984615683555603, 'learning_rate': 1.923076923076923e-05, 'epoch': 1.86}


 62%|██████▏   | 1410/2262 [1:23:10<52:05,  3.67s/it]

{'loss': 0.0302, 'grad_norm': 0.5398409366607666, 'learning_rate': 1.900972590627763e-05, 'epoch': 1.87}


 63%|██████▎   | 1420/2262 [1:23:47<51:33,  3.67s/it]

{'loss': 0.0371, 'grad_norm': 0.6111605167388916, 'learning_rate': 1.878868258178603e-05, 'epoch': 1.88}


 63%|██████▎   | 1430/2262 [1:24:24<50:53,  3.67s/it]

{'loss': 0.0264, 'grad_norm': 0.5900337100028992, 'learning_rate': 1.856763925729443e-05, 'epoch': 1.9}


 64%|██████▎   | 1440/2262 [1:25:01<50:06,  3.66s/it]

{'loss': 0.0269, 'grad_norm': 0.6407743096351624, 'learning_rate': 1.834659593280283e-05, 'epoch': 1.91}


 64%|██████▍   | 1450/2262 [1:25:38<51:17,  3.79s/it]

{'loss': 0.0262, 'grad_norm': 0.6549135446548462, 'learning_rate': 1.812555260831123e-05, 'epoch': 1.92}


 65%|██████▍   | 1460/2262 [1:26:16<49:42,  3.72s/it]

{'loss': 0.0273, 'grad_norm': 0.2588050365447998, 'learning_rate': 1.7904509283819628e-05, 'epoch': 1.94}


 65%|██████▍   | 1470/2262 [1:26:53<48:24,  3.67s/it]

{'loss': 0.0273, 'grad_norm': 0.29877492785453796, 'learning_rate': 1.768346595932803e-05, 'epoch': 1.95}


 65%|██████▌   | 1480/2262 [1:27:30<48:03,  3.69s/it]

{'loss': 0.0272, 'grad_norm': 0.5545470118522644, 'learning_rate': 1.746242263483643e-05, 'epoch': 1.96}


 66%|██████▌   | 1490/2262 [1:28:07<47:14,  3.67s/it]

{'loss': 0.0245, 'grad_norm': 0.4303203821182251, 'learning_rate': 1.7241379310344828e-05, 'epoch': 1.98}


 66%|██████▋   | 1500/2262 [1:28:43<46:38,  3.67s/it]

{'loss': 0.0262, 'grad_norm': 0.4416343867778778, 'learning_rate': 1.7020335985853227e-05, 'epoch': 1.99}


 67%|██████▋   | 1510/2262 [1:29:21<46:07,  3.68s/it]

{'loss': 0.0244, 'grad_norm': 0.5554056167602539, 'learning_rate': 1.6799292661361626e-05, 'epoch': 2.0}


 67%|██████▋   | 1520/2262 [1:29:58<45:50,  3.71s/it]

{'loss': 0.0231, 'grad_norm': 0.6968348622322083, 'learning_rate': 1.657824933687003e-05, 'epoch': 2.02}


 68%|██████▊   | 1530/2262 [1:30:35<44:57,  3.69s/it]

{'loss': 0.0215, 'grad_norm': 0.7150402665138245, 'learning_rate': 1.6357206012378427e-05, 'epoch': 2.03}


 68%|██████▊   | 1540/2262 [1:31:12<44:07,  3.67s/it]

{'loss': 0.0315, 'grad_norm': 0.5159174799919128, 'learning_rate': 1.6136162687886826e-05, 'epoch': 2.04}


 69%|██████▊   | 1550/2262 [1:31:48<43:27,  3.66s/it]

{'loss': 0.0247, 'grad_norm': 0.5122671127319336, 'learning_rate': 1.5915119363395225e-05, 'epoch': 2.06}


 69%|██████▉   | 1560/2262 [1:32:25<43:01,  3.68s/it]

{'loss': 0.0234, 'grad_norm': 0.7496623396873474, 'learning_rate': 1.5694076038903627e-05, 'epoch': 2.07}


 69%|██████▉   | 1570/2262 [1:33:02<42:27,  3.68s/it]

{'loss': 0.0234, 'grad_norm': 0.4910532832145691, 'learning_rate': 1.5473032714412026e-05, 'epoch': 2.08}


 70%|██████▉   | 1580/2262 [1:33:39<41:47,  3.68s/it]

{'loss': 0.0204, 'grad_norm': 0.551832914352417, 'learning_rate': 1.5251989389920427e-05, 'epoch': 2.1}


 70%|███████   | 1590/2262 [1:34:17<43:36,  3.89s/it]

{'loss': 0.0269, 'grad_norm': 0.3653830885887146, 'learning_rate': 1.5030946065428824e-05, 'epoch': 2.11}


 71%|███████   | 1600/2262 [1:34:53<40:17,  3.65s/it]

{'loss': 0.0297, 'grad_norm': 0.5041195750236511, 'learning_rate': 1.4809902740937223e-05, 'epoch': 2.12}


 71%|███████   | 1610/2262 [1:35:30<39:18,  3.62s/it]

{'loss': 0.0228, 'grad_norm': 0.5661701560020447, 'learning_rate': 1.4588859416445624e-05, 'epoch': 2.14}


 72%|███████▏  | 1620/2262 [1:36:06<38:44,  3.62s/it]

{'loss': 0.0244, 'grad_norm': 0.648736298084259, 'learning_rate': 1.4367816091954022e-05, 'epoch': 2.15}


 72%|███████▏  | 1630/2262 [1:36:44<39:26,  3.74s/it]

{'loss': 0.0223, 'grad_norm': 0.36003851890563965, 'learning_rate': 1.4146772767462425e-05, 'epoch': 2.16}


 73%|███████▎  | 1640/2262 [1:37:20<37:41,  3.64s/it]

{'loss': 0.0291, 'grad_norm': 0.47690993547439575, 'learning_rate': 1.3925729442970822e-05, 'epoch': 2.18}


 73%|███████▎  | 1650/2262 [1:37:57<37:35,  3.69s/it]

{'loss': 0.027, 'grad_norm': 0.5857700109481812, 'learning_rate': 1.3704686118479224e-05, 'epoch': 2.19}


 73%|███████▎  | 1660/2262 [1:38:34<36:34,  3.65s/it]

{'loss': 0.0272, 'grad_norm': 0.4790198504924774, 'learning_rate': 1.3483642793987621e-05, 'epoch': 2.2}


 74%|███████▍  | 1670/2262 [1:39:10<35:53,  3.64s/it]

{'loss': 0.0244, 'grad_norm': 0.4217016398906708, 'learning_rate': 1.3262599469496024e-05, 'epoch': 2.21}


 74%|███████▍  | 1680/2262 [1:39:47<35:08,  3.62s/it]

{'loss': 0.0242, 'grad_norm': 0.6732864379882812, 'learning_rate': 1.3041556145004421e-05, 'epoch': 2.23}


 75%|███████▍  | 1690/2262 [1:40:23<34:32,  3.62s/it]

{'loss': 0.0236, 'grad_norm': 0.5583467483520508, 'learning_rate': 1.282051282051282e-05, 'epoch': 2.24}


 75%|███████▌  | 1700/2262 [1:40:59<33:47,  3.61s/it]

{'loss': 0.0288, 'grad_norm': 0.7565642595291138, 'learning_rate': 1.259946949602122e-05, 'epoch': 2.25}


 76%|███████▌  | 1710/2262 [1:41:36<33:15,  3.62s/it]

{'loss': 0.0376, 'grad_norm': 1.0969020128250122, 'learning_rate': 1.2378426171529621e-05, 'epoch': 2.27}


 76%|███████▌  | 1720/2262 [1:42:12<32:43,  3.62s/it]

{'loss': 0.0254, 'grad_norm': 0.7018120288848877, 'learning_rate': 1.215738284703802e-05, 'epoch': 2.28}


 76%|███████▋  | 1730/2262 [1:42:48<32:04,  3.62s/it]

{'loss': 0.0274, 'grad_norm': 0.6171402931213379, 'learning_rate': 1.1936339522546419e-05, 'epoch': 2.29}


 77%|███████▋  | 1740/2262 [1:43:25<31:33,  3.63s/it]

{'loss': 0.0238, 'grad_norm': 0.5536542534828186, 'learning_rate': 1.171529619805482e-05, 'epoch': 2.31}


 77%|███████▋  | 1750/2262 [1:44:01<30:59,  3.63s/it]

{'loss': 0.0223, 'grad_norm': 0.9141562581062317, 'learning_rate': 1.1494252873563218e-05, 'epoch': 2.32}


 78%|███████▊  | 1760/2262 [1:44:38<30:26,  3.64s/it]

{'loss': 0.0216, 'grad_norm': 0.6766693592071533, 'learning_rate': 1.1273209549071619e-05, 'epoch': 2.33}


 78%|███████▊  | 1770/2262 [1:45:15<30:52,  3.76s/it]

{'loss': 0.0247, 'grad_norm': 0.6207599639892578, 'learning_rate': 1.1052166224580018e-05, 'epoch': 2.35}


 79%|███████▊  | 1780/2262 [1:45:52<29:10,  3.63s/it]

{'loss': 0.0246, 'grad_norm': 0.4240770936012268, 'learning_rate': 1.0831122900088418e-05, 'epoch': 2.36}


 79%|███████▉  | 1790/2262 [1:46:28<28:33,  3.63s/it]

{'loss': 0.0239, 'grad_norm': 0.45768219232559204, 'learning_rate': 1.0610079575596817e-05, 'epoch': 2.37}


 80%|███████▉  | 1800/2262 [1:47:05<28:21,  3.68s/it]

{'loss': 0.0212, 'grad_norm': 0.45484381914138794, 'learning_rate': 1.0389036251105218e-05, 'epoch': 2.39}


 80%|████████  | 1810/2262 [1:47:42<28:01,  3.72s/it]

{'loss': 0.0231, 'grad_norm': 0.3923461437225342, 'learning_rate': 1.0167992926613617e-05, 'epoch': 2.4}


 80%|████████  | 1820/2262 [1:48:19<26:47,  3.64s/it]

{'loss': 0.0266, 'grad_norm': 0.5502606630325317, 'learning_rate': 9.946949602122016e-06, 'epoch': 2.41}


 81%|████████  | 1830/2262 [1:48:55<26:03,  3.62s/it]

{'loss': 0.0241, 'grad_norm': 0.4722103774547577, 'learning_rate': 9.725906277630416e-06, 'epoch': 2.43}


 81%|████████▏ | 1840/2262 [1:49:31<25:25,  3.62s/it]

{'loss': 0.0275, 'grad_norm': 1.5219930410385132, 'learning_rate': 9.504862953138815e-06, 'epoch': 2.44}


 82%|████████▏ | 1850/2262 [1:50:08<24:49,  3.62s/it]

{'loss': 0.0316, 'grad_norm': 0.7192596793174744, 'learning_rate': 9.283819628647216e-06, 'epoch': 2.45}


 82%|████████▏ | 1860/2262 [1:50:44<24:12,  3.61s/it]

{'loss': 0.0243, 'grad_norm': 0.7733884453773499, 'learning_rate': 9.062776304155615e-06, 'epoch': 2.47}


 83%|████████▎ | 1870/2262 [1:51:21<24:07,  3.69s/it]

{'loss': 0.0224, 'grad_norm': 0.5705422759056091, 'learning_rate': 8.841732979664015e-06, 'epoch': 2.48}


 83%|████████▎ | 1880/2262 [1:51:58<24:06,  3.79s/it]

{'loss': 0.0227, 'grad_norm': 0.37063199281692505, 'learning_rate': 8.620689655172414e-06, 'epoch': 2.49}


 84%|████████▎ | 1890/2262 [1:52:35<22:39,  3.66s/it]

{'loss': 0.0246, 'grad_norm': 0.5365554094314575, 'learning_rate': 8.399646330680813e-06, 'epoch': 2.51}


 84%|████████▍ | 1900/2262 [1:53:11<21:57,  3.64s/it]

{'loss': 0.0214, 'grad_norm': 0.6299540400505066, 'learning_rate': 8.178603006189214e-06, 'epoch': 2.52}


 84%|████████▍ | 1910/2262 [1:53:48<21:25,  3.65s/it]

{'loss': 0.0242, 'grad_norm': 0.4092448949813843, 'learning_rate': 7.957559681697613e-06, 'epoch': 2.53}


 85%|████████▍ | 1920/2262 [1:54:24<20:40,  3.63s/it]

{'loss': 0.034, 'grad_norm': 0.5499797463417053, 'learning_rate': 7.736516357206013e-06, 'epoch': 2.55}


 85%|████████▌ | 1930/2262 [1:55:01<20:02,  3.62s/it]

{'loss': 0.0233, 'grad_norm': 0.7217724919319153, 'learning_rate': 7.515473032714412e-06, 'epoch': 2.56}


 86%|████████▌ | 1940/2262 [1:55:37<19:28,  3.63s/it]

{'loss': 0.0234, 'grad_norm': 0.5868905782699585, 'learning_rate': 7.294429708222812e-06, 'epoch': 2.57}


 86%|████████▌ | 1950/2262 [1:56:17<19:43,  3.79s/it]

{'loss': 0.0246, 'grad_norm': 0.4296846091747284, 'learning_rate': 7.073386383731212e-06, 'epoch': 2.59}


 87%|████████▋ | 1960/2262 [1:56:54<18:14,  3.62s/it]

{'loss': 0.0232, 'grad_norm': 0.7558116912841797, 'learning_rate': 6.852343059239612e-06, 'epoch': 2.6}


 87%|████████▋ | 1970/2262 [1:57:30<17:38,  3.62s/it]

{'loss': 0.0214, 'grad_norm': 0.33739110827445984, 'learning_rate': 6.631299734748012e-06, 'epoch': 2.61}


 88%|████████▊ | 1980/2262 [1:58:07<17:37,  3.75s/it]

{'loss': 0.0267, 'grad_norm': 0.4066608250141144, 'learning_rate': 6.41025641025641e-06, 'epoch': 2.63}


 88%|████████▊ | 1990/2262 [1:58:44<16:39,  3.67s/it]

{'loss': 0.0228, 'grad_norm': 0.5428808927536011, 'learning_rate': 6.1892130857648105e-06, 'epoch': 2.64}


 88%|████████▊ | 2000/2262 [1:59:21<15:52,  3.64s/it]

{'loss': 0.0259, 'grad_norm': 0.6457363963127136, 'learning_rate': 5.968169761273209e-06, 'epoch': 2.65}


 89%|████████▉ | 2010/2262 [1:59:57<15:11,  3.62s/it]

{'loss': 0.0221, 'grad_norm': 0.8092381358146667, 'learning_rate': 5.747126436781609e-06, 'epoch': 2.67}


 89%|████████▉ | 2020/2262 [2:00:34<14:38,  3.63s/it]

{'loss': 0.0228, 'grad_norm': 0.5548058748245239, 'learning_rate': 5.526083112290009e-06, 'epoch': 2.68}


 90%|████████▉ | 2030/2262 [2:01:10<13:59,  3.62s/it]

{'loss': 0.0228, 'grad_norm': 0.4031423032283783, 'learning_rate': 5.305039787798409e-06, 'epoch': 2.69}


 90%|█████████ | 2040/2262 [2:01:46<13:24,  3.63s/it]

{'loss': 0.0368, 'grad_norm': 0.5534734129905701, 'learning_rate': 5.083996463306808e-06, 'epoch': 2.71}


 91%|█████████ | 2050/2262 [2:02:23<12:49,  3.63s/it]

{'loss': 0.0251, 'grad_norm': 0.5048367381095886, 'learning_rate': 4.862953138815208e-06, 'epoch': 2.72}


 91%|█████████ | 2060/2262 [2:02:59<12:10,  3.61s/it]

{'loss': 0.0208, 'grad_norm': 0.25707119703292847, 'learning_rate': 4.641909814323608e-06, 'epoch': 2.73}


 92%|█████████▏| 2070/2262 [2:03:36<11:36,  3.63s/it]

{'loss': 0.0302, 'grad_norm': 0.3438515067100525, 'learning_rate': 4.420866489832008e-06, 'epoch': 2.75}


 92%|█████████▏| 2080/2262 [2:04:12<10:57,  3.61s/it]

{'loss': 0.0231, 'grad_norm': 0.4979548752307892, 'learning_rate': 4.1998231653404065e-06, 'epoch': 2.76}


 92%|█████████▏| 2090/2262 [2:04:48<10:22,  3.62s/it]

{'loss': 0.0223, 'grad_norm': 0.7897560596466064, 'learning_rate': 3.978779840848806e-06, 'epoch': 2.77}


 93%|█████████▎| 2100/2262 [2:05:25<10:17,  3.81s/it]

{'loss': 0.0314, 'grad_norm': 0.5497084259986877, 'learning_rate': 3.757736516357206e-06, 'epoch': 2.79}


 93%|█████████▎| 2110/2262 [2:06:10<11:34,  4.57s/it]

{'loss': 0.0212, 'grad_norm': 0.6572733521461487, 'learning_rate': 3.536693191865606e-06, 'epoch': 2.8}


 94%|█████████▎| 2120/2262 [2:06:57<11:16,  4.76s/it]

{'loss': 0.0215, 'grad_norm': 0.4980076849460602, 'learning_rate': 3.315649867374006e-06, 'epoch': 2.81}


 94%|█████████▍| 2130/2262 [2:07:43<10:04,  4.58s/it]

{'loss': 0.0222, 'grad_norm': 0.42545634508132935, 'learning_rate': 3.0946065428824053e-06, 'epoch': 2.82}


 95%|█████████▍| 2140/2262 [2:08:32<10:04,  4.95s/it]

{'loss': 0.0238, 'grad_norm': 0.4769528806209564, 'learning_rate': 2.8735632183908046e-06, 'epoch': 2.84}


 95%|█████████▌| 2150/2262 [2:09:23<09:07,  4.89s/it]

{'loss': 0.0244, 'grad_norm': 0.44186508655548096, 'learning_rate': 2.6525198938992043e-06, 'epoch': 2.85}


 95%|█████████▌| 2160/2262 [2:10:09<07:39,  4.50s/it]

{'loss': 0.0237, 'grad_norm': 0.4451349377632141, 'learning_rate': 2.431476569407604e-06, 'epoch': 2.86}


 96%|█████████▌| 2170/2262 [2:10:53<06:45,  4.41s/it]

{'loss': 0.0289, 'grad_norm': 0.9875059127807617, 'learning_rate': 2.210433244916004e-06, 'epoch': 2.88}


 96%|█████████▋| 2180/2262 [2:11:37<06:02,  4.42s/it]

{'loss': 0.0215, 'grad_norm': 0.8981293439865112, 'learning_rate': 1.989389920424403e-06, 'epoch': 2.89}


 97%|█████████▋| 2190/2262 [2:12:21<05:15,  4.38s/it]

{'loss': 0.0193, 'grad_norm': 0.5024239420890808, 'learning_rate': 1.768346595932803e-06, 'epoch': 2.9}


 97%|█████████▋| 2200/2262 [2:13:05<04:30,  4.37s/it]

{'loss': 0.021, 'grad_norm': 0.8217297792434692, 'learning_rate': 1.5473032714412026e-06, 'epoch': 2.92}


 98%|█████████▊| 2210/2262 [2:13:49<03:49,  4.41s/it]

{'loss': 0.0205, 'grad_norm': 0.5031179785728455, 'learning_rate': 1.3262599469496022e-06, 'epoch': 2.93}


 98%|█████████▊| 2220/2262 [2:14:33<03:03,  4.37s/it]

{'loss': 0.0433, 'grad_norm': 0.6864629983901978, 'learning_rate': 1.105216622458002e-06, 'epoch': 2.94}


 99%|█████████▊| 2230/2262 [2:15:17<02:22,  4.44s/it]

{'loss': 0.0191, 'grad_norm': 0.302197128534317, 'learning_rate': 8.841732979664015e-07, 'epoch': 2.96}


 99%|█████████▉| 2240/2262 [2:16:01<01:35,  4.34s/it]

{'loss': 0.0335, 'grad_norm': 1.1364024877548218, 'learning_rate': 6.631299734748011e-07, 'epoch': 2.97}


 99%|█████████▉| 2250/2262 [2:16:45<00:52,  4.34s/it]

{'loss': 0.0244, 'grad_norm': 0.4399726390838623, 'learning_rate': 4.4208664898320077e-07, 'epoch': 2.98}


100%|█████████▉| 2260/2262 [2:17:29<00:08,  4.49s/it]

{'loss': 0.0222, 'grad_norm': 0.4678157866001129, 'learning_rate': 2.2104332449160039e-07, 'epoch': 3.0}


100%|██████████| 2262/2262 [2:17:38<00:00,  3.65s/it]


{'train_runtime': 8258.6341, 'train_samples_per_second': 0.548, 'train_steps_per_second': 0.274, 'train_loss': 0.09346748259705742, 'epoch': 3.0}


In [13]:
model2 = AutoModelForCausalLM.from_pretrained(output_dir)

In [17]:
# Fine-tuned model

classifier = pipeline(
    "text-generation", 
    model=model2,
    tokenizer=tokenizer,
    torch_dtype=torch.bfloat16,
    max_length=100,
    device=device
)

response = classifier("What do people in Canada eat?")
response

[{'generated_text': 'What do people in Canada eat? The most eaten foods in Canada: Spaghetti and meatballs'}]

In [18]:
# Move model to endpoint
finished_model = AutoModelForCausalLM.from_pretrained(output_dir)
tokenizer = AutoTokenizer.from_pretrained(output_dir)

finished_model.save_pretrained(enpoint_url)
tokenizer.save_pretrained(enpoint_url)

('../apis/model/gpt-v1\\tokenizer_config.json',
 '../apis/model/gpt-v1\\special_tokens_map.json',
 '../apis/model/gpt-v1\\vocab.json',
 '../apis/model/gpt-v1\\merges.txt',
 '../apis/model/gpt-v1\\added_tokens.json',
 '../apis/model/gpt-v1\\tokenizer.json')