In [1]:
import torch
import evaluate
import numpy as np
from datasets import load_from_disk
from transformers import Trainer, TrainingArguments, ViTForImageClassification

  from .autonotebook import tqdm as notebook_tqdm


# Load Dataset

In [2]:
dataset = load_from_disk("../data/processed/huggingface")
labels = dataset["train"].features["label"].names

# Training

In [3]:
accuracy = evaluate.load("accuracy")

def compute_accuracy(p):
    pred = np.argmax(p.predictions, axis=1)
    lab = p.label_ids
    return accuracy.compute(predictions=pred, references=lab)

In [4]:
# Load the pre-trained model
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)},
    ignore_mismatched_sizes=True,  # Ignore the classifier's size mismatch
)

# Define training arguments
training_args = TrainingArguments(
    output_dir="../models/results",
    evaluation_strategy="steps",
    eval_steps=100,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    save_strategy="epoch",
    save_total_limit=2,
    logging_dir="../models/logs",
    logging_steps=10,
)

# Create Trainer instance
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    compute_metrics=compute_accuracy
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([63]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([63, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
# Train the model
trainer.train()

  3%|▎         | 10/342 [04:53<2:40:12, 28.95s/it]

{'loss': 4.1874, 'grad_norm': 8.067720413208008, 'learning_rate': 4.853801169590643e-05, 'epoch': 0.09}


  6%|▌         | 20/342 [09:42<2:33:25, 28.59s/it]

{'loss': 3.9065, 'grad_norm': 8.660638809204102, 'learning_rate': 4.707602339181287e-05, 'epoch': 0.18}


  9%|▉         | 30/342 [13:06<1:45:28, 20.28s/it]

{'loss': 3.6774, 'grad_norm': 9.284564018249512, 'learning_rate': 4.56140350877193e-05, 'epoch': 0.26}


 12%|█▏        | 40/342 [16:27<1:41:46, 20.22s/it]

{'loss': 3.3542, 'grad_norm': 8.14759635925293, 'learning_rate': 4.4152046783625734e-05, 'epoch': 0.35}


 15%|█▍        | 50/342 [19:46<1:36:30, 19.83s/it]

{'loss': 2.7119, 'grad_norm': 8.0725736618042, 'learning_rate': 4.269005847953216e-05, 'epoch': 0.44}


 18%|█▊        | 60/342 [23:08<1:34:27, 20.10s/it]

{'loss': 2.2976, 'grad_norm': 7.691367149353027, 'learning_rate': 4.12280701754386e-05, 'epoch': 0.53}


 20%|██        | 70/342 [26:28<1:29:55, 19.84s/it]

{'loss': 1.7508, 'grad_norm': 7.591429233551025, 'learning_rate': 3.976608187134503e-05, 'epoch': 0.61}


 23%|██▎       | 80/342 [29:48<1:26:43, 19.86s/it]

{'loss': 1.3877, 'grad_norm': 7.125931262969971, 'learning_rate': 3.8304093567251465e-05, 'epoch': 0.7}


 26%|██▋       | 90/342 [33:09<1:25:06, 20.26s/it]

{'loss': 1.079, 'grad_norm': 5.668850898742676, 'learning_rate': 3.6842105263157895e-05, 'epoch': 0.79}


 29%|██▉       | 100/342 [36:29<1:20:06, 19.86s/it]

{'loss': 0.7651, 'grad_norm': 4.3466997146606445, 'learning_rate': 3.538011695906433e-05, 'epoch': 0.88}


                                                   
 29%|██▉       | 100/342 [39:58<1:20:06, 19.86s/it]

{'eval_loss': 0.5862319469451904, 'eval_accuracy': 0.9515418502202643, 'eval_runtime': 208.9474, 'eval_samples_per_second': 2.173, 'eval_steps_per_second': 0.139, 'epoch': 0.88}


 32%|███▏      | 110/342 [43:19<1:26:30, 22.37s/it]

{'loss': 0.5687, 'grad_norm': 6.023655414581299, 'learning_rate': 3.391812865497076e-05, 'epoch': 0.96}


 35%|███▌      | 120/342 [46:30<1:13:00, 19.73s/it]

{'loss': 0.3875, 'grad_norm': 3.1926794052124023, 'learning_rate': 3.24561403508772e-05, 'epoch': 1.05}


 38%|███▊      | 130/342 [49:51<1:11:20, 20.19s/it]

{'loss': 0.2432, 'grad_norm': 3.69233775138855, 'learning_rate': 3.0994152046783626e-05, 'epoch': 1.14}


 41%|████      | 140/342 [53:09<1:06:36, 19.78s/it]

{'loss': 0.1933, 'grad_norm': 2.1777749061584473, 'learning_rate': 2.9532163742690062e-05, 'epoch': 1.23}


 44%|████▍     | 150/342 [56:29<1:03:39, 19.90s/it]

{'loss': 0.1738, 'grad_norm': 3.4778926372528076, 'learning_rate': 2.8070175438596492e-05, 'epoch': 1.32}


 47%|████▋     | 160/342 [59:49<1:00:00, 19.79s/it]

{'loss': 0.137, 'grad_norm': 4.059450149536133, 'learning_rate': 2.6608187134502928e-05, 'epoch': 1.4}


 50%|████▉     | 170/342 [1:03:09<57:03, 19.91s/it]  

{'loss': 0.0994, 'grad_norm': 3.3373701572418213, 'learning_rate': 2.5146198830409358e-05, 'epoch': 1.49}


 53%|█████▎    | 180/342 [1:06:29<54:13, 20.08s/it]

{'loss': 0.1114, 'grad_norm': 0.8837440609931946, 'learning_rate': 2.368421052631579e-05, 'epoch': 1.58}


 56%|█████▌    | 190/342 [1:09:48<50:05, 19.77s/it]

{'loss': 0.0892, 'grad_norm': 1.311502456665039, 'learning_rate': 2.2222222222222223e-05, 'epoch': 1.67}


 58%|█████▊    | 200/342 [1:13:08<47:05, 19.89s/it]

{'loss': 0.074, 'grad_norm': 1.9475340843200684, 'learning_rate': 2.0760233918128656e-05, 'epoch': 1.75}


                                                   
 58%|█████▊    | 200/342 [1:17:01<47:05, 19.89s/it]

{'eval_loss': 0.07098432630300522, 'eval_accuracy': 0.9911894273127754, 'eval_runtime': 233.4033, 'eval_samples_per_second': 1.945, 'eval_steps_per_second': 0.124, 'epoch': 1.75}


 61%|██████▏   | 210/342 [1:21:41<1:06:57, 30.43s/it]

{'loss': 0.0543, 'grad_norm': 0.793697714805603, 'learning_rate': 1.929824561403509e-05, 'epoch': 1.84}


 64%|██████▍   | 220/342 [1:26:15<55:30, 27.30s/it]  

{'loss': 0.0477, 'grad_norm': 0.5363067388534546, 'learning_rate': 1.7836257309941522e-05, 'epoch': 1.93}


 67%|██████▋   | 230/342 [1:30:32<46:59, 25.17s/it]

{'loss': 0.0562, 'grad_norm': 0.30392104387283325, 'learning_rate': 1.6374269005847955e-05, 'epoch': 2.02}


 70%|███████   | 240/342 [1:34:23<35:22, 20.81s/it]

{'loss': 0.027, 'grad_norm': 0.2660457193851471, 'learning_rate': 1.4912280701754386e-05, 'epoch': 2.11}


 73%|███████▎  | 250/342 [1:37:43<30:55, 20.16s/it]

{'loss': 0.0241, 'grad_norm': 0.244992196559906, 'learning_rate': 1.3450292397660819e-05, 'epoch': 2.19}


 76%|███████▌  | 260/342 [1:41:03<27:07, 19.85s/it]

{'loss': 0.0227, 'grad_norm': 0.18153800070285797, 'learning_rate': 1.1988304093567252e-05, 'epoch': 2.28}


 79%|███████▉  | 270/342 [1:44:23<24:16, 20.23s/it]

{'loss': 0.0224, 'grad_norm': 0.20052620768547058, 'learning_rate': 1.0526315789473684e-05, 'epoch': 2.37}


 82%|████████▏ | 280/342 [1:47:41<20:26, 19.79s/it]

{'loss': 0.0179, 'grad_norm': 0.22117994725704193, 'learning_rate': 9.064327485380117e-06, 'epoch': 2.46}


 85%|████████▍ | 290/342 [1:51:01<17:14, 19.88s/it]

{'loss': 0.0197, 'grad_norm': 0.2551204264163971, 'learning_rate': 7.602339181286549e-06, 'epoch': 2.54}


 88%|████████▊ | 300/342 [1:54:22<14:17, 20.42s/it]

{'loss': 0.018, 'grad_norm': 0.300246000289917, 'learning_rate': 6.140350877192982e-06, 'epoch': 2.63}


                                                   
 88%|████████▊ | 300/342 [1:58:39<14:17, 20.42s/it]

{'eval_loss': 0.027735428884625435, 'eval_accuracy': 1.0, 'eval_runtime': 257.0178, 'eval_samples_per_second': 1.766, 'eval_steps_per_second': 0.113, 'epoch': 2.63}


 91%|█████████ | 310/342 [2:03:30<18:03, 33.85s/it]  

{'loss': 0.0186, 'grad_norm': 0.192501038312912, 'learning_rate': 4.678362573099415e-06, 'epoch': 2.72}


 94%|█████████▎| 320/342 [2:08:15<10:24, 28.37s/it]

{'loss': 0.019, 'grad_norm': 0.25078198313713074, 'learning_rate': 3.216374269005848e-06, 'epoch': 2.81}


 96%|█████████▋| 330/342 [2:12:53<05:19, 26.63s/it]

{'loss': 0.0193, 'grad_norm': 0.19568824768066406, 'learning_rate': 1.7543859649122807e-06, 'epoch': 2.89}


 99%|█████████▉| 340/342 [2:16:14<00:40, 20.14s/it]

{'loss': 0.0166, 'grad_norm': 0.20978698134422302, 'learning_rate': 2.9239766081871344e-07, 'epoch': 2.98}


100%|██████████| 342/342 [2:16:45<00:00, 23.99s/it]

{'train_runtime': 8205.3428, 'train_samples_per_second': 0.663, 'train_steps_per_second': 0.042, 'train_loss': 0.8064742495975735, 'epoch': 3.0}





TrainOutput(global_step=342, training_loss=0.8064742495975735, metrics={'train_runtime': 8205.3428, 'train_samples_per_second': 0.663, 'train_steps_per_second': 0.042, 'total_flos': 4.219419671059784e+17, 'train_loss': 0.8064742495975735, 'epoch': 3.0})

In [6]:
# Evaluate the model on the test dataset
eval_results = trainer.evaluate()
trainer.log_metrics("eval", eval_results)
trainer.save_metrics("eval", eval_results)

100%|██████████| 29/29 [03:18<00:00,  6.86s/it]

***** eval metrics *****
  epoch                   =        3.0
  eval_accuracy           =        1.0
  eval_loss               =      0.026
  eval_runtime            = 0:03:27.34
  eval_samples_per_second =       2.19
  eval_steps_per_second   =       0.14





In [7]:
# Save the final trained model
model.save_pretrained("../models/results/final_model")

# Inference

In [None]:
model = ViTForImageClassification.from_pretrained("./final_model")
model.eval()

In [None]:
# Pick a training image to test with
example_path = "../data/processed/selected/480px-ISO_7000_-_Ref-No_0082.svg.png"
expected_class = "0082"
expected_class_idx = labels.index(expected_class)
example = load_image({"image": example_path}, t)
example = example["pixel_values"].unsqueeze(0)  # Add batch dimension

In [None]:
example

In [None]:
# Perform inference
with torch.no_grad():  # Disable gradient computation during inference
    outputs = model(example)
    logits = outputs.logits

# Convert logits to probabilities and get predicted class
probs = torch.softmax(logits, dim=-1)
predicted_class_idx = torch.argmax(probs, dim=-1).item()

In [None]:
predicted_class_idx

In [None]:
expected_class_idx

In [None]:
probs