In [1]:
from datasets import load_from_disk
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType
import torch
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, TrainingArguments, Trainer, EarlyStoppingCallback, AutoTokenizer
from huggingface_hub import login
from dotenv import load_dotenv
import os
import wandb

In [2]:
WB_KEY = os.getenv("WB_KEY")
wandb.login(key=WB_KEY)
run = wandb.init(project="Digital Self-Replica", job_type="Training", name="train_test2")

wandb: Appending key for api.wandb.ai to your netrc file: C:\Users\Francesco\_netrc
wandb: Currently logged in as: francescobrigante (francescobrigante_s_projects) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


In [3]:
#pip install -U bitsandbytes

In [4]:
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if (HF_TOKEN == None):
    raise ValueError("HF_TOKEN is not set")
login(token=HF_TOKEN)

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [5]:
# 4 bit quantization
# could be further increased to 8b for more precision
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

# LoRA configuration for Qwen model architecture
lora_config = LoraConfig(
    r=32,                       #rank of the added low-rank matrices
    lora_alpha=64,              #generally 2*r
    target_modules=[            #modules where LoRA is applied
        "q_proj",               # query, key, value, output projection layers in the self-attention
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",            # gate, up, down are part of the FFNN in the model
        "up_proj",
        "down_proj"
    ],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

In [6]:
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    quantization_config=quantization_config
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
# # preparing model for LoRA
# model = prepare_model_for_kbit_training(model)
# model = get_peft_model(model, lora_config)


# # training arguments
# training_args = TrainingArguments(
#     output_dir="./francesco_lora",
#     num_train_epochs=3,
#     per_device_train_batch_size=6,
#     gradient_accumulation_steps=4,      # effective batch size = per_device_train_batch_size * gradinet_accumulation_steps
#     #per_device_eval_batch_size=4,
#     #eval_accumulation_steps=6,
#     # warmup_steps=5,
#     warmup_ratio=0.03,
#     learning_rate=3e-4,                # Slightly lower for distilled model
#     optim="paged_adamw_8bit",         # 8bit optimizer <- ADDED
#     #optim="adamw_torch",
#     lr_scheduler_type="cosine",       # cosine learning rate scheduler <- ADDED
#     fp16=True,
#     logging_steps=5,
#     eval_strategy="no",
#     #eval_strategy="steps",
#     #eval_steps=5,
#     save_steps=50,
#     save_strategy="steps",
#     #load_best_model_at_end=True,
#     #save_total_limit=1,
#     metric_for_best_model="loss",
#     greater_is_better=False,            #lower loss is better
#     gradient_checkpointing=False,
#     disable_tqdm=False,
#     report_to=["wandb"],                                # Enable W&B logging 
#     label_names=["labels"]  # Explicitly specify label field
# )

In [8]:
# preparing model for LoRA
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)


# TRY  AND BATCH 8x8

# TRY ALSO THIS:

    # eval_strategy="steps",
    # eval_steps=100,
    # per_device_eval_batch_size=16,
    # eval_accumulation_steps=2,
    # compute_metrics=None,  # Loss-only evaluation
    # load_best_model_at_end=True,

# training arguments
training_args = TrainingArguments(
    output_dir="./francesco_lora",
    num_train_epochs=3,
    per_device_train_batch_size=6,
    gradient_accumulation_steps=4,      # effective batch size = per_device_train_batch_size * gradinet_accumulation_steps
    #per_device_eval_batch_size=4,
    #eval_accumulation_steps=6,
    # warmup_steps=5,
    warmup_ratio=0.1,
    learning_rate=2e-4,                # Slightly lower for distilled model
    optim="paged_adamw_8bit",         # 8bit optimizer <- ADDED
    #optim="adamw_torch",
    lr_scheduler_type="cosine",       # cosine learning rate scheduler <- ADDED
    weight_decay=0.01,
    fp16=True,
    logging_steps=5,
    #eval_strategy="no",
    eval_strategy="steps",
    per_device_eval_batch_size=16,
    eval_accumulation_steps=2,
    eval_steps=5,
    save_steps=50,
    save_strategy="steps",
    load_best_model_at_end=True,
    #save_total_limit=1,
    metric_for_best_model="loss",
    greater_is_better=False,            #lower loss is better
    gradient_checkpointing=False,
    max_grad_norm=0.5,
    disable_tqdm=False,
    report_to=["wandb"],                                # Enable W&B logging
    label_names=["labels"]  # Explicitly specify label field
)

In [9]:
from transformers import DataCollatorWithPadding, DataCollatorForLanguageModeling, DefaultDataCollator

tokenizer = AutoTokenizer.from_pretrained(model_id)

data_collator = DefaultDataCollator()

#Dynamic padding for causal LM
# data_collator = DataCollatorForLanguageModeling(
#    tokenizer=tokenizer,
#    mlm=False
# )

#Dynamic padding more general purpose
# it doesnt handle shifting lables: you have to implement it manually

# data_collator = DataCollatorWithPadding(
#     tokenizer=tokenizer,
#     padding=True,            # pad to longest in batch
#     return_tensors="pt",
# )


In [10]:
# loading datasets
tokenized_train = load_from_disk('datasets/tokenized_train')
tokenized_val = load_from_disk('datasets/tokenized_val')
#tokenized_test = load_from_disk('datasets/tokenized_test')

# print
print(f"Training examples: {len(tokenized_train)}")
print(f"Validation examples: {len(tokenized_val)}")
#print(f"Test examples: {len(tokenized_test)}")

print("\nOne training example:")
print(tokenized_train[1000])

Training examples: 8720
Validation examples: 1090

One training example:
{'input_ids': [151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 

In [11]:
#IF ON COLAB
# from google.colab import drive

# drive.mount('/content/drive')
# drive_base_path = '/content/drive/My Drive/datasets'

# tokenized_train = load_from_disk(os.path.join(drive_base_path, 'tokenized_train'))
# tokenized_val = load_from_disk(os.path.join(drive_base_path, 'tokenized_val'))
# tokenized_test = load_from_disk(os.path.join(drive_base_path, 'tokenized_test'))

# print("Datasets loaded successfully from Google Drive!")
# print(f"Training examples: {len(tokenized_train)}")
# print(f"Validation examples: {len(tokenized_val)}")
# print(f"Test examples: {len(tokenized_test)}")

# print("\nOne training example:")
# print(tokenized_train[8000])

In [12]:
# print trainable parameters
model.print_trainable_parameters()

# training
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=data_collator,
)

# add early stopping
early_stopping = EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.01)
trainer.add_callback(early_stopping)

trainable params: 80,740,352 || all params: 7,696,356,864 || trainable%: 1.0491


In [13]:
trainer.train()



  0%|          | 0/1089 [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
  return fn(*args, **kwargs)
  attn_output = torch.nn.functional.scaled_dot_product_attention(
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 8.532, 'grad_norm': 13.746661186218262, 'learning_rate': 9.174311926605506e-06, 'epoch': 0.01}


  0%|          | 0/69 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
#trainer.train(resume_from_checkpoint="./francesco_lora/checkpoint-50")

	save_steps: 50 (from args) != 25 (from trainer_state.json)


  0%|          | 0/1089 [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
  return fn(*args, **kwargs)
  attn_output = torch.nn.functional.scaled_dot_product_attention(
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 2.9748, 'grad_norm': 2.294355630874634, 'learning_rate': 4.545454545454545e-05, 'epoch': 0.08}
{'loss': 2.6041, 'grad_norm': 2.8884832859039307, 'learning_rate': 9.09090909090909e-05, 'epoch': 0.1}
{'loss': 2.4398, 'grad_norm': 3.4224636554718018, 'learning_rate': 0.00013636363636363634, 'epoch': 0.11}
{'loss': 2.9799, 'grad_norm': 4.153173923492432, 'learning_rate': 0.0001818181818181818, 'epoch': 0.12}
{'loss': 3.3979, 'grad_norm': 5.842740058898926, 'learning_rate': 0.00022727272727272725, 'epoch': 0.14}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.0973, 'grad_norm': 4.533180236816406, 'learning_rate': 0.0002727272727272727, 'epoch': 0.15}
{'loss': 3.6818, 'grad_norm': 3.4354896545410156, 'learning_rate': 0.00029999734483275115, 'epoch': 0.17}
{'loss': 4.0499, 'grad_norm': 5.2512383460998535, 'learning_rate': 0.0002999674752807096, 'epoch': 0.18}
{'loss': 3.8722, 'grad_norm': 4.789249897003174, 'learning_rate': 0.00029990442384854874, 'epoch': 0.19}
{'loss': 3.8372, 'grad_norm': 3.21344256401062, 'learning_rate': 0.0002998082044870607, 'epoch': 0.21}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 3.4786, 'grad_norm': 3.6903276443481445, 'learning_rate': 0.0002996788384857905, 'epoch': 0.22}
{'loss': 3.7358, 'grad_norm': 3.7469398975372314, 'learning_rate': 0.0002995163544683256, 'epoch': 0.23}
{'loss': 3.6151, 'grad_norm': 2.9869725704193115, 'learning_rate': 0.0002993207883859627, 'epoch': 0.25}
{'loss': 3.7306, 'grad_norm': 3.024510622024536, 'learning_rate': 0.00029909218350975285, 'epoch': 0.26}
{'loss': 3.5931, 'grad_norm': 2.9555509090423584, 'learning_rate': 0.00029883059042092774, 'epoch': 0.28}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 3.6485, 'grad_norm': 3.150639057159424, 'learning_rate': 0.00029853606699970766, 'epoch': 0.29}
{'loss': 3.567, 'grad_norm': 2.8044278621673584, 'learning_rate': 0.0002982086784124952, 'epoch': 0.3}
{'loss': 3.5932, 'grad_norm': 3.487525701522827, 'learning_rate': 0.00029784849709745616, 'epoch': 0.32}
{'loss': 3.5303, 'grad_norm': 1.7299824953079224, 'learning_rate': 0.00029745560274849214, 'epoch': 0.33}
{'loss': 3.8586, 'grad_norm': 3.0844006538391113, 'learning_rate': 0.00029703008229760736, 'epoch': 0.34}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 3.5269, 'grad_norm': 3.211728572845459, 'learning_rate': 0.00029657202989567393, 'epoch': 0.36}
{'loss': 3.3147, 'grad_norm': 2.5923051834106445, 'learning_rate': 0.0002960815468916, 'epoch': 0.37}
{'loss': 3.5988, 'grad_norm': 2.683166027069092, 'learning_rate': 0.0002955587418099055, 'epoch': 0.39}
{'loss': 3.8877, 'grad_norm': 2.918941020965576, 'learning_rate': 0.0002950037303267096, 'epoch': 0.4}
{'loss': 3.3882, 'grad_norm': 3.6034300327301025, 'learning_rate': 0.0002944166352441363, 'epoch': 0.41}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 3.764, 'grad_norm': 2.5907528400421143, 'learning_rate': 0.00029379758646314323, 'epoch': 0.43}
{'loss': 3.4164, 'grad_norm': 3.426248788833618, 'learning_rate': 0.00029314672095477953, 'epoch': 0.44}
{'loss': 3.6221, 'grad_norm': 4.063133239746094, 'learning_rate': 0.00029246418272987993, 'epoch': 0.45}
{'loss': 3.1293, 'grad_norm': 2.362250328063965, 'learning_rate': 0.00029175012280720024, 'epoch': 0.47}
{'loss': 3.8398, 'grad_norm': 3.417720079421997, 'learning_rate': 0.0002910046991800035, 'epoch': 0.48}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 3.5184, 'grad_norm': 3.3896896839141846, 'learning_rate': 0.0002902280767811019, 'epoch': 0.5}
{'loss': 3.4015, 'grad_norm': 2.8821897506713867, 'learning_rate': 0.0002894204274463637, 'epoch': 0.51}
{'loss': 3.2702, 'grad_norm': 2.7369213104248047, 'learning_rate': 0.000288581929876693, 'epoch': 0.52}
{'loss': 3.3546, 'grad_norm': 3.7003273963928223, 'learning_rate': 0.00028771276959848994, 'epoch': 0.54}
{'loss': 3.6029, 'grad_norm': 2.752347469329834, 'learning_rate': 0.0002868131389226013, 'epoch': 0.55}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 3.3021, 'grad_norm': 2.8787314891815186, 'learning_rate': 0.00028588323690176954, 'epoch': 0.56}
{'loss': 3.5473, 'grad_norm': 3.1614179611206055, 'learning_rate': 0.00028492326928659045, 'epoch': 0.58}
{'loss': 3.5059, 'grad_norm': 2.529334306716919, 'learning_rate': 0.00028393344847998844, 'epoch': 0.59}
{'loss': 3.3777, 'grad_norm': 2.9682776927948, 'learning_rate': 0.00028291399349022036, 'epoch': 0.61}
{'loss': 3.2118, 'grad_norm': 2.576864004135132, 'learning_rate': 0.00028186512988241755, 'epoch': 0.62}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 3.4563, 'grad_norm': 2.8691444396972656, 'learning_rate': 0.0002807870897286772, 'epoch': 0.63}
{'loss': 3.4198, 'grad_norm': 2.484403133392334, 'learning_rate': 0.0002796801115567139, 'epoch': 0.65}
{'loss': 3.5658, 'grad_norm': 2.5332326889038086, 'learning_rate': 0.0002785444402970829, 'epoch': 0.66}
{'loss': 3.4539, 'grad_norm': 2.954714775085449, 'learning_rate': 0.00027738032722898683, 'epoch': 0.67}
{'loss': 3.419, 'grad_norm': 3.0758652687072754, 'learning_rate': 0.0002761880299246772, 'epoch': 0.69}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 3.4333, 'grad_norm': 2.5053317546844482, 'learning_rate': 0.000274967812192464, 'epoch': 0.7}
{'loss': 3.6111, 'grad_norm': 2.93147349357605, 'learning_rate': 0.00027371994401834555, 'epoch': 0.72}
{'loss': 3.5379, 'grad_norm': 2.1655354499816895, 'learning_rate': 0.0002724447015062708, 'epoch': 0.73}
{'loss': 3.4123, 'grad_norm': 2.559091329574585, 'learning_rate': 0.000271142366817049, 'epoch': 0.74}
{'loss': 3.5006, 'grad_norm': 3.0507078170776367, 'learning_rate': 0.00026981322810591793, 'epoch': 0.76}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 3.4624, 'grad_norm': 3.4602880477905273, 'learning_rate': 0.00026845757945878737, 'epoch': 0.77}
{'loss': 3.5757, 'grad_norm': 2.1083385944366455, 'learning_rate': 0.0002670757208271687, 'epoch': 0.78}
{'loss': 3.4725, 'grad_norm': 2.472771406173706, 'learning_rate': 0.0002656679579618081, 'epoch': 0.8}
{'loss': 3.4252, 'grad_norm': 2.5242655277252197, 'learning_rate': 0.0002642346023450357, 'epoch': 0.81}
{'loss': 3.3468, 'grad_norm': 2.2211289405822754, 'learning_rate': 0.0002627759711218466, 'epoch': 0.83}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 3.4813, 'grad_norm': 2.8587441444396973, 'learning_rate': 0.00026129238702972987, 'epoch': 0.84}
{'loss': 3.2853, 'grad_norm': 2.517057180404663, 'learning_rate': 0.0002597841783272588, 'epoch': 0.85}
{'loss': 3.3555, 'grad_norm': 2.2894883155822754, 'learning_rate': 0.0002582516787214607, 'epoch': 0.87}
{'loss': 3.283, 'grad_norm': 2.8041491508483887, 'learning_rate': 0.0002566952272939805, 'epoch': 0.88}
{'loss': 3.1511, 'grad_norm': 2.561751365661621, 'learning_rate': 0.0002551151684260553, 'epoch': 0.89}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 3.4883, 'grad_norm': 2.422520399093628, 'learning_rate': 0.0002535118517223168, 'epoch': 0.91}
{'loss': 3.2834, 'grad_norm': 2.6135976314544678, 'learning_rate': 0.000251885631933437, 'epoch': 0.92}
{'loss': 3.2794, 'grad_norm': 2.83675217628479, 'learning_rate': 0.00025023686887763643, 'epoch': 0.94}
{'loss': 3.4521, 'grad_norm': 2.7402193546295166, 'learning_rate': 0.0002485659273610703, 'epoch': 0.95}
{'loss': 3.649, 'grad_norm': 1.6840540170669556, 'learning_rate': 0.0002468731770971113, 'epoch': 0.96}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 3.1643, 'grad_norm': 3.0399107933044434, 'learning_rate': 0.0002451589926245468, 'epoch': 0.98}
{'loss': 3.2421, 'grad_norm': 2.6941335201263428, 'learning_rate': 0.00024342375322470807, 'epoch': 0.99}
{'loss': 3.5644, 'grad_norm': 2.5662319660186768, 'learning_rate': 0.00024166784283755034, 'epoch': 1.0}
{'loss': 2.6523, 'grad_norm': 2.6218206882476807, 'learning_rate': 0.00023989164997670202, 'epoch': 1.02}
{'loss': 2.602, 'grad_norm': 2.861525535583496, 'learning_rate': 0.00023809556764350204, 'epoch': 1.03}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 2.8857, 'grad_norm': 2.3202781677246094, 'learning_rate': 0.000236279993240044, 'epoch': 1.05}
{'loss': 2.279, 'grad_norm': 2.7474334239959717, 'learning_rate': 0.00023444532848124715, 'epoch': 1.06}
{'loss': 2.5033, 'grad_norm': 2.9517195224761963, 'learning_rate': 0.0002325919793059723, 'epoch': 1.07}
{'loss': 2.7729, 'grad_norm': 2.9528772830963135, 'learning_rate': 0.00023072035578720388, 'epoch': 1.09}
{'loss': 2.4171, 'grad_norm': 1.7767149209976196, 'learning_rate': 0.0002288308720413169, 'epoch': 1.1}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 2.4914, 'grad_norm': 2.462202310562134, 'learning_rate': 0.00022692394613644932, 'epoch': 1.11}
{'loss': 2.5779, 'grad_norm': 2.972353935241699, 'learning_rate': 0.000225, 'epoch': 1.13}
{'loss': 2.8055, 'grad_norm': 3.413198709487915, 'learning_rate': 0.00022305945932527308, 'epoch': 1.14}
{'loss': 2.7694, 'grad_norm': 2.9278130531311035, 'learning_rate': 0.00022110275347728858, 'epoch': 1.16}
{'loss': 2.7265, 'grad_norm': 3.074336051940918, 'learning_rate': 0.00021913031539778116, 'epoch': 1.17}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 2.589, 'grad_norm': 3.1302847862243652, 'learning_rate': 0.00021714258150940685, 'epoch': 1.18}
{'loss': 2.7469, 'grad_norm': 2.52520489692688, 'learning_rate': 0.0002151399916191804, 'epoch': 1.2}
{'loss': 2.7441, 'grad_norm': 3.4955756664276123, 'learning_rate': 0.00021312298882116286, 'epoch': 1.21}
{'loss': 2.8741, 'grad_norm': 3.422661304473877, 'learning_rate': 0.0002110920193984228, 'epoch': 1.22}
{'loss': 2.616, 'grad_norm': 3.250969409942627, 'learning_rate': 0.0002090475327242912, 'epoch': 1.24}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 2.6407, 'grad_norm': 2.9759998321533203, 'learning_rate': 0.00020698998116293302, 'epoch': 1.25}
{'loss': 2.8436, 'grad_norm': 2.732311964035034, 'learning_rate': 0.0002049198199692569, 'epoch': 1.27}
{'loss': 2.763, 'grad_norm': 2.571718692779541, 'learning_rate': 0.00020283750718818501, 'epoch': 1.28}
{'loss': 2.5906, 'grad_norm': 3.319793939590454, 'learning_rate': 0.0002007435035533061, 'epoch': 1.29}
{'loss': 2.7229, 'grad_norm': 2.5533857345581055, 'learning_rate': 0.00019863827238493308, 'epoch': 1.31}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 2.6287, 'grad_norm': 3.140554904937744, 'learning_rate': 0.00019652227948758878, 'epoch': 1.32}
{'loss': 2.694, 'grad_norm': 2.986036539077759, 'learning_rate': 0.00019439599304694154, 'epoch': 1.33}
{'loss': 2.7893, 'grad_norm': 3.369769811630249, 'learning_rate': 0.00019225988352621445, 'epoch': 1.35}
{'loss': 2.488, 'grad_norm': 2.9034252166748047, 'learning_rate': 0.00019011442356209023, 'epoch': 1.36}
{'loss': 2.5245, 'grad_norm': 2.8083109855651855, 'learning_rate': 0.0001879600878601355, 'epoch': 1.38}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 2.8457, 'grad_norm': 3.626420736312866, 'learning_rate': 0.00018579735308976727, 'epoch': 1.39}
{'loss': 2.4842, 'grad_norm': 2.571101427078247, 'learning_rate': 0.00018362669777878453, 'epoch': 1.4}
{'loss': 2.2869, 'grad_norm': 2.690263271331787, 'learning_rate': 0.00018144860220748932, 'epoch': 1.42}
{'loss': 2.7093, 'grad_norm': 2.85054349899292, 'learning_rate': 0.00017926354830241924, 'epoch': 1.43}
{'loss': 2.6496, 'grad_norm': 2.6131203174591064, 'learning_rate': 0.0001770720195297166, 'epoch': 1.44}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 2.8099, 'grad_norm': 2.9764621257781982, 'learning_rate': 0.0001748745007881561, 'epoch': 1.46}
{'loss': 2.6152, 'grad_norm': 3.266005277633667, 'learning_rate': 0.00017267147830185608, 'epoch': 1.47}
{'loss': 2.4533, 'grad_norm': 2.7621657848358154, 'learning_rate': 0.00017046343951269621, 'epoch': 1.49}
{'loss': 2.8339, 'grad_norm': 3.6515700817108154, 'learning_rate': 0.00016825087297246582, 'epoch': 1.5}
{'loss': 2.5948, 'grad_norm': 2.8908591270446777, 'learning_rate': 0.00016603426823476693, 'epoch': 1.51}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 2.8234, 'grad_norm': 2.528038263320923, 'learning_rate': 0.000163814115746695, 'epoch': 1.53}
{'loss': 2.638, 'grad_norm': 3.00260853767395, 'learning_rate': 0.00016159090674032267, 'epoch': 1.54}
{'loss': 2.3241, 'grad_norm': 2.7044365406036377, 'learning_rate': 0.00015936513312400936, 'epoch': 1.55}
{'loss': 2.5686, 'grad_norm': 2.942772150039673, 'learning_rate': 0.00015713728737356137, 'epoch': 1.57}
{'loss': 2.6249, 'grad_norm': 3.6798386573791504, 'learning_rate': 0.00015490786242326643, 'epoch': 1.58}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 2.6257, 'grad_norm': 2.6975250244140625, 'learning_rate': 0.00015267735155682688, 'epoch': 1.6}
{'loss': 2.4316, 'grad_norm': 2.7367029190063477, 'learning_rate': 0.0001504462482982155, 'epoch': 1.61}
{'loss': 2.6303, 'grad_norm': 3.0093727111816406, 'learning_rate': 0.00014821504630247785, 'epoch': 1.62}
{'loss': 2.6357, 'grad_norm': 2.9841971397399902, 'learning_rate': 0.0001459842392465063, 'epoch': 1.64}
{'loss': 2.5635, 'grad_norm': 3.042201042175293, 'learning_rate': 0.0001437543207198086, 'epoch': 1.65}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 2.6592, 'grad_norm': 3.206308364868164, 'learning_rate': 0.0001415257841152961, 'epoch': 1.66}
{'loss': 2.5825, 'grad_norm': 2.834376335144043, 'learning_rate': 0.00013929912252011516, 'epoch': 1.68}
{'loss': 2.7186, 'grad_norm': 2.9710707664489746, 'learning_rate': 0.0001370748286065468, 'epoch': 1.69}
{'loss': 2.4292, 'grad_norm': 2.4457507133483887, 'learning_rate': 0.00013485339452299754, 'epoch': 1.71}
{'loss': 2.8385, 'grad_norm': 2.866943359375, 'learning_rate': 0.00013263531178510647, 'epoch': 1.72}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 2.5841, 'grad_norm': 3.4312779903411865, 'learning_rate': 0.00013042107116699228, 'epoch': 1.73}
{'loss': 2.527, 'grad_norm': 3.25289249420166, 'learning_rate': 0.0001282111625926641, 'epoch': 1.75}
{'loss': 2.4421, 'grad_norm': 2.446511745452881, 'learning_rate': 0.00012600607502762096, 'epoch': 1.76}
{'loss': 2.5528, 'grad_norm': 2.9619667530059814, 'learning_rate': 0.00012380629637066297, 'epoch': 1.77}
{'loss': 2.8119, 'grad_norm': 3.2970638275146484, 'learning_rate': 0.00012161231334593851, 'epoch': 1.79}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 2.5309, 'grad_norm': 2.570861339569092, 'learning_rate': 0.00011942461139525123, 'epoch': 1.8}
{'loss': 2.6826, 'grad_norm': 2.899683952331543, 'learning_rate': 0.00011724367457065135, 'epoch': 1.82}
{'loss': 2.5655, 'grad_norm': 2.263598680496216, 'learning_rate': 0.00011506998542733373, 'epoch': 1.83}
{'loss': 2.6892, 'grad_norm': 2.794584274291992, 'learning_rate': 0.00011290402491686766, 'epoch': 1.84}
{'loss': 2.6247, 'grad_norm': 3.17472767829895, 'learning_rate': 0.0001107462722807811, 'epoch': 1.86}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
