In [2]:
import torch
import evaluate
import numpy as np
from datasets import load_from_disk
from transformers import Trainer, TrainingArguments, ViTForImageClassification
from sklearn.model_selection import KFold

# Load Dataset

In [3]:
dataset = load_from_disk("../data/processed/huggingface")
labels = dataset["train"].features["label"].names

In [21]:
dataset["train"]

Dataset({
    features: ['image', 'label', 'pixel_values'],
    num_rows: 1814
})

In [22]:
dataset["test"]

Dataset({
    features: ['image', 'label', 'pixel_values'],
    num_rows: 454
})

# Training

k-fold cross-validation with incremental learning

In [4]:
accuracy = evaluate.load("accuracy")

def compute_accuracy(p):
    pred = np.argmax(p.predictions, axis=1)
    lab = p.label_ids
    return accuracy.compute(predictions=pred, references=lab)

results_folder = "../models/results"
logs_folder = "../models/logs"

# Define training arguments
training_args = TrainingArguments(
    output_dir= results_folder,
    evaluation_strategy="epoch",
    num_train_epochs=2,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    save_strategy="no",
    logging_dir=logs_folder,
    logging_steps=10,
)

# Initialize with pre-trained model
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)},
    ignore_mismatched_sizes=True,  # Ignore the classifier's size mismatch
)

# Initialize k-fold splits
k = 5
kf = KFold(n_splits=k, shuffle=True)
all_fold_metrics = []

# Perform k-fold training
for fold, (train_idx, val_idx) in enumerate(kf.split(dataset["train"])):
    print(f"Training fold {fold+1}/{k}")

    train_split = dataset["train"].select(train_idx)
    val_split = dataset["train"].select(val_idx)

    training_args.output_dir = f"{results_folder}/fold_{fold}"
    training_args.logging_dir = f"{logs_folder}/fold_{fold}"

    # Trainer instance for this fold
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_split,
        eval_dataset=val_split,
        compute_metrics=compute_accuracy
    )

    trainer.train()
    eval_result = trainer.evaluate()

    print(eval_result)
    all_fold_metrics.append(eval_result)

model.save_pretrained(f"{results_folder}/final_model")

# Evaluate final model on test set
print("Evaluating final model on test set")
test_result = trainer.evaluate(eval_dataset=dataset["test"])
print(test_result)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([63]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([63, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Training fold 1/5


  5%|▌         | 10/182 [04:25<1:16:50, 26.80s/it]

{'loss': 4.1508, 'grad_norm': 7.471713542938232, 'learning_rate': 4.7252747252747257e-05, 'epoch': 0.11}


 11%|█         | 20/182 [08:53<1:13:33, 27.24s/it]

{'loss': 4.0607, 'grad_norm': 6.904857158660889, 'learning_rate': 4.4505494505494504e-05, 'epoch': 0.22}


 16%|█▋        | 30/182 [13:23<1:07:48, 26.77s/it]

{'loss': 3.7408, 'grad_norm': 8.895026206970215, 'learning_rate': 4.1758241758241765e-05, 'epoch': 0.33}


 22%|██▏       | 40/182 [17:49<1:02:36, 26.45s/it]

{'loss': 3.4398, 'grad_norm': 8.596405029296875, 'learning_rate': 3.901098901098901e-05, 'epoch': 0.44}


 27%|██▋       | 50/182 [22:15<58:15, 26.48s/it]  

{'loss': 2.9203, 'grad_norm': 8.512317657470703, 'learning_rate': 3.6263736263736266e-05, 'epoch': 0.55}


 33%|███▎      | 60/182 [26:06<42:38, 20.97s/it]

{'loss': 2.5244, 'grad_norm': 7.572287559509277, 'learning_rate': 3.3516483516483513e-05, 'epoch': 0.66}


 38%|███▊      | 70/182 [29:26<37:12, 19.94s/it]

{'loss': 2.0411, 'grad_norm': 8.584601402282715, 'learning_rate': 3.0769230769230774e-05, 'epoch': 0.77}


 44%|████▍     | 80/182 [33:49<44:26, 26.14s/it]

{'loss': 1.6961, 'grad_norm': 8.757231712341309, 'learning_rate': 2.8021978021978025e-05, 'epoch': 0.88}


 49%|████▉     | 90/182 [38:13<40:07, 26.17s/it]

{'loss': 1.2217, 'grad_norm': 6.988622188568115, 'learning_rate': 2.5274725274725276e-05, 'epoch': 0.99}


                                                
 50%|█████     | 91/182 [41:31<35:45, 23.58s/it]

{'eval_loss': 1.2526980638504028, 'eval_accuracy': 0.8898071625344353, 'eval_runtime': 180.4325, 'eval_samples_per_second': 2.012, 'eval_steps_per_second': 0.127, 'epoch': 1.0}


 55%|█████▍    | 100/182 [44:33<31:38, 23.15s/it] 

{'loss': 0.8782, 'grad_norm': 5.071971416473389, 'learning_rate': 2.252747252747253e-05, 'epoch': 1.1}


 60%|██████    | 110/182 [48:53<31:18, 26.08s/it]

{'loss': 0.7458, 'grad_norm': 4.605737209320068, 'learning_rate': 1.978021978021978e-05, 'epoch': 1.21}


 66%|██████▌   | 120/182 [53:15<27:17, 26.41s/it]

{'loss': 0.5779, 'grad_norm': 4.294705390930176, 'learning_rate': 1.7032967032967035e-05, 'epoch': 1.32}


 71%|███████▏  | 130/182 [57:41<22:45, 26.26s/it]

{'loss': 0.4928, 'grad_norm': 3.843090057373047, 'learning_rate': 1.4285714285714285e-05, 'epoch': 1.43}


 77%|███████▋  | 140/182 [1:02:05<18:21, 26.22s/it]

{'loss': 0.4166, 'grad_norm': 4.349550724029541, 'learning_rate': 1.153846153846154e-05, 'epoch': 1.54}


 82%|████████▏ | 150/182 [1:05:36<10:46, 20.21s/it]

{'loss': 0.4102, 'grad_norm': 4.17403507232666, 'learning_rate': 8.791208791208792e-06, 'epoch': 1.65}


 88%|████████▊ | 160/182 [1:08:56<07:17, 19.88s/it]

{'loss': 0.3644, 'grad_norm': 2.8008954524993896, 'learning_rate': 6.043956043956044e-06, 'epoch': 1.76}


 93%|█████████▎| 170/182 [1:12:15<03:58, 19.84s/it]

{'loss': 0.3552, 'grad_norm': 3.441009044647217, 'learning_rate': 3.2967032967032968e-06, 'epoch': 1.87}


 99%|█████████▉| 180/182 [1:15:34<00:39, 19.83s/it]

{'loss': 0.3139, 'grad_norm': 2.645235538482666, 'learning_rate': 5.494505494505495e-07, 'epoch': 1.98}


                                                   
100%|██████████| 182/182 [1:19:02<00:00, 26.06s/it]


{'eval_loss': 0.4176348149776459, 'eval_accuracy': 0.9724517906336089, 'eval_runtime': 175.5078, 'eval_samples_per_second': 2.068, 'eval_steps_per_second': 0.131, 'epoch': 2.0}
{'train_runtime': 4742.5657, 'train_samples_per_second': 0.612, 'train_steps_per_second': 0.038, 'train_loss': 1.6721074404297294, 'epoch': 2.0}


100%|██████████| 23/23 [03:27<00:00,  9.02s/it]


{'eval_loss': 0.4176348149776459, 'eval_accuracy': 0.9724517906336089, 'eval_runtime': 219.3809, 'eval_samples_per_second': 1.655, 'eval_steps_per_second': 0.105, 'epoch': 2.0}
Training fold 2/5


  5%|▌         | 10/182 [04:25<1:15:16, 26.26s/it]

{'loss': 0.415, 'grad_norm': 5.038876533508301, 'learning_rate': 4.7252747252747257e-05, 'epoch': 0.11}


 11%|█         | 20/182 [08:14<55:47, 20.66s/it]  

{'loss': 0.2794, 'grad_norm': 2.081804037094116, 'learning_rate': 4.4505494505494504e-05, 'epoch': 0.22}


 16%|█▋        | 30/182 [11:53<1:01:35, 24.31s/it]

{'loss': 0.2079, 'grad_norm': 3.814389705657959, 'learning_rate': 4.1758241758241765e-05, 'epoch': 0.33}


 22%|██▏       | 40/182 [16:15<1:01:28, 25.98s/it]

{'loss': 0.1408, 'grad_norm': 1.9292540550231934, 'learning_rate': 3.901098901098901e-05, 'epoch': 0.44}


 27%|██▋       | 50/182 [20:36<57:17, 26.04s/it]  

{'loss': 0.1259, 'grad_norm': 4.656289577484131, 'learning_rate': 3.6263736263736266e-05, 'epoch': 0.55}


 33%|███▎      | 60/182 [24:56<52:24, 25.78s/it]

{'loss': 0.1299, 'grad_norm': 2.448277473449707, 'learning_rate': 3.3516483516483513e-05, 'epoch': 0.66}


 38%|███▊      | 70/182 [29:15<48:01, 25.73s/it]

{'loss': 0.0727, 'grad_norm': 2.0427463054656982, 'learning_rate': 3.0769230769230774e-05, 'epoch': 0.77}


 44%|████▍     | 80/182 [33:35<43:37, 25.66s/it]

{'loss': 0.0606, 'grad_norm': 1.1988260746002197, 'learning_rate': 2.8021978021978025e-05, 'epoch': 0.88}


 49%|████▉     | 90/182 [37:49<38:39, 25.22s/it]

{'loss': 0.0752, 'grad_norm': 0.46147143840789795, 'learning_rate': 2.5274725274725276e-05, 'epoch': 0.99}


                                                
 50%|█████     | 91/182 [41:23<34:44, 22.91s/it]

{'eval_loss': 0.04235069453716278, 'eval_accuracy': 0.9972451790633609, 'eval_runtime': 195.7551, 'eval_samples_per_second': 1.854, 'eval_steps_per_second': 0.117, 'epoch': 1.0}


 55%|█████▍    | 100/182 [44:22<31:47, 23.26s/it] 

{'loss': 0.0265, 'grad_norm': 0.2493004947900772, 'learning_rate': 2.252747252747253e-05, 'epoch': 1.1}


 60%|██████    | 110/182 [47:40<23:44, 19.78s/it]

{'loss': 0.0227, 'grad_norm': 0.7332025170326233, 'learning_rate': 1.978021978021978e-05, 'epoch': 1.21}


 66%|██████▌   | 120/182 [50:58<20:22, 19.72s/it]

{'loss': 0.0257, 'grad_norm': 0.19210655987262726, 'learning_rate': 1.7032967032967035e-05, 'epoch': 1.32}


 71%|███████▏  | 130/182 [54:16<17:04, 19.70s/it]

{'loss': 0.02, 'grad_norm': 0.17589767277240753, 'learning_rate': 1.4285714285714285e-05, 'epoch': 1.43}


 77%|███████▋  | 140/182 [57:34<13:46, 19.67s/it]

{'loss': 0.0145, 'grad_norm': 0.1315748244524002, 'learning_rate': 1.153846153846154e-05, 'epoch': 1.54}


 82%|████████▏ | 150/182 [1:00:52<10:30, 19.71s/it]

{'loss': 0.0197, 'grad_norm': 0.7683666944503784, 'learning_rate': 8.791208791208792e-06, 'epoch': 1.65}


 88%|████████▊ | 160/182 [1:04:10<07:13, 19.72s/it]

{'loss': 0.0154, 'grad_norm': 0.23821432888507843, 'learning_rate': 6.043956043956044e-06, 'epoch': 1.76}


 93%|█████████▎| 170/182 [1:07:28<03:56, 19.74s/it]

{'loss': 0.0137, 'grad_norm': 0.1138196587562561, 'learning_rate': 3.2967032967032968e-06, 'epoch': 1.87}


 99%|█████████▉| 180/182 [1:10:46<00:39, 19.71s/it]

{'loss': 0.0152, 'grad_norm': 0.151548832654953, 'learning_rate': 5.494505494505495e-07, 'epoch': 1.98}


                                                   
100%|██████████| 182/182 [1:14:05<00:00, 24.43s/it]


{'eval_loss': 0.021328803151845932, 'eval_accuracy': 1.0, 'eval_runtime': 165.75, 'eval_samples_per_second': 2.19, 'eval_steps_per_second': 0.139, 'epoch': 2.0}
{'train_runtime': 4445.6179, 'train_samples_per_second': 0.653, 'train_steps_per_second': 0.041, 'train_loss': 0.0925239734351635, 'epoch': 2.0}


100%|██████████| 23/23 [02:36<00:00,  6.81s/it]


{'eval_loss': 0.021328803151845932, 'eval_accuracy': 1.0, 'eval_runtime': 165.4916, 'eval_samples_per_second': 2.193, 'eval_steps_per_second': 0.139, 'epoch': 2.0}
Training fold 3/5


  5%|▌         | 10/182 [03:19<56:38, 19.76s/it] 

{'loss': 0.0166, 'grad_norm': 0.8765193223953247, 'learning_rate': 4.7252747252747257e-05, 'epoch': 0.11}


 11%|█         | 20/182 [06:37<53:15, 19.73s/it]

{'loss': 0.0184, 'grad_norm': 0.3830690085887909, 'learning_rate': 4.4505494505494504e-05, 'epoch': 0.22}


 16%|█▋        | 30/182 [09:56<49:57, 19.72s/it]

{'loss': 0.0308, 'grad_norm': 0.21884013712406158, 'learning_rate': 4.1758241758241765e-05, 'epoch': 0.33}


 22%|██▏       | 40/182 [13:14<46:45, 19.75s/it]

{'loss': 0.0287, 'grad_norm': 0.1296769231557846, 'learning_rate': 3.901098901098901e-05, 'epoch': 0.44}


 27%|██▋       | 50/182 [16:32<43:24, 19.73s/it]

{'loss': 0.0169, 'grad_norm': 0.3122718334197998, 'learning_rate': 3.6263736263736266e-05, 'epoch': 0.55}


 33%|███▎      | 60/182 [19:50<40:01, 19.69s/it]

{'loss': 0.0154, 'grad_norm': 0.4938470423221588, 'learning_rate': 3.3516483516483513e-05, 'epoch': 0.66}


 38%|███▊      | 70/182 [23:09<36:55, 19.78s/it]

{'loss': 0.009, 'grad_norm': 0.4572494626045227, 'learning_rate': 3.0769230769230774e-05, 'epoch': 0.77}


 44%|████▍     | 80/182 [26:27<33:29, 19.70s/it]

{'loss': 0.0052, 'grad_norm': 0.045386239886283875, 'learning_rate': 2.8021978021978025e-05, 'epoch': 0.88}


 49%|████▉     | 90/182 [29:44<29:59, 19.56s/it]

{'loss': 0.0055, 'grad_norm': 0.17156347632408142, 'learning_rate': 2.5274725274725276e-05, 'epoch': 0.99}


 50%|█████     | 91/182 [29:57<26:46, 17.66s/it]
 50%|█████     | 91/182 [32:43<26:46, 17.66s/it]

{'eval_loss': 0.004679068922996521, 'eval_accuracy': 1.0, 'eval_runtime': 166.0163, 'eval_samples_per_second': 2.187, 'eval_steps_per_second': 0.139, 'epoch': 1.0}


 55%|█████▍    | 100/182 [35:43<30:51, 22.58s/it] 

{'loss': 0.0039, 'grad_norm': 0.0645103007555008, 'learning_rate': 2.252747252747253e-05, 'epoch': 1.1}


 60%|██████    | 110/182 [39:02<23:47, 19.82s/it]

{'loss': 0.0031, 'grad_norm': 0.04033539816737175, 'learning_rate': 1.978021978021978e-05, 'epoch': 1.21}


 66%|██████▌   | 120/182 [42:20<20:30, 19.84s/it]

{'loss': 0.0039, 'grad_norm': 0.028808321803808212, 'learning_rate': 1.7032967032967035e-05, 'epoch': 1.32}


 71%|███████▏  | 130/182 [45:39<17:13, 19.87s/it]

{'loss': 0.003, 'grad_norm': 0.026378637179732323, 'learning_rate': 1.4285714285714285e-05, 'epoch': 1.43}


 77%|███████▋  | 140/182 [48:58<13:49, 19.75s/it]

{'loss': 0.0033, 'grad_norm': 0.0350567027926445, 'learning_rate': 1.153846153846154e-05, 'epoch': 1.54}


 82%|████████▏ | 150/182 [52:16<10:32, 19.77s/it]

{'loss': 0.0065, 'grad_norm': 0.032217927277088165, 'learning_rate': 8.791208791208792e-06, 'epoch': 1.65}


 88%|████████▊ | 160/182 [55:34<07:13, 19.72s/it]

{'loss': 0.0028, 'grad_norm': 0.03513607755303383, 'learning_rate': 6.043956043956044e-06, 'epoch': 1.76}


 93%|█████████▎| 170/182 [58:52<03:56, 19.71s/it]

{'loss': 0.003, 'grad_norm': 0.026935026049613953, 'learning_rate': 3.2967032967032968e-06, 'epoch': 1.87}


 99%|█████████▉| 180/182 [1:02:11<00:39, 19.75s/it]

{'loss': 0.003, 'grad_norm': 0.07336150109767914, 'learning_rate': 5.494505494505495e-07, 'epoch': 1.98}


100%|██████████| 182/182 [1:02:44<00:00, 17.79s/it]
100%|██████████| 182/182 [1:05:28<00:00, 21.59s/it]


{'eval_loss': 0.003511958522722125, 'eval_accuracy': 1.0, 'eval_runtime': 164.7977, 'eval_samples_per_second': 2.203, 'eval_steps_per_second': 0.14, 'epoch': 2.0}
{'train_runtime': 3928.8924, 'train_samples_per_second': 0.739, 'train_steps_per_second': 0.046, 'train_loss': 0.00987853146168393, 'epoch': 2.0}


100%|██████████| 23/23 [02:35<00:00,  6.78s/it]


{'eval_loss': 0.003511958522722125, 'eval_accuracy': 1.0, 'eval_runtime': 164.7753, 'eval_samples_per_second': 2.203, 'eval_steps_per_second': 0.14, 'epoch': 2.0}
Training fold 4/5


  5%|▌         | 10/182 [03:19<56:41, 19.78s/it] 

{'loss': 0.003, 'grad_norm': 0.04462822154164314, 'learning_rate': 4.7252747252747257e-05, 'epoch': 0.11}


 11%|█         | 20/182 [06:38<53:26, 19.79s/it]

{'loss': 0.0022, 'grad_norm': 0.057535570114851, 'learning_rate': 4.4505494505494504e-05, 'epoch': 0.22}


 16%|█▋        | 30/182 [09:56<49:59, 19.73s/it]

{'loss': 0.0022, 'grad_norm': 0.032741498202085495, 'learning_rate': 4.1758241758241765e-05, 'epoch': 0.33}


 22%|██▏       | 40/182 [13:14<46:43, 19.74s/it]

{'loss': 0.0024, 'grad_norm': 0.052396323531866074, 'learning_rate': 3.901098901098901e-05, 'epoch': 0.44}


 27%|██▋       | 50/182 [16:32<43:27, 19.75s/it]

{'loss': 0.0017, 'grad_norm': 0.03404027968645096, 'learning_rate': 3.6263736263736266e-05, 'epoch': 0.55}


 33%|███▎      | 60/182 [19:51<40:11, 19.76s/it]

{'loss': 0.006, 'grad_norm': 0.028008855879306793, 'learning_rate': 3.3516483516483513e-05, 'epoch': 0.66}


 38%|███▊      | 70/182 [23:09<36:56, 19.79s/it]

{'loss': 0.0021, 'grad_norm': 0.026888012886047363, 'learning_rate': 3.0769230769230774e-05, 'epoch': 0.77}


 44%|████▍     | 80/182 [26:27<33:39, 19.80s/it]

{'loss': 0.0037, 'grad_norm': 0.01577819138765335, 'learning_rate': 2.8021978021978025e-05, 'epoch': 0.88}


 49%|████▉     | 90/182 [30:52<47:04, 30.71s/it]

{'loss': 0.0015, 'grad_norm': 0.01646994613111019, 'learning_rate': 2.5274725274725276e-05, 'epoch': 0.99}


 50%|█████     | 91/182 [31:18<44:25, 29.29s/it]
 50%|█████     | 91/182 [37:16<44:25, 29.29s/it]

{'eval_loss': 0.0013585083652287722, 'eval_accuracy': 1.0, 'eval_runtime': 358.7331, 'eval_samples_per_second': 1.012, 'eval_steps_per_second': 0.064, 'epoch': 1.0}


 55%|█████▍    | 100/182 [41:49<45:28, 33.28s/it]  

{'loss': 0.001, 'grad_norm': 0.014521874487400055, 'learning_rate': 2.252747252747253e-05, 'epoch': 1.1}


 60%|██████    | 110/182 [46:10<31:27, 26.22s/it]

{'loss': 0.0008, 'grad_norm': 0.017500251531600952, 'learning_rate': 1.978021978021978e-05, 'epoch': 1.21}


 66%|██████▌   | 120/182 [50:07<22:09, 21.44s/it]

{'loss': 0.0008, 'grad_norm': 0.02310907654464245, 'learning_rate': 1.7032967032967035e-05, 'epoch': 1.32}


 71%|███████▏  | 130/182 [53:28<17:17, 19.96s/it]

{'loss': 0.0008, 'grad_norm': 0.0071482411585748196, 'learning_rate': 1.4285714285714285e-05, 'epoch': 1.43}


 77%|███████▋  | 140/182 [56:48<14:00, 20.02s/it]

{'loss': 0.0009, 'grad_norm': 0.005525314249098301, 'learning_rate': 1.153846153846154e-05, 'epoch': 1.54}


 82%|████████▏ | 150/182 [1:00:08<10:37, 19.92s/it]

{'loss': 0.0006, 'grad_norm': 0.005409426521509886, 'learning_rate': 8.791208791208792e-06, 'epoch': 1.65}


 88%|████████▊ | 160/182 [1:03:27<07:16, 19.85s/it]

{'loss': 0.0006, 'grad_norm': 0.009046707302331924, 'learning_rate': 6.043956043956044e-06, 'epoch': 1.76}


 93%|█████████▎| 170/182 [1:06:48<03:59, 19.98s/it]

{'loss': 0.0006, 'grad_norm': 0.013576716184616089, 'learning_rate': 3.2967032967032968e-06, 'epoch': 1.87}


 99%|█████████▉| 180/182 [1:10:07<00:39, 19.89s/it]

{'loss': 0.0007, 'grad_norm': 0.004892315715551376, 'learning_rate': 5.494505494505495e-07, 'epoch': 1.98}


100%|██████████| 182/182 [1:10:40<00:00, 17.87s/it]
100%|██████████| 182/182 [1:13:26<00:00, 24.21s/it]


{'eval_loss': 0.0006699750083498657, 'eval_accuracy': 1.0, 'eval_runtime': 166.0033, 'eval_samples_per_second': 2.187, 'eval_steps_per_second': 0.139, 'epoch': 2.0}
{'train_runtime': 4406.8312, 'train_samples_per_second': 0.659, 'train_steps_per_second': 0.041, 'train_loss': 0.0017345624945157177, 'epoch': 2.0}


100%|██████████| 23/23 [02:37<00:00,  6.83s/it]


{'eval_loss': 0.0006699750083498657, 'eval_accuracy': 1.0, 'eval_runtime': 165.8944, 'eval_samples_per_second': 2.188, 'eval_steps_per_second': 0.139, 'epoch': 2.0}
Training fold 5/5


  5%|▌         | 10/182 [03:20<56:48, 19.82s/it] 

{'loss': 0.0005, 'grad_norm': 0.010252426378428936, 'learning_rate': 4.7252747252747257e-05, 'epoch': 0.11}


 11%|█         | 20/182 [07:12<1:07:23, 24.96s/it]

{'loss': 0.0004, 'grad_norm': 0.0031480155885219574, 'learning_rate': 4.4505494505494504e-05, 'epoch': 0.22}


 16%|█▋        | 30/182 [11:31<1:05:09, 25.72s/it]

{'loss': 0.0003, 'grad_norm': 0.002935146912932396, 'learning_rate': 4.1758241758241765e-05, 'epoch': 0.33}


 22%|██▏       | 40/182 [15:38<54:09, 22.88s/it]  

{'loss': 0.0002, 'grad_norm': 0.00370950810611248, 'learning_rate': 3.901098901098901e-05, 'epoch': 0.44}


 27%|██▋       | 50/182 [18:58<44:08, 20.06s/it]

{'loss': 0.0002, 'grad_norm': 0.0014416826888918877, 'learning_rate': 3.6263736263736266e-05, 'epoch': 0.55}


 33%|███▎      | 60/182 [22:17<40:18, 19.82s/it]

{'loss': 0.0002, 'grad_norm': 0.003719224128872156, 'learning_rate': 3.3516483516483513e-05, 'epoch': 0.66}


 38%|███▊      | 70/182 [25:36<37:01, 19.83s/it]

{'loss': 0.0001, 'grad_norm': 0.001170425210148096, 'learning_rate': 3.0769230769230774e-05, 'epoch': 0.77}


 44%|████▍     | 80/182 [28:55<33:47, 19.88s/it]

{'loss': 0.0001, 'grad_norm': 0.0011604935862123966, 'learning_rate': 2.8021978021978025e-05, 'epoch': 0.88}


 49%|████▉     | 90/182 [32:14<30:20, 19.79s/it]

{'loss': 0.0001, 'grad_norm': 0.0008853195467963815, 'learning_rate': 2.5274725274725276e-05, 'epoch': 0.99}


 50%|█████     | 91/182 [32:29<27:37, 18.21s/it]
 50%|█████     | 91/182 [35:14<27:37, 18.21s/it]

{'eval_loss': 9.694014443084598e-05, 'eval_accuracy': 1.0, 'eval_runtime': 165.4278, 'eval_samples_per_second': 2.188, 'eval_steps_per_second': 0.139, 'epoch': 1.0}


 55%|█████▍    | 100/182 [38:14<30:52, 22.59s/it] 

{'loss': 0.0001, 'grad_norm': 0.000707217724993825, 'learning_rate': 2.252747252747253e-05, 'epoch': 1.1}


 60%|██████    | 110/182 [41:33<23:52, 19.90s/it]

{'loss': 0.0001, 'grad_norm': 0.000756161636672914, 'learning_rate': 1.978021978021978e-05, 'epoch': 1.21}


 66%|██████▌   | 120/182 [44:51<20:25, 19.77s/it]

{'loss': 0.0001, 'grad_norm': 0.0005908960592932999, 'learning_rate': 1.7032967032967035e-05, 'epoch': 1.32}


 71%|███████▏  | 130/182 [48:10<17:09, 19.81s/it]

{'loss': 0.0001, 'grad_norm': 0.0008830704027786851, 'learning_rate': 1.4285714285714285e-05, 'epoch': 1.43}


 77%|███████▋  | 140/182 [51:28<13:50, 19.77s/it]

{'loss': 0.0001, 'grad_norm': 0.0006057813297957182, 'learning_rate': 1.153846153846154e-05, 'epoch': 1.54}


 82%|████████▏ | 150/182 [54:47<10:33, 19.80s/it]

{'loss': 0.0001, 'grad_norm': 0.0006142189959064126, 'learning_rate': 8.791208791208792e-06, 'epoch': 1.65}


 88%|████████▊ | 160/182 [58:05<07:14, 19.77s/it]

{'loss': 0.0001, 'grad_norm': 0.0005752090946771204, 'learning_rate': 6.043956043956044e-06, 'epoch': 1.76}


 93%|█████████▎| 170/182 [1:01:23<03:57, 19.77s/it]

{'loss': 0.0001, 'grad_norm': 0.0005014866474084556, 'learning_rate': 3.2967032967032968e-06, 'epoch': 1.87}


 99%|█████████▉| 180/182 [1:04:42<00:39, 19.84s/it]

{'loss': 0.0001, 'grad_norm': 0.0007713573868386447, 'learning_rate': 5.494505494505495e-07, 'epoch': 1.98}


100%|██████████| 182/182 [1:05:16<00:00, 18.16s/it]
100%|██████████| 182/182 [1:08:00<00:00, 22.42s/it]


{'eval_loss': 7.254808588186279e-05, 'eval_accuracy': 1.0, 'eval_runtime': 164.4935, 'eval_samples_per_second': 2.201, 'eval_steps_per_second': 0.14, 'epoch': 2.0}
{'train_runtime': 4080.9169, 'train_samples_per_second': 0.712, 'train_steps_per_second': 0.045, 'train_loss': 0.00015888688686063287, 'epoch': 2.0}


100%|██████████| 23/23 [02:35<00:00,  6.78s/it]


{'eval_loss': 7.254808588186279e-05, 'eval_accuracy': 1.0, 'eval_runtime': 164.7694, 'eval_samples_per_second': 2.197, 'eval_steps_per_second': 0.14, 'epoch': 2.0}
Evaluating final model on test set


100%|██████████| 29/29 [03:17<00:00,  6.82s/it]

{'eval_loss': 0.012209120206534863, 'eval_accuracy': 0.9977973568281938, 'eval_runtime': 206.7976, 'eval_samples_per_second': 2.195, 'eval_steps_per_second': 0.14, 'epoch': 2.0}





# Inference

Test with a single example from "test" dataset

In [20]:
# Pick and example
ex = dataset["test"][42]
ex_px_data = torch.tensor(ex["pixel_values"]).unsqueeze(0)  # Add batch dimension
expected_class_label = model.config.id2label[str(ex["label"])]

# Perform inference
with torch.no_grad():  # Disable gradient computation during inference
    outputs = model(ex_px_data)
    logits = outputs.logits

# Convert logits to probabilities and get predicted class
probs = torch.softmax(logits, dim=-1)
predicted_class_idx = logits.argmax(-1).item()
predicted_class_label = model.config.id2label[str(predicted_class_idx)]

print(f"Predicted label '{expected_class_label}' is correct: [{expected_class_label == predicted_class_label}]")

Predicted label '2047' is correct: [True]


In [27]:
probs[:,predicted_class_idx].item()

0.9999359846115112