### References

[video](https://www.youtube.com/watch?v=qU7wO02urYU)

In [1]:
from datasets import load_dataset
import torch
from transformers import ViTImageProcessor, TrainingArguments, ViTForImageClassification, Trainer
import os

# model id to be used
model_id = 'google/vit-base-patch16-224-in21k'

# device will determine whether to run the training on GPU or CPU.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

device

device(type='cuda')

In [2]:
dataset_train = load_dataset("imagefolder","PlantVillage", split="train") # streaming=True for lazy loading

dataset_train, dataset_train.features

Resolving data files:   0%|          | 0/20402 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/576 [00:00<?, ?it/s]

(Dataset({
     features: ['image', 'label'],
     num_rows: 20401
 }),
 {'image': Image(mode=None, decode=True, id=None),
  'label': ClassLabel(names=['00 Chilli - Healthy', '01 Chilli - Leaf Curl Virus', '02 Pepper Bell - Bacterial Spot', '03 Pepper Bell - Healthy', '04 Potato - Early Blight', '05 Potato - Healthy', '06 Potato - Late Blight', '07 Tomato - Bacterial Spot', '08 Tomato - Early Blight', '09 Tomato - Healthy', '10 Tomato - Late Blight', '11 Tomato - Leaf Mold', '12 Tomato - Mosaic Virus', '13 Tomato - Septoria Leaf Spot', '14 Tomato - Target Spot', '15 Tomato - Two Spotted Spider Mite', '16 Tomato - Yellow Leaf Curl Virus'], id=None)})

In [3]:
dataset_test = load_dataset("imagefolder", "PlantVillage", split="test")

dataset_test

Resolving data files:   0%|          | 0/20402 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/576 [00:00<?, ?it/s]

Dataset({
    features: ['image', 'label'],
    num_rows: 576
})

In [4]:
dataset_train[0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=200x200>,
 'label': 0}

In [5]:
feature_extractor = ViTImageProcessor.from_pretrained(
    model_id,
    ignore_mismatched_sizes=True
)
feature_extractor

ViTImageProcessor {
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

In [6]:
def preprocess(batch):
    # take a list of PIL images and turn them to pixel values
    images = [img.convert('RGB') if img.mode != 'RGB' else img for img in batch['image']]
    inputs = feature_extractor(
        images,
        return_tensors='pt'
    )
    # include the labels
    inputs['label'] = batch['label']
    return inputs

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

In [7]:
prepared_train = dataset_train.with_transform(preprocess)
prepared_test = dataset_test.with_transform(preprocess)

In [8]:
labels = dataset_train.features['label'].names

model = ViTForImageClassification.from_pretrained(
    model_id,
    ignore_mismatched_sizes= True,
    num_labels=len(labels),
    id2label={str(i): label for i, label in enumerate(labels)},
    label2id={label: str(i) for i, label in enumerate(labels)}
)
model.to(device)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [9]:
training_args = TrainingArguments(
    output_dir="./plantvillage_model",
    per_device_train_batch_size=16,
    eval_strategy="steps",
    num_train_epochs=3,
    save_steps=100,
    eval_steps=100,
    logging_steps=10,
    learning_rate=2e-4,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    load_best_model_at_end=True,
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=prepared_train,
    eval_dataset=prepared_test,
    tokenizer=feature_extractor,
)



In [10]:
if os.path.exists("./plantvillage_checkpoints"):
    # Get the latest checkpoint
    checkpoints = [dir for dir in os.listdir("./plantvillage_checkpoints") if dir.startswith("checkpoint-")]
    if checkpoints:
        latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("-")[1]))
        checkpoint_path = os.path.join("./plantvillage_checkpoints", latest_checkpoint)
        print(f"Resuming training from checkpoint: {checkpoint_path}")
        # Resume training
        train_results = trainer.train(resume_from_checkpoint=checkpoint_path)
    else:
        print("No checkpoint found. Starting training from scratch.")
        train_results = trainer.train()
else:
    print("No checkpoint directory found. Starting training from scratch.")
    train_results = trainer.train()

No checkpoint directory found. Starting training from scratch.


  0%|          | 0/3828 [00:00<?, ?it/s]

{'loss': 2.5494, 'grad_norm': 1.5736089944839478, 'learning_rate': 0.00019947753396029258, 'epoch': 0.01}
{'loss': 2.1203, 'grad_norm': 1.9132421016693115, 'learning_rate': 0.00019895506792058518, 'epoch': 0.02}
{'loss': 1.7233, 'grad_norm': 2.2865209579467773, 'learning_rate': 0.00019843260188087775, 'epoch': 0.02}
{'loss': 1.3617, 'grad_norm': 1.8445883989334106, 'learning_rate': 0.00019791013584117032, 'epoch': 0.03}
{'loss': 1.088, 'grad_norm': 1.5586708784103394, 'learning_rate': 0.0001973876698014629, 'epoch': 0.04}
{'loss': 1.0526, 'grad_norm': 2.6481359004974365, 'learning_rate': 0.0001968652037617555, 'epoch': 0.05}
{'loss': 0.7676, 'grad_norm': 3.6279137134552, 'learning_rate': 0.00019634273772204807, 'epoch': 0.05}
{'loss': 0.8775, 'grad_norm': 2.0404765605926514, 'learning_rate': 0.00019582027168234064, 'epoch': 0.06}
{'loss': 0.7132, 'grad_norm': 1.008786678314209, 'learning_rate': 0.00019529780564263324, 'epoch': 0.07}
{'loss': 0.5184, 'grad_norm': 2.7155203819274902, 'le

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.811065673828125, 'eval_runtime': 7.9477, 'eval_samples_per_second': 72.474, 'eval_steps_per_second': 9.059, 'epoch': 0.08}
{'loss': 0.5544, 'grad_norm': 1.9914830923080444, 'learning_rate': 0.0001942528735632184, 'epoch': 0.09}
{'loss': 0.5146, 'grad_norm': 4.156414031982422, 'learning_rate': 0.00019373040752351098, 'epoch': 0.09}
{'loss': 0.5305, 'grad_norm': 2.6715192794799805, 'learning_rate': 0.00019320794148380358, 'epoch': 0.1}
{'loss': 0.389, 'grad_norm': 0.7182912826538086, 'learning_rate': 0.00019268547544409615, 'epoch': 0.11}
{'loss': 0.3797, 'grad_norm': 0.4292014241218567, 'learning_rate': 0.00019216300940438872, 'epoch': 0.12}
{'loss': 0.4057, 'grad_norm': 0.8226805329322815, 'learning_rate': 0.00019164054336468132, 'epoch': 0.13}
{'loss': 0.3866, 'grad_norm': 2.682779312133789, 'learning_rate': 0.0001911180773249739, 'epoch': 0.13}
{'loss': 0.2756, 'grad_norm': 0.31661897897720337, 'learning_rate': 0.00019059561128526647, 'epoch': 0.14}
{'loss': 0.3227, '

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.46295294165611267, 'eval_runtime': 7.8992, 'eval_samples_per_second': 72.919, 'eval_steps_per_second': 9.115, 'epoch': 0.16}
{'loss': 0.3734, 'grad_norm': 3.1243324279785156, 'learning_rate': 0.0001890282131661442, 'epoch': 0.16}
{'loss': 0.2634, 'grad_norm': 0.5445452332496643, 'learning_rate': 0.00018850574712643678, 'epoch': 0.17}
{'loss': 0.3009, 'grad_norm': 1.6611272096633911, 'learning_rate': 0.00018798328108672938, 'epoch': 0.18}
{'loss': 0.4305, 'grad_norm': 6.616272449493408, 'learning_rate': 0.00018746081504702195, 'epoch': 0.19}
{'loss': 0.2442, 'grad_norm': 3.222721576690674, 'learning_rate': 0.00018693834900731452, 'epoch': 0.2}
{'loss': 0.2477, 'grad_norm': 4.5935139656066895, 'learning_rate': 0.00018641588296760712, 'epoch': 0.2}
{'loss': 0.462, 'grad_norm': 4.241666316986084, 'learning_rate': 0.0001858934169278997, 'epoch': 0.21}
{'loss': 0.3282, 'grad_norm': 1.4599313735961914, 'learning_rate': 0.00018537095088819226, 'epoch': 0.22}
{'loss': 0.2413, 'g

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.5375230312347412, 'eval_runtime': 7.5824, 'eval_samples_per_second': 75.965, 'eval_steps_per_second': 9.496, 'epoch': 0.24}
{'loss': 0.4024, 'grad_norm': 1.0279532670974731, 'learning_rate': 0.00018380355276907, 'epoch': 0.24}
{'loss': 0.1693, 'grad_norm': 1.11663019657135, 'learning_rate': 0.0001832810867293626, 'epoch': 0.25}
{'loss': 0.248, 'grad_norm': 1.0837794542312622, 'learning_rate': 0.00018275862068965518, 'epoch': 0.26}
{'loss': 0.1655, 'grad_norm': 0.43622225522994995, 'learning_rate': 0.00018223615464994778, 'epoch': 0.27}
{'loss': 0.2659, 'grad_norm': 2.0811972618103027, 'learning_rate': 0.00018171368861024035, 'epoch': 0.27}
{'loss': 0.298, 'grad_norm': 5.959192276000977, 'learning_rate': 0.00018119122257053292, 'epoch': 0.28}
{'loss': 0.2644, 'grad_norm': 0.5086017847061157, 'learning_rate': 0.00018066875653082552, 'epoch': 0.29}
{'loss': 0.1761, 'grad_norm': 3.178424835205078, 'learning_rate': 0.0001801462904911181, 'epoch': 0.3}
{'loss': 0.2087, 'grad_

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.34923282265663147, 'eval_runtime': 7.7777, 'eval_samples_per_second': 74.058, 'eval_steps_per_second': 9.257, 'epoch': 0.31}
{'loss': 0.2391, 'grad_norm': 0.7944479584693909, 'learning_rate': 0.00017857889237199583, 'epoch': 0.32}
{'loss': 0.1789, 'grad_norm': 2.3415513038635254, 'learning_rate': 0.0001780564263322884, 'epoch': 0.33}
{'loss': 0.1718, 'grad_norm': 0.4112450182437897, 'learning_rate': 0.00017753396029258098, 'epoch': 0.34}
{'loss': 0.1929, 'grad_norm': 1.9314395189285278, 'learning_rate': 0.00017701149425287358, 'epoch': 0.34}
{'loss': 0.143, 'grad_norm': 0.36084479093551636, 'learning_rate': 0.00017648902821316615, 'epoch': 0.35}
{'loss': 0.1242, 'grad_norm': 0.39720234274864197, 'learning_rate': 0.00017596656217345872, 'epoch': 0.36}
{'loss': 0.1844, 'grad_norm': 2.7724075317382812, 'learning_rate': 0.00017544409613375132, 'epoch': 0.37}
{'loss': 0.1891, 'grad_norm': 0.13133889436721802, 'learning_rate': 0.0001749216300940439, 'epoch': 0.38}
{'loss': 0.

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.27153152227401733, 'eval_runtime': 7.6511, 'eval_samples_per_second': 75.283, 'eval_steps_per_second': 9.41, 'epoch': 0.39}
{'loss': 0.132, 'grad_norm': 6.046752452850342, 'learning_rate': 0.00017335423197492163, 'epoch': 0.4}
{'loss': 0.2335, 'grad_norm': 3.5202386379241943, 'learning_rate': 0.0001728317659352142, 'epoch': 0.41}
{'loss': 0.1742, 'grad_norm': 1.6746879816055298, 'learning_rate': 0.0001723092998955068, 'epoch': 0.42}
{'loss': 0.1881, 'grad_norm': 0.19983364641666412, 'learning_rate': 0.0001717868338557994, 'epoch': 0.42}
{'loss': 0.2141, 'grad_norm': 2.1950554847717285, 'learning_rate': 0.00017126436781609197, 'epoch': 0.43}
{'loss': 0.1701, 'grad_norm': 7.203112602233887, 'learning_rate': 0.00017074190177638455, 'epoch': 0.44}
{'loss': 0.2708, 'grad_norm': 6.515425205230713, 'learning_rate': 0.00017021943573667712, 'epoch': 0.45}
{'loss': 0.2572, 'grad_norm': 1.701904535293579, 'learning_rate': 0.00016969696969696972, 'epoch': 0.45}
{'loss': 0.1581, 'gr

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.30759307742118835, 'eval_runtime': 7.8754, 'eval_samples_per_second': 73.139, 'eval_steps_per_second': 9.142, 'epoch': 0.47}
{'loss': 0.1727, 'grad_norm': 3.843146324157715, 'learning_rate': 0.00016812957157784746, 'epoch': 0.48}
{'loss': 0.0791, 'grad_norm': 0.20995312929153442, 'learning_rate': 0.00016760710553814003, 'epoch': 0.49}
{'loss': 0.2186, 'grad_norm': 4.220780372619629, 'learning_rate': 0.0001670846394984326, 'epoch': 0.49}
{'loss': 0.1289, 'grad_norm': 2.6502511501312256, 'learning_rate': 0.0001665621734587252, 'epoch': 0.5}
{'loss': 0.1439, 'grad_norm': 0.38949310779571533, 'learning_rate': 0.00016603970741901777, 'epoch': 0.51}
{'loss': 0.0492, 'grad_norm': 0.19756561517715454, 'learning_rate': 0.00016551724137931035, 'epoch': 0.52}
{'loss': 0.1264, 'grad_norm': 2.0892951488494873, 'learning_rate': 0.00016499477533960292, 'epoch': 0.53}
{'loss': 0.1236, 'grad_norm': 0.12575498223304749, 'learning_rate': 0.00016447230929989552, 'epoch': 0.53}
{'loss': 0.1

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.20943185687065125, 'eval_runtime': 8.1426, 'eval_samples_per_second': 70.739, 'eval_steps_per_second': 8.842, 'epoch': 0.55}
{'loss': 0.1223, 'grad_norm': 2.995344877243042, 'learning_rate': 0.00016290491118077326, 'epoch': 0.56}
{'loss': 0.1684, 'grad_norm': 6.286434173583984, 'learning_rate': 0.00016238244514106583, 'epoch': 0.56}
{'loss': 0.0954, 'grad_norm': 6.188852310180664, 'learning_rate': 0.0001618599791013584, 'epoch': 0.57}
{'loss': 0.269, 'grad_norm': 3.484454870223999, 'learning_rate': 0.000161337513061651, 'epoch': 0.58}
{'loss': 0.1771, 'grad_norm': 1.7957199811935425, 'learning_rate': 0.0001608150470219436, 'epoch': 0.59}
{'loss': 0.1632, 'grad_norm': 0.5671838521957397, 'learning_rate': 0.00016029258098223617, 'epoch': 0.6}
{'loss': 0.1207, 'grad_norm': 6.671136856079102, 'learning_rate': 0.00015977011494252874, 'epoch': 0.6}
{'loss': 0.1303, 'grad_norm': 6.188453674316406, 'learning_rate': 0.00015924764890282134, 'epoch': 0.61}
{'loss': 0.0489, 'grad_n

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.2543509304523468, 'eval_runtime': 8.1871, 'eval_samples_per_second': 70.354, 'eval_steps_per_second': 8.794, 'epoch': 0.63}
{'loss': 0.15, 'grad_norm': 3.0102126598358154, 'learning_rate': 0.00015768025078369906, 'epoch': 0.63}
{'loss': 0.0684, 'grad_norm': 0.08692537248134613, 'learning_rate': 0.00015715778474399166, 'epoch': 0.64}
{'loss': 0.1355, 'grad_norm': 5.503729820251465, 'learning_rate': 0.00015663531870428423, 'epoch': 0.65}
{'loss': 0.0791, 'grad_norm': 0.043729301542043686, 'learning_rate': 0.0001561128526645768, 'epoch': 0.66}
{'loss': 0.1021, 'grad_norm': 8.379542350769043, 'learning_rate': 0.0001555903866248694, 'epoch': 0.67}
{'loss': 0.1, 'grad_norm': 2.105807304382324, 'learning_rate': 0.00015506792058516197, 'epoch': 0.67}
{'loss': 0.1303, 'grad_norm': 4.429437637329102, 'learning_rate': 0.00015454545454545454, 'epoch': 0.68}
{'loss': 0.0859, 'grad_norm': 0.61519455909729, 'learning_rate': 0.00015402298850574712, 'epoch': 0.69}
{'loss': 0.1244, 'grad

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.23682396113872528, 'eval_runtime': 7.7278, 'eval_samples_per_second': 74.536, 'eval_steps_per_second': 9.317, 'epoch': 0.71}
{'loss': 0.085, 'grad_norm': 1.2954775094985962, 'learning_rate': 0.00015245559038662486, 'epoch': 0.71}
{'loss': 0.0833, 'grad_norm': 0.027222277596592903, 'learning_rate': 0.00015193312434691746, 'epoch': 0.72}
{'loss': 0.1054, 'grad_norm': 0.04538816586136818, 'learning_rate': 0.00015141065830721003, 'epoch': 0.73}
{'loss': 0.0836, 'grad_norm': 6.099230766296387, 'learning_rate': 0.00015088819226750263, 'epoch': 0.74}
{'loss': 0.1202, 'grad_norm': 6.027598857879639, 'learning_rate': 0.0001503657262277952, 'epoch': 0.74}
{'loss': 0.1001, 'grad_norm': 1.5671093463897705, 'learning_rate': 0.0001498432601880878, 'epoch': 0.75}
{'loss': 0.0664, 'grad_norm': 3.236758232116699, 'learning_rate': 0.00014932079414838037, 'epoch': 0.76}
{'loss': 0.18, 'grad_norm': 0.038063984364271164, 'learning_rate': 0.00014879832810867294, 'epoch': 0.77}
{'loss': 0.133

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.14661329984664917, 'eval_runtime': 7.4502, 'eval_samples_per_second': 77.313, 'eval_steps_per_second': 9.664, 'epoch': 0.78}
{'loss': 0.0712, 'grad_norm': 1.8470311164855957, 'learning_rate': 0.00014723092998955069, 'epoch': 0.79}
{'loss': 0.1395, 'grad_norm': 0.09286659955978394, 'learning_rate': 0.00014670846394984328, 'epoch': 0.8}
{'loss': 0.0231, 'grad_norm': 0.355596661567688, 'learning_rate': 0.00014618599791013586, 'epoch': 0.81}
{'loss': 0.1111, 'grad_norm': 2.5051848888397217, 'learning_rate': 0.00014566353187042843, 'epoch': 0.82}
{'loss': 0.0675, 'grad_norm': 5.51124382019043, 'learning_rate': 0.000145141065830721, 'epoch': 0.82}
{'loss': 0.0434, 'grad_norm': 0.02259928360581398, 'learning_rate': 0.0001446185997910136, 'epoch': 0.83}
{'loss': 0.1175, 'grad_norm': 0.07450602203607559, 'learning_rate': 0.00014409613375130617, 'epoch': 0.84}
{'loss': 0.0682, 'grad_norm': 0.03157905861735344, 'learning_rate': 0.00014357366771159874, 'epoch': 0.85}
{'loss': 0.043

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.26989784836769104, 'eval_runtime': 7.7015, 'eval_samples_per_second': 74.79, 'eval_steps_per_second': 9.349, 'epoch': 0.86}
{'loss': 0.1058, 'grad_norm': 0.027064738795161247, 'learning_rate': 0.00014200626959247648, 'epoch': 0.87}
{'loss': 0.0494, 'grad_norm': 0.020756859332323074, 'learning_rate': 0.00014148380355276906, 'epoch': 0.88}
{'loss': 0.0527, 'grad_norm': 0.08638189733028412, 'learning_rate': 0.00014096133751306166, 'epoch': 0.89}
{'loss': 0.0518, 'grad_norm': 0.13641102612018585, 'learning_rate': 0.00014043887147335423, 'epoch': 0.89}
{'loss': 0.1219, 'grad_norm': 0.026504192501306534, 'learning_rate': 0.00013991640543364683, 'epoch': 0.9}
{'loss': 0.1031, 'grad_norm': 1.6737061738967896, 'learning_rate': 0.0001393939393939394, 'epoch': 0.91}
{'loss': 0.0925, 'grad_norm': 0.02726888284087181, 'learning_rate': 0.000138871473354232, 'epoch': 0.92}
{'loss': 0.164, 'grad_norm': 1.9040111303329468, 'learning_rate': 0.00013834900731452457, 'epoch': 0.92}
{'loss':

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.17246948182582855, 'eval_runtime': 7.6266, 'eval_samples_per_second': 75.525, 'eval_steps_per_second': 9.441, 'epoch': 0.94}
{'loss': 0.1849, 'grad_norm': 5.349415302276611, 'learning_rate': 0.0001367816091954023, 'epoch': 0.95}
{'loss': 0.0789, 'grad_norm': 1.317427396774292, 'learning_rate': 0.00013625914315569488, 'epoch': 0.96}
{'loss': 0.0729, 'grad_norm': 1.2050138711929321, 'learning_rate': 0.00013573667711598748, 'epoch': 0.96}
{'loss': 0.0264, 'grad_norm': 0.05380038544535637, 'learning_rate': 0.00013521421107628005, 'epoch': 0.97}
{'loss': 0.0755, 'grad_norm': 0.2694826126098633, 'learning_rate': 0.00013469174503657263, 'epoch': 0.98}
{'loss': 0.1048, 'grad_norm': 0.07259940356016159, 'learning_rate': 0.0001341692789968652, 'epoch': 0.99}
{'loss': 0.0804, 'grad_norm': 5.865521430969238, 'learning_rate': 0.0001336468129571578, 'epoch': 1.0}
{'loss': 0.0338, 'grad_norm': 5.130529880523682, 'learning_rate': 0.00013312434691745037, 'epoch': 1.0}
{'loss': 0.04, 'gr

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.1035054549574852, 'eval_runtime': 8.0948, 'eval_samples_per_second': 71.156, 'eval_steps_per_second': 8.895, 'epoch': 1.02}
{'loss': 0.0268, 'grad_norm': 0.013340204954147339, 'learning_rate': 0.0001315569487983281, 'epoch': 1.03}
{'loss': 0.0695, 'grad_norm': 5.382990837097168, 'learning_rate': 0.00013103448275862068, 'epoch': 1.03}
{'loss': 0.0798, 'grad_norm': 1.344637393951416, 'learning_rate': 0.00013051201671891325, 'epoch': 1.04}
{'loss': 0.0897, 'grad_norm': 0.026908960193395615, 'learning_rate': 0.00012998955067920585, 'epoch': 1.05}
{'loss': 0.0575, 'grad_norm': 5.505439758300781, 'learning_rate': 0.00012946708463949843, 'epoch': 1.06}
{'loss': 0.1101, 'grad_norm': 0.42578840255737305, 'learning_rate': 0.00012894461859979102, 'epoch': 1.07}
{'loss': 0.0564, 'grad_norm': 0.7465565800666809, 'learning_rate': 0.0001284221525600836, 'epoch': 1.07}
{'loss': 0.2036, 'grad_norm': 0.06592925637960434, 'learning_rate': 0.0001278996865203762, 'epoch': 1.08}
{'loss': 0.0

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.1321529597043991, 'eval_runtime': 7.8972, 'eval_samples_per_second': 72.937, 'eval_steps_per_second': 9.117, 'epoch': 1.1}
{'loss': 0.0904, 'grad_norm': 1.355607032775879, 'learning_rate': 0.00012633228840125394, 'epoch': 1.11}
{'loss': 0.1261, 'grad_norm': 0.061568327248096466, 'learning_rate': 0.0001258098223615465, 'epoch': 1.11}
{'loss': 0.0278, 'grad_norm': 0.9808339476585388, 'learning_rate': 0.00012528735632183908, 'epoch': 1.12}
{'loss': 0.0892, 'grad_norm': 3.641172170639038, 'learning_rate': 0.00012476489028213168, 'epoch': 1.13}
{'loss': 0.0765, 'grad_norm': 0.6721726059913635, 'learning_rate': 0.00012424242424242425, 'epoch': 1.14}
{'loss': 0.0668, 'grad_norm': 5.774373531341553, 'learning_rate': 0.00012371995820271682, 'epoch': 1.14}
{'loss': 0.0193, 'grad_norm': 0.03443977236747742, 'learning_rate': 0.00012319749216300942, 'epoch': 1.15}
{'loss': 0.0509, 'grad_norm': 0.024472808465361595, 'learning_rate': 0.000122675026123302, 'epoch': 1.16}
{'loss': 0.010

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.13472314178943634, 'eval_runtime': 8.1528, 'eval_samples_per_second': 70.651, 'eval_steps_per_second': 8.831, 'epoch': 1.18}
{'loss': 0.0792, 'grad_norm': 4.959997177124023, 'learning_rate': 0.00012110762800417974, 'epoch': 1.18}
{'loss': 0.0631, 'grad_norm': 0.010228625498712063, 'learning_rate': 0.00012058516196447231, 'epoch': 1.19}
{'loss': 0.069, 'grad_norm': 0.04553980007767677, 'learning_rate': 0.0001200626959247649, 'epoch': 1.2}
{'loss': 0.0691, 'grad_norm': 0.03898899629712105, 'learning_rate': 0.00011954022988505748, 'epoch': 1.21}
{'loss': 0.0114, 'grad_norm': 0.3059581220149994, 'learning_rate': 0.00011901776384535007, 'epoch': 1.21}
{'loss': 0.0251, 'grad_norm': 2.5174131393432617, 'learning_rate': 0.00011849529780564264, 'epoch': 1.22}
{'loss': 0.0552, 'grad_norm': 1.1565138101577759, 'learning_rate': 0.00011797283176593521, 'epoch': 1.23}
{'loss': 0.0198, 'grad_norm': 0.015519246459007263, 'learning_rate': 0.00011745036572622781, 'epoch': 1.24}
{'loss': 

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.15898673236370087, 'eval_runtime': 7.4387, 'eval_samples_per_second': 77.433, 'eval_steps_per_second': 9.679, 'epoch': 1.25}
{'loss': 0.0509, 'grad_norm': 5.336925983428955, 'learning_rate': 0.00011588296760710555, 'epoch': 1.26}
{'loss': 0.0278, 'grad_norm': 0.036074332892894745, 'learning_rate': 0.00011536050156739812, 'epoch': 1.27}
{'loss': 0.0611, 'grad_norm': 0.023687895387411118, 'learning_rate': 0.00011483803552769071, 'epoch': 1.28}
{'loss': 0.0054, 'grad_norm': 0.03554030507802963, 'learning_rate': 0.00011431556948798328, 'epoch': 1.29}
{'loss': 0.0233, 'grad_norm': 0.014826677739620209, 'learning_rate': 0.00011379310344827588, 'epoch': 1.29}
{'loss': 0.039, 'grad_norm': 0.012637129984796047, 'learning_rate': 0.00011327063740856845, 'epoch': 1.3}
{'loss': 0.0584, 'grad_norm': 4.304856300354004, 'learning_rate': 0.00011274817136886102, 'epoch': 1.31}
{'loss': 0.0321, 'grad_norm': 0.011278118006885052, 'learning_rate': 0.00011222570532915362, 'epoch': 1.32}
{'lo

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.11125727742910385, 'eval_runtime': 7.464, 'eval_samples_per_second': 77.17, 'eval_steps_per_second': 9.646, 'epoch': 1.33}
{'loss': 0.0787, 'grad_norm': 0.027009518817067146, 'learning_rate': 0.00011065830721003134, 'epoch': 1.34}
{'loss': 0.0541, 'grad_norm': 0.011977170594036579, 'learning_rate': 0.00011013584117032394, 'epoch': 1.35}
{'loss': 0.0724, 'grad_norm': 0.03558238595724106, 'learning_rate': 0.00010961337513061651, 'epoch': 1.36}
{'loss': 0.0397, 'grad_norm': 11.342432975769043, 'learning_rate': 0.00010909090909090909, 'epoch': 1.36}
{'loss': 0.0325, 'grad_norm': 0.11951570212841034, 'learning_rate': 0.00010856844305120169, 'epoch': 1.37}
{'loss': 0.0217, 'grad_norm': 0.2707083523273468, 'learning_rate': 0.00010804597701149426, 'epoch': 1.38}
{'loss': 0.0066, 'grad_norm': 1.4399858713150024, 'learning_rate': 0.00010752351097178684, 'epoch': 1.39}
{'loss': 0.0204, 'grad_norm': 0.04738759249448776, 'learning_rate': 0.00010700104493207943, 'epoch': 1.39}
{'loss

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.1586693823337555, 'eval_runtime': 28.0833, 'eval_samples_per_second': 20.51, 'eval_steps_per_second': 2.564, 'epoch': 1.41}
{'loss': 0.0037, 'grad_norm': 0.015234180726110935, 'learning_rate': 0.00010543364681295715, 'epoch': 1.42}
{'loss': 0.0352, 'grad_norm': 0.5712074041366577, 'learning_rate': 0.00010491118077324975, 'epoch': 1.43}
{'loss': 0.0348, 'grad_norm': 6.078217029571533, 'learning_rate': 0.00010438871473354232, 'epoch': 1.43}
{'loss': 0.0367, 'grad_norm': 1.1986409425735474, 'learning_rate': 0.0001038662486938349, 'epoch': 1.44}
{'loss': 0.0068, 'grad_norm': 0.22961173951625824, 'learning_rate': 0.00010334378265412749, 'epoch': 1.45}
{'loss': 0.0656, 'grad_norm': 0.8412068486213684, 'learning_rate': 0.00010282131661442008, 'epoch': 1.46}
{'loss': 0.0384, 'grad_norm': 2.992851734161377, 'learning_rate': 0.00010229885057471265, 'epoch': 1.47}
{'loss': 0.0046, 'grad_norm': 0.11627233773469925, 'learning_rate': 0.00010177638453500522, 'epoch': 1.47}
{'loss': 0.

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.0657556802034378, 'eval_runtime': 7.4439, 'eval_samples_per_second': 77.379, 'eval_steps_per_second': 9.672, 'epoch': 1.49}
{'loss': 0.0416, 'grad_norm': 5.374913215637207, 'learning_rate': 0.00010020898641588296, 'epoch': 1.5}
{'loss': 0.0135, 'grad_norm': 6.722672939300537, 'learning_rate': 9.968652037617555e-05, 'epoch': 1.5}
{'loss': 0.108, 'grad_norm': 3.863485813140869, 'learning_rate': 9.916405433646813e-05, 'epoch': 1.51}
{'loss': 0.0041, 'grad_norm': 0.059663522988557816, 'learning_rate': 9.864158829676072e-05, 'epoch': 1.52}
{'loss': 0.0295, 'grad_norm': 0.03070157766342163, 'learning_rate': 9.81191222570533e-05, 'epoch': 1.53}
{'loss': 0.0249, 'grad_norm': 0.22868306934833527, 'learning_rate': 9.759665621734588e-05, 'epoch': 1.54}
{'loss': 0.0324, 'grad_norm': 0.027927089482545853, 'learning_rate': 9.707419017763846e-05, 'epoch': 1.54}
{'loss': 0.0694, 'grad_norm': 0.37730103731155396, 'learning_rate': 9.655172413793105e-05, 'epoch': 1.55}
{'loss': 0.0317, 'g

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.08151793479919434, 'eval_runtime': 7.6567, 'eval_samples_per_second': 75.228, 'eval_steps_per_second': 9.403, 'epoch': 1.57}
{'loss': 0.0361, 'grad_norm': 0.08461111038923264, 'learning_rate': 9.498432601880878e-05, 'epoch': 1.58}
{'loss': 0.009, 'grad_norm': 0.14547698199748993, 'learning_rate': 9.446185997910136e-05, 'epoch': 1.58}
{'loss': 0.0035, 'grad_norm': 0.025661392137408257, 'learning_rate': 9.393939393939395e-05, 'epoch': 1.59}
{'loss': 0.0027, 'grad_norm': 0.007606533821672201, 'learning_rate': 9.341692789968652e-05, 'epoch': 1.6}
{'loss': 0.0041, 'grad_norm': 0.007040788419544697, 'learning_rate': 9.28944618599791e-05, 'epoch': 1.61}
{'loss': 0.0058, 'grad_norm': 0.012228811159729958, 'learning_rate': 9.237199582027169e-05, 'epoch': 1.61}
{'loss': 0.0031, 'grad_norm': 0.012334192171692848, 'learning_rate': 9.184952978056427e-05, 'epoch': 1.62}
{'loss': 0.0463, 'grad_norm': 9.32374382019043, 'learning_rate': 9.132706374085685e-05, 'epoch': 1.63}
{'loss': 0.0

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.060541458427906036, 'eval_runtime': 7.9443, 'eval_samples_per_second': 72.505, 'eval_steps_per_second': 9.063, 'epoch': 1.65}
{'loss': 0.0041, 'grad_norm': 0.009063112549483776, 'learning_rate': 8.975966562173459e-05, 'epoch': 1.65}
{'loss': 0.0046, 'grad_norm': 0.010461282916367054, 'learning_rate': 8.923719958202717e-05, 'epoch': 1.66}
{'loss': 0.0035, 'grad_norm': 0.010437480174005032, 'learning_rate': 8.871473354231975e-05, 'epoch': 1.67}
{'loss': 0.0065, 'grad_norm': 0.032686349004507065, 'learning_rate': 8.819226750261233e-05, 'epoch': 1.68}
{'loss': 0.0217, 'grad_norm': 0.006377156358212233, 'learning_rate': 8.766980146290492e-05, 'epoch': 1.68}
{'loss': 0.0027, 'grad_norm': 0.022027339786291122, 'learning_rate': 8.71473354231975e-05, 'epoch': 1.69}
{'loss': 0.0522, 'grad_norm': 0.49981313943862915, 'learning_rate': 8.662486938349009e-05, 'epoch': 1.7}
{'loss': 0.0589, 'grad_norm': 4.885700225830078, 'learning_rate': 8.610240334378266e-05, 'epoch': 1.71}
{'loss':

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.094878651201725, 'eval_runtime': 7.616, 'eval_samples_per_second': 75.63, 'eval_steps_per_second': 9.454, 'epoch': 1.72}
{'loss': 0.0149, 'grad_norm': 10.201760292053223, 'learning_rate': 8.45350052246604e-05, 'epoch': 1.73}
{'loss': 0.0053, 'grad_norm': 2.7734475135803223, 'learning_rate': 8.401253918495299e-05, 'epoch': 1.74}
{'loss': 0.0347, 'grad_norm': 0.007352403365075588, 'learning_rate': 8.349007314524556e-05, 'epoch': 1.75}
{'loss': 0.0045, 'grad_norm': 0.005898187402635813, 'learning_rate': 8.296760710553814e-05, 'epoch': 1.76}
{'loss': 0.0026, 'grad_norm': 0.006251541431993246, 'learning_rate': 8.244514106583072e-05, 'epoch': 1.76}
{'loss': 0.0035, 'grad_norm': 0.008176211267709732, 'learning_rate': 8.19226750261233e-05, 'epoch': 1.77}
{'loss': 0.0149, 'grad_norm': 7.986173152923584, 'learning_rate': 8.140020898641589e-05, 'epoch': 1.78}
{'loss': 0.0072, 'grad_norm': 2.1511776447296143, 'learning_rate': 8.087774294670847e-05, 'epoch': 1.79}
{'loss': 0.0019, '

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.10194958746433258, 'eval_runtime': 7.7616, 'eval_samples_per_second': 74.211, 'eval_steps_per_second': 9.276, 'epoch': 1.8}
{'loss': 0.0645, 'grad_norm': 0.011056152172386646, 'learning_rate': 7.931034482758621e-05, 'epoch': 1.81}
{'loss': 0.0025, 'grad_norm': 0.006033923476934433, 'learning_rate': 7.878787878787879e-05, 'epoch': 1.82}
{'loss': 0.0041, 'grad_norm': 0.006800787523388863, 'learning_rate': 7.826541274817137e-05, 'epoch': 1.83}
{'loss': 0.0131, 'grad_norm': 2.7846500873565674, 'learning_rate': 7.774294670846394e-05, 'epoch': 1.83}
{'loss': 0.0048, 'grad_norm': 0.012917822226881981, 'learning_rate': 7.722048066875653e-05, 'epoch': 1.84}
{'loss': 0.0237, 'grad_norm': 0.01738261803984642, 'learning_rate': 7.669801462904911e-05, 'epoch': 1.85}
{'loss': 0.0023, 'grad_norm': 0.00777400890365243, 'learning_rate': 7.61755485893417e-05, 'epoch': 1.86}
{'loss': 0.034, 'grad_norm': 0.006071976386010647, 'learning_rate': 7.565308254963429e-05, 'epoch': 1.87}
{'loss': 0

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.058771762996912, 'eval_runtime': 7.6367, 'eval_samples_per_second': 75.426, 'eval_steps_per_second': 9.428, 'epoch': 1.88}
{'loss': 0.0051, 'grad_norm': 0.005996545311063528, 'learning_rate': 7.408568443051203e-05, 'epoch': 1.89}
{'loss': 0.013, 'grad_norm': 0.6856368780136108, 'learning_rate': 7.35632183908046e-05, 'epoch': 1.9}
{'loss': 0.0025, 'grad_norm': 0.018279504030942917, 'learning_rate': 7.304075235109719e-05, 'epoch': 1.9}
{'loss': 0.0043, 'grad_norm': 0.007061623968183994, 'learning_rate': 7.251828631138976e-05, 'epoch': 1.91}
{'loss': 0.0393, 'grad_norm': 0.004808466415852308, 'learning_rate': 7.199582027168234e-05, 'epoch': 1.92}
{'loss': 0.0249, 'grad_norm': 0.0041197314858436584, 'learning_rate': 7.147335423197491e-05, 'epoch': 1.93}
{'loss': 0.0018, 'grad_norm': 0.00760581623762846, 'learning_rate': 7.09508881922675e-05, 'epoch': 1.94}
{'loss': 0.0344, 'grad_norm': 0.005200102459639311, 'learning_rate': 7.04284221525601e-05, 'epoch': 1.94}
{'loss': 0.00

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.05883385241031647, 'eval_runtime': 7.9968, 'eval_samples_per_second': 72.029, 'eval_steps_per_second': 9.004, 'epoch': 1.96}
{'loss': 0.0131, 'grad_norm': 0.008777677081525326, 'learning_rate': 6.886102403343783e-05, 'epoch': 1.97}
{'loss': 0.0374, 'grad_norm': 0.0069143460132181644, 'learning_rate': 6.833855799373041e-05, 'epoch': 1.97}
{'loss': 0.0021, 'grad_norm': 0.007198363076895475, 'learning_rate': 6.781609195402298e-05, 'epoch': 1.98}
{'loss': 0.005, 'grad_norm': 0.02337276190519333, 'learning_rate': 6.729362591431557e-05, 'epoch': 1.99}
{'loss': 0.1204, 'grad_norm': 0.006326479837298393, 'learning_rate': 6.677115987460816e-05, 'epoch': 2.0}
{'loss': 0.0016, 'grad_norm': 0.005657107103615999, 'learning_rate': 6.624869383490073e-05, 'epoch': 2.01}
{'loss': 0.0066, 'grad_norm': 0.007888296619057655, 'learning_rate': 6.572622779519331e-05, 'epoch': 2.01}
{'loss': 0.002, 'grad_norm': 0.01220399048179388, 'learning_rate': 6.52037617554859e-05, 'epoch': 2.02}
{'loss':

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.06500769406557083, 'eval_runtime': 7.8782, 'eval_samples_per_second': 73.113, 'eval_steps_per_second': 9.139, 'epoch': 2.04}
{'loss': 0.0017, 'grad_norm': 0.008136472664773464, 'learning_rate': 6.363636363636364e-05, 'epoch': 2.05}
{'loss': 0.0046, 'grad_norm': 0.00616668863222003, 'learning_rate': 6.311389759665623e-05, 'epoch': 2.05}
{'loss': 0.0425, 'grad_norm': 0.008776778355240822, 'learning_rate': 6.25914315569488e-05, 'epoch': 2.06}
{'loss': 0.0035, 'grad_norm': 0.0067460970021784306, 'learning_rate': 6.206896551724138e-05, 'epoch': 2.07}
{'loss': 0.01, 'grad_norm': 0.004684227053076029, 'learning_rate': 6.154649947753396e-05, 'epoch': 2.08}
{'loss': 0.0024, 'grad_norm': 0.014263673685491085, 'learning_rate': 6.102403343782655e-05, 'epoch': 2.08}
{'loss': 0.0029, 'grad_norm': 0.9970391988754272, 'learning_rate': 6.050156739811913e-05, 'epoch': 2.09}
{'loss': 0.0457, 'grad_norm': 0.004902213346213102, 'learning_rate': 5.9979101358411704e-05, 'epoch': 2.1}
{'loss':

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.07348629832267761, 'eval_runtime': 7.718, 'eval_samples_per_second': 74.63, 'eval_steps_per_second': 9.329, 'epoch': 2.12}
{'loss': 0.0017, 'grad_norm': 0.020748179405927658, 'learning_rate': 5.841170323928945e-05, 'epoch': 2.12}
{'loss': 0.0162, 'grad_norm': 0.008781244046986103, 'learning_rate': 5.7889237199582026e-05, 'epoch': 2.13}
{'loss': 0.0068, 'grad_norm': 0.06717163324356079, 'learning_rate': 5.736677115987461e-05, 'epoch': 2.14}
{'loss': 0.0019, 'grad_norm': 0.005841807462275028, 'learning_rate': 5.6844305120167196e-05, 'epoch': 2.15}
{'loss': 0.0017, 'grad_norm': 0.033633340150117874, 'learning_rate': 5.632183908045977e-05, 'epoch': 2.16}
{'loss': 0.0018, 'grad_norm': 0.00458553247153759, 'learning_rate': 5.5799373040752354e-05, 'epoch': 2.16}
{'loss': 0.0018, 'grad_norm': 0.004277497995644808, 'learning_rate': 5.527690700104493e-05, 'epoch': 2.17}
{'loss': 0.0019, 'grad_norm': 0.003866010345518589, 'learning_rate': 5.475444096133752e-05, 'epoch': 2.18}
{'lo

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.07132481038570404, 'eval_runtime': 7.3651, 'eval_samples_per_second': 78.207, 'eval_steps_per_second': 9.776, 'epoch': 2.19}
{'loss': 0.0023, 'grad_norm': 0.01090237032622099, 'learning_rate': 5.318704284221526e-05, 'epoch': 2.2}
{'loss': 0.0079, 'grad_norm': 0.039902858436107635, 'learning_rate': 5.266457680250784e-05, 'epoch': 2.21}
{'loss': 0.0043, 'grad_norm': 0.004606783390045166, 'learning_rate': 5.2142110762800424e-05, 'epoch': 2.22}
{'loss': 0.0014, 'grad_norm': 0.004344211425632238, 'learning_rate': 5.1619644723092996e-05, 'epoch': 2.23}
{'loss': 0.0088, 'grad_norm': 0.005527963396161795, 'learning_rate': 5.109717868338558e-05, 'epoch': 2.23}
{'loss': 0.0015, 'grad_norm': 0.007773634511977434, 'learning_rate': 5.057471264367817e-05, 'epoch': 2.24}
{'loss': 0.0013, 'grad_norm': 0.0044192238710820675, 'learning_rate': 5.0052246603970745e-05, 'epoch': 2.25}
{'loss': 0.0014, 'grad_norm': 0.009840392507612705, 'learning_rate': 4.9529780564263324e-05, 'epoch': 2.26}


  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.053871039301157, 'eval_runtime': 7.4296, 'eval_samples_per_second': 77.528, 'eval_steps_per_second': 9.691, 'epoch': 2.27}
{'loss': 0.0011, 'grad_norm': 0.006205760408192873, 'learning_rate': 4.7962382445141066e-05, 'epoch': 2.28}
{'loss': 0.0012, 'grad_norm': 0.003113804617896676, 'learning_rate': 4.743991640543365e-05, 'epoch': 2.29}
{'loss': 0.0014, 'grad_norm': 0.005142162553966045, 'learning_rate': 4.691745036572623e-05, 'epoch': 2.3}
{'loss': 0.0087, 'grad_norm': 8.409506797790527, 'learning_rate': 4.639498432601881e-05, 'epoch': 2.3}
{'loss': 0.0295, 'grad_norm': 0.009684483520686626, 'learning_rate': 4.5872518286311394e-05, 'epoch': 2.31}
{'loss': 0.0079, 'grad_norm': 0.04902331158518791, 'learning_rate': 4.535005224660397e-05, 'epoch': 2.32}
{'loss': 0.0105, 'grad_norm': 0.004060741513967514, 'learning_rate': 4.482758620689655e-05, 'epoch': 2.33}
{'loss': 0.009, 'grad_norm': 6.375670909881592, 'learning_rate': 4.430512016718914e-05, 'epoch': 2.34}
{'loss': 0.00

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.06272099912166595, 'eval_runtime': 7.889, 'eval_samples_per_second': 73.013, 'eval_steps_per_second': 9.127, 'epoch': 2.35}
{'loss': 0.0015, 'grad_norm': 0.0067363339476287365, 'learning_rate': 4.273772204806687e-05, 'epoch': 2.36}
{'loss': 0.0015, 'grad_norm': 0.004887898452579975, 'learning_rate': 4.221525600835946e-05, 'epoch': 2.37}
{'loss': 0.0071, 'grad_norm': 0.005628807470202446, 'learning_rate': 4.1692789968652043e-05, 'epoch': 2.37}
{'loss': 0.0015, 'grad_norm': 0.010895175859332085, 'learning_rate': 4.117032392894462e-05, 'epoch': 2.38}
{'loss': 0.0021, 'grad_norm': 0.005600111559033394, 'learning_rate': 4.06478578892372e-05, 'epoch': 2.39}
{'loss': 0.0511, 'grad_norm': 0.003928138874471188, 'learning_rate': 4.012539184952978e-05, 'epoch': 2.4}
{'loss': 0.0485, 'grad_norm': 0.005031864158809185, 'learning_rate': 3.960292580982236e-05, 'epoch': 2.41}
{'loss': 0.0015, 'grad_norm': 0.028445469215512276, 'learning_rate': 3.908045977011495e-05, 'epoch': 2.41}
{'lo

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.0560675710439682, 'eval_runtime': 7.4683, 'eval_samples_per_second': 77.126, 'eval_steps_per_second': 9.641, 'epoch': 2.43}
{'loss': 0.0021, 'grad_norm': 0.009597845375537872, 'learning_rate': 3.7513061650992686e-05, 'epoch': 2.44}
{'loss': 0.0013, 'grad_norm': 0.0058478694409132, 'learning_rate': 3.6990595611285264e-05, 'epoch': 2.45}
{'loss': 0.0013, 'grad_norm': 0.0045097460970282555, 'learning_rate': 3.646812957157785e-05, 'epoch': 2.45}
{'loss': 0.0015, 'grad_norm': 0.013563988730311394, 'learning_rate': 3.5945663531870435e-05, 'epoch': 2.46}
{'loss': 0.0072, 'grad_norm': 0.20232631266117096, 'learning_rate': 3.5423197492163014e-05, 'epoch': 2.47}
{'loss': 0.0037, 'grad_norm': 0.0035689170472323895, 'learning_rate': 3.490073145245559e-05, 'epoch': 2.48}
{'loss': 0.0013, 'grad_norm': 0.0034087635576725006, 'learning_rate': 3.437826541274817e-05, 'epoch': 2.48}
{'loss': 0.0013, 'grad_norm': 0.006427367217838764, 'learning_rate': 3.3855799373040756e-05, 'epoch': 2.49}

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.06309136003255844, 'eval_runtime': 7.4449, 'eval_samples_per_second': 77.368, 'eval_steps_per_second': 9.671, 'epoch': 2.51}
{'loss': 0.0209, 'grad_norm': 0.004251124337315559, 'learning_rate': 3.22884012539185e-05, 'epoch': 2.52}
{'loss': 0.0018, 'grad_norm': 0.01826915703713894, 'learning_rate': 3.176593521421108e-05, 'epoch': 2.52}
{'loss': 0.0013, 'grad_norm': 0.003419796470552683, 'learning_rate': 3.1243469174503656e-05, 'epoch': 2.53}
{'loss': 0.0013, 'grad_norm': 0.0044568064622581005, 'learning_rate': 3.072100313479624e-05, 'epoch': 2.54}
{'loss': 0.0012, 'grad_norm': 0.00416027195751667, 'learning_rate': 3.019853709508882e-05, 'epoch': 2.55}
{'loss': 0.0013, 'grad_norm': 0.00730745168402791, 'learning_rate': 2.96760710553814e-05, 'epoch': 2.55}
{'loss': 0.0372, 'grad_norm': 0.0032188338227570057, 'learning_rate': 2.9153605015673984e-05, 'epoch': 2.56}
{'loss': 0.0012, 'grad_norm': 0.004143284168094397, 'learning_rate': 2.8631138975966566e-05, 'epoch': 2.57}
{'l

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.055200543254613876, 'eval_runtime': 7.4904, 'eval_samples_per_second': 76.898, 'eval_steps_per_second': 9.612, 'epoch': 2.59}
{'loss': 0.0013, 'grad_norm': 0.00293457112275064, 'learning_rate': 2.7063740856844305e-05, 'epoch': 2.59}
{'loss': 0.0012, 'grad_norm': 0.0035222184378653765, 'learning_rate': 2.6541274817136884e-05, 'epoch': 2.6}
{'loss': 0.0011, 'grad_norm': 0.003441382432356477, 'learning_rate': 2.601880877742947e-05, 'epoch': 2.61}
{'loss': 0.0014, 'grad_norm': 0.004808458499610424, 'learning_rate': 2.549634273772205e-05, 'epoch': 2.62}
{'loss': 0.0011, 'grad_norm': 0.004806511104106903, 'learning_rate': 2.497387669801463e-05, 'epoch': 2.63}
{'loss': 0.0013, 'grad_norm': 0.017148269340395927, 'learning_rate': 2.4451410658307212e-05, 'epoch': 2.63}
{'loss': 0.0011, 'grad_norm': 0.004004401154816151, 'learning_rate': 2.392894461859979e-05, 'epoch': 2.64}
{'loss': 0.0015, 'grad_norm': 0.0037864067126065493, 'learning_rate': 2.3406478578892372e-05, 'epoch': 2.65

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.05568113550543785, 'eval_runtime': 8.296, 'eval_samples_per_second': 69.431, 'eval_steps_per_second': 8.679, 'epoch': 2.66}
{'loss': 0.0017, 'grad_norm': 1.0392160415649414, 'learning_rate': 2.183908045977012e-05, 'epoch': 2.67}
{'loss': 0.0011, 'grad_norm': 0.0036945841275155544, 'learning_rate': 2.1316614420062697e-05, 'epoch': 2.68}
{'loss': 0.0012, 'grad_norm': 0.004460788331925869, 'learning_rate': 2.079414838035528e-05, 'epoch': 2.69}
{'loss': 0.0014, 'grad_norm': 0.0030522355809807777, 'learning_rate': 2.027168234064786e-05, 'epoch': 2.7}
{'loss': 0.0018, 'grad_norm': 0.0027197939343750477, 'learning_rate': 1.974921630094044e-05, 'epoch': 2.7}
{'loss': 0.0411, 'grad_norm': 0.0033104312606155872, 'learning_rate': 1.922675026123302e-05, 'epoch': 2.71}
{'loss': 0.0013, 'grad_norm': 0.010486757382750511, 'learning_rate': 1.8704284221525603e-05, 'epoch': 2.72}
{'loss': 0.0012, 'grad_norm': 0.009942792356014252, 'learning_rate': 1.8181818181818182e-05, 'epoch': 2.73}
{

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.05285092070698738, 'eval_runtime': 7.6616, 'eval_samples_per_second': 75.181, 'eval_steps_per_second': 9.398, 'epoch': 2.74}
{'loss': 0.001, 'grad_norm': 0.03942130133509636, 'learning_rate': 1.6614420062695925e-05, 'epoch': 2.75}
{'loss': 0.0021, 'grad_norm': 0.008209452964365482, 'learning_rate': 1.6091954022988507e-05, 'epoch': 2.76}
{'loss': 0.0012, 'grad_norm': 0.003445032751187682, 'learning_rate': 1.5569487983281085e-05, 'epoch': 2.77}
{'loss': 0.0011, 'grad_norm': 0.0030758429784327745, 'learning_rate': 1.5047021943573669e-05, 'epoch': 2.77}
{'loss': 0.0015, 'grad_norm': 0.0044005922973155975, 'learning_rate': 1.452455590386625e-05, 'epoch': 2.78}
{'loss': 0.0016, 'grad_norm': 0.009105256758630276, 'learning_rate': 1.400208986415883e-05, 'epoch': 2.79}
{'loss': 0.001, 'grad_norm': 0.005555335897952318, 'learning_rate': 1.3479623824451411e-05, 'epoch': 2.8}
{'loss': 0.0014, 'grad_norm': 0.008146214298903942, 'learning_rate': 1.2957157784743992e-05, 'epoch': 2.81}

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.0529453344643116, 'eval_runtime': 7.6328, 'eval_samples_per_second': 75.464, 'eval_steps_per_second': 9.433, 'epoch': 2.82}
{'loss': 0.0011, 'grad_norm': 0.00280874059535563, 'learning_rate': 1.1389759665621736e-05, 'epoch': 2.83}
{'loss': 0.0226, 'grad_norm': 0.012960979714989662, 'learning_rate': 1.0867293625914315e-05, 'epoch': 2.84}
{'loss': 0.0011, 'grad_norm': 0.0036664933431893587, 'learning_rate': 1.0344827586206897e-05, 'epoch': 2.84}
{'loss': 0.0095, 'grad_norm': 0.0030029420740902424, 'learning_rate': 9.822361546499479e-06, 'epoch': 2.85}
{'loss': 0.0013, 'grad_norm': 0.002860903274267912, 'learning_rate': 9.299895506792059e-06, 'epoch': 2.86}
{'loss': 0.015, 'grad_norm': 0.0090488838031888, 'learning_rate': 8.77742946708464e-06, 'epoch': 2.87}
{'loss': 0.001, 'grad_norm': 0.002948290202766657, 'learning_rate': 8.254963427377221e-06, 'epoch': 2.88}
{'loss': 0.0011, 'grad_norm': 0.005187514238059521, 'learning_rate': 7.732497387669801e-06, 'epoch': 2.88}
{'los

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.05453566834330559, 'eval_runtime': 7.6206, 'eval_samples_per_second': 75.585, 'eval_steps_per_second': 9.448, 'epoch': 2.9}
{'loss': 0.001, 'grad_norm': 0.0026585387531667948, 'learning_rate': 6.165099268547545e-06, 'epoch': 2.91}
{'loss': 0.0013, 'grad_norm': 0.0026715446729213, 'learning_rate': 5.642633228840126e-06, 'epoch': 2.92}
{'loss': 0.0011, 'grad_norm': 0.002514324616640806, 'learning_rate': 5.120167189132706e-06, 'epoch': 2.92}
{'loss': 0.0267, 'grad_norm': 0.00951069314032793, 'learning_rate': 4.5977011494252875e-06, 'epoch': 2.93}
{'loss': 0.0015, 'grad_norm': 0.004121039062738419, 'learning_rate': 4.075235109717869e-06, 'epoch': 2.94}
{'loss': 0.0011, 'grad_norm': 0.003485678927972913, 'learning_rate': 3.5527690700104498e-06, 'epoch': 2.95}
{'loss': 0.0012, 'grad_norm': 0.003161568893119693, 'learning_rate': 3.0303030303030305e-06, 'epoch': 2.95}
{'loss': 0.0322, 'grad_norm': 5.592681884765625, 'learning_rate': 2.507836990595611e-06, 'epoch': 2.96}
{'loss'

  0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.05512528494000435, 'eval_runtime': 7.6198, 'eval_samples_per_second': 75.592, 'eval_steps_per_second': 9.449, 'epoch': 2.98}
{'loss': 0.0083, 'grad_norm': 0.002920000348240137, 'learning_rate': 9.404388714733543e-07, 'epoch': 2.99}
{'loss': 0.0011, 'grad_norm': 0.00848575308918953, 'learning_rate': 4.179728317659353e-07, 'epoch': 2.99}
{'train_runtime': 1472.1349, 'train_samples_per_second': 41.574, 'train_steps_per_second': 2.6, 'train_loss': 0.10247549182473115, 'epoch': 3.0}


In [11]:
trainer.save_model("./plantvillage_final_model")
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

***** train metrics *****
  epoch                    =          3.0
  total_flos               = 4417616739GF
  train_loss               =       0.1025
  train_runtime            =   0:24:32.13
  train_samples_per_second =       41.574
  train_steps_per_second   =          2.6


In [None]:
dataset_test = load_dataset("imagefolder", "PlantVillage", split="test")
# model_name_or_path = './plantvillage_model/checkpoint-1300/'
model_name_or_path = './plantvillage_final_model/'
model_finetuned = ViTForImageClassification.from_pretrained(model_name_or_path)
# import features
feature_extractor_finetuned = ViTImageProcessor.from_pretrained(model_name_or_path)

In [None]:
model_finetuned.eval()

# Initialize lists to store true and predicted labels
true_labels = []
predicted_labels = []

missed = []

# Iterate through the test dataset
for sample in dataset_test:
    image = sample['image']
    label = sample['label']

    # Extract features from the image
    inputs = feature_extractor_finetuned(image, return_tensors="pt")

    # Get model predictions
    with torch.no_grad():
        logits = model_finetuned(**inputs).logits
    
    # Get the predicted label
    predicted_label = logits.argmax(-1).item()
    # print(predicted_label)

    if (label != predicted_label):
        missed.append((image,label,))
    
    # Append the true and predicted labels
    true_labels.append(label)
    predicted_labels.append(predicted_label)

# Convert lists to numpy arrays
true_labels = np.array(true_labels)
predicted_labels = np.array(predicted_labels)

In [None]:
# Get the label names
label_names = dataset_test.features['label'].names

# Compute the confusion matrix
conf_matrix = confusion_matrix(true_labels, predicted_labels)

# Display the confusion matrix
disp = ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=label_names)
disp.plot(cmap=plt.cm.Blues, xticks_rotation='vertical')

# Show the plot
plt.show()