In [4]:
import torch
from datasets import load_from_disk
from transformers import Trainer, TrainingArguments, ViTForImageClassification

# Load Dataset

In [5]:
dataset = load_from_disk("../data/processed/huggingface")

# Training

In [20]:
# Load the pre-trained model
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=len(class_labels.names),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)},
    ignore_mismatched_sizes=True,  # Ignore the classifier's size mismatch
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([63]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([63, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="no",
    per_device_train_batch_size=16,
    num_train_epochs=3,
    save_strategy="epoch",
    save_total_limit=2,
    logging_dir="./logs",
    logging_steps=10,
)

In [22]:
# Create Trainer instance
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
)

In [23]:
# Train the model
trainer.train()


  2%|▏         | 10/426 [13:30<9:21:49, 81.03s/it]
                                                  
  2%|▏         | 10/426 [04:49<3:13:08, 27.86s/it]

{'loss': 4.2508, 'grad_norm': 7.926233291625977, 'learning_rate': 4.882629107981221e-05, 'epoch': 0.07}


                                                  
  5%|▍         | 20/426 [09:33<3:11:40, 28.33s/it]

{'loss': 4.2231, 'grad_norm': 6.351879596710205, 'learning_rate': 4.765258215962441e-05, 'epoch': 0.14}


                                                  
  7%|▋         | 30/426 [12:55<2:13:24, 20.21s/it]

{'loss': 3.959, 'grad_norm': 7.181564807891846, 'learning_rate': 4.647887323943662e-05, 'epoch': 0.21}


                                                  
  9%|▉         | 40/426 [16:18<2:10:00, 20.21s/it]

{'loss': 3.7604, 'grad_norm': 7.8429718017578125, 'learning_rate': 4.530516431924883e-05, 'epoch': 0.28}


                                                  
 12%|█▏        | 50/426 [19:40<2:06:32, 20.19s/it]

{'loss': 3.1608, 'grad_norm': 10.527081489562988, 'learning_rate': 4.413145539906103e-05, 'epoch': 0.35}


                                                  
 14%|█▍        | 60/426 [23:04<2:02:33, 20.09s/it]

{'loss': 2.7857, 'grad_norm': 8.003210067749023, 'learning_rate': 4.295774647887324e-05, 'epoch': 0.42}


                                                  
 16%|█▋        | 70/426 [26:27<2:00:40, 20.34s/it]

{'loss': 2.1989, 'grad_norm': 7.3468194007873535, 'learning_rate': 4.178403755868545e-05, 'epoch': 0.49}


                                                  
 19%|█▉        | 80/426 [29:48<1:55:06, 19.96s/it]

{'loss': 1.8665, 'grad_norm': 7.855565547943115, 'learning_rate': 4.0610328638497654e-05, 'epoch': 0.56}


                                                  
 21%|██        | 90/426 [33:11<1:52:37, 20.11s/it]

{'loss': 1.4652, 'grad_norm': 6.727529525756836, 'learning_rate': 3.943661971830986e-05, 'epoch': 0.63}


                                                   
 23%|██▎       | 100/426 [36:33<1:49:40, 20.19s/it]

{'loss': 0.9334, 'grad_norm': 5.915176868438721, 'learning_rate': 3.826291079812207e-05, 'epoch': 0.7}


                                                   
 26%|██▌       | 110/426 [39:55<1:45:07, 19.96s/it]

{'loss': 0.7788, 'grad_norm': 6.130402565002441, 'learning_rate': 3.7089201877934274e-05, 'epoch': 0.77}


                                                   
 28%|██▊       | 120/426 [43:18<1:43:54, 20.37s/it]

{'loss': 0.59, 'grad_norm': 3.9398374557495117, 'learning_rate': 3.5915492957746486e-05, 'epoch': 0.85}


                                                   
 31%|███       | 130/426 [46:39<1:38:04, 19.88s/it]

{'loss': 0.3201, 'grad_norm': 2.114210367202759, 'learning_rate': 3.474178403755869e-05, 'epoch': 0.92}


                                                   
 33%|███▎      | 140/426 [50:02<1:35:51, 20.11s/it]

{'loss': 0.3384, 'grad_norm': 3.2942051887512207, 'learning_rate': 3.3568075117370895e-05, 'epoch': 0.99}


                                                   
 35%|███▌      | 150/426 [53:21<1:32:27, 20.10s/it]

{'loss': 0.2898, 'grad_norm': 1.15336275100708, 'learning_rate': 3.23943661971831e-05, 'epoch': 1.06}


                                                   
 38%|███▊      | 160/426 [56:41<1:27:42, 19.78s/it] 

{'loss': 0.1475, 'grad_norm': 1.042989730834961, 'learning_rate': 3.1220657276995305e-05, 'epoch': 1.13}


                                                     
 40%|███▉      | 170/426 [1:00:03<1:26:30, 20.28s/it]

{'loss': 0.117, 'grad_norm': 1.9178990125656128, 'learning_rate': 3.0046948356807513e-05, 'epoch': 1.2}


                                                     
 42%|████▏     | 180/426 [1:03:21<1:20:28, 19.63s/it]

{'loss': 0.1614, 'grad_norm': 1.3865584135055542, 'learning_rate': 2.887323943661972e-05, 'epoch': 1.27}


                                                     
 45%|████▍     | 190/426 [1:06:42<1:18:11, 19.88s/it]

{'loss': 0.0804, 'grad_norm': 0.6282534599304199, 'learning_rate': 2.7699530516431926e-05, 'epoch': 1.34}


                                                     
 47%|████▋     | 200/426 [1:10:01<1:14:57, 19.90s/it]

{'loss': 0.0793, 'grad_norm': 2.2629177570343018, 'learning_rate': 2.6525821596244134e-05, 'epoch': 1.41}


                                                     
 49%|████▉     | 210/426 [1:13:21<1:11:06, 19.75s/it]

{'loss': 0.0923, 'grad_norm': 1.9020198583602905, 'learning_rate': 2.535211267605634e-05, 'epoch': 1.48}


                                                     
 52%|█████▏    | 220/426 [1:16:42<1:09:43, 20.31s/it]

{'loss': 0.0595, 'grad_norm': 0.85466468334198, 'learning_rate': 2.4178403755868547e-05, 'epoch': 1.55}


                                                     
 54%|█████▍    | 230/426 [1:20:00<1:04:17, 19.68s/it]

{'loss': 0.0475, 'grad_norm': 0.3305336534976959, 'learning_rate': 2.300469483568075e-05, 'epoch': 1.62}


                                                     
 56%|█████▋    | 240/426 [1:23:21<1:02:11, 20.06s/it]

{'loss': 0.0552, 'grad_norm': 1.686480164527893, 'learning_rate': 2.1830985915492956e-05, 'epoch': 1.69}


                                                     
 59%|█████▊    | 250/426 [1:26:41<58:37, 19.98s/it] 

{'loss': 0.0427, 'grad_norm': 0.28902679681777954, 'learning_rate': 2.0657276995305167e-05, 'epoch': 1.76}


                                                     
 61%|██████    | 260/426 [1:30:02<55:09, 19.93s/it] 

{'loss': 0.0373, 'grad_norm': 0.7185868620872498, 'learning_rate': 1.9483568075117372e-05, 'epoch': 1.83}


                                                   
 63%|██████▎   | 270/426 [1:33:25<53:28, 20.56s/it] 

{'loss': 0.0334, 'grad_norm': 0.23054943978786469, 'learning_rate': 1.830985915492958e-05, 'epoch': 1.9}


                                                   
 66%|██████▌   | 280/426 [1:36:45<47:57, 19.71s/it] 

{'loss': 0.0214, 'grad_norm': 0.20681364834308624, 'learning_rate': 1.7136150234741785e-05, 'epoch': 1.97}


                                                   
 68%|██████▊   | 290/426 [1:40:05<45:42, 20.16s/it] 

{'loss': 0.0179, 'grad_norm': 0.29753947257995605, 'learning_rate': 1.5962441314553993e-05, 'epoch': 2.04}


                                                   
 70%|███████   | 300/426 [1:43:25<41:51, 19.93s/it] 

{'loss': 0.0199, 'grad_norm': 0.18191979825496674, 'learning_rate': 1.4788732394366198e-05, 'epoch': 2.11}


                                                   
 73%|███████▎  | 310/426 [1:46:47<38:28, 19.90s/it] 

{'loss': 0.0144, 'grad_norm': 0.15219365060329437, 'learning_rate': 1.3615023474178404e-05, 'epoch': 2.18}


                                                   
 75%|███████▌  | 320/426 [1:50:09<36:03, 20.41s/it] 

{'loss': 0.014, 'grad_norm': 0.14550182223320007, 'learning_rate': 1.2441314553990612e-05, 'epoch': 2.25}


                                                   
 77%|███████▋  | 330/426 [1:53:29<31:40, 19.79s/it] 

{'loss': 0.0127, 'grad_norm': 0.1351010799407959, 'learning_rate': 1.1267605633802817e-05, 'epoch': 2.32}


                                                   
 80%|███████▉  | 340/426 [1:56:51<29:02, 20.26s/it] 

{'loss': 0.0142, 'grad_norm': 0.09526295214891434, 'learning_rate': 1.0093896713615023e-05, 'epoch': 2.39}


                                                   
 82%|████████▏ | 350/426 [2:00:13<25:38, 20.24s/it] 

{'loss': 0.0115, 'grad_norm': 0.12635745108127594, 'learning_rate': 8.92018779342723e-06, 'epoch': 2.46}


                                                   
 85%|████████▍ | 360/426 [2:03:35<22:02, 20.04s/it] 

{'loss': 0.015, 'grad_norm': 0.12200980633497238, 'learning_rate': 7.746478873239436e-06, 'epoch': 2.54}


                                                   
 87%|████████▋ | 370/426 [2:06:58<19:07, 20.48s/it] 

{'loss': 0.013, 'grad_norm': 0.09888613969087601, 'learning_rate': 6.572769953051644e-06, 'epoch': 2.61}


                                                   
 89%|████████▉ | 380/426 [2:10:19<15:15, 19.89s/it] 

{'loss': 0.0121, 'grad_norm': 0.1133284643292427, 'learning_rate': 5.3990610328638506e-06, 'epoch': 2.68}


                                                   
 92%|█████████▏| 390/426 [2:13:42<12:03, 20.11s/it] 

{'loss': 0.0122, 'grad_norm': 0.16865159571170807, 'learning_rate': 4.225352112676056e-06, 'epoch': 2.75}


                                                   
 94%|█████████▍| 400/426 [2:17:04<08:48, 20.32s/it] 

{'loss': 0.0138, 'grad_norm': 0.13869592547416687, 'learning_rate': 3.051643192488263e-06, 'epoch': 2.82}


                                                   
 96%|█████████▌| 410/426 [2:20:26<05:20, 20.05s/it] 

{'loss': 0.0114, 'grad_norm': 0.09893126040697098, 'learning_rate': 1.8779342723004696e-06, 'epoch': 2.89}


                                                   
 99%|█████████▊| 420/426 [2:23:49<02:00, 20.15s/it] 

{'loss': 0.0109, 'grad_norm': 0.1360124945640564, 'learning_rate': 7.042253521126761e-07, 'epoch': 2.96}


                                                   
100%|██████████| 426/426 [2:25:49<00:00, 20.54s/it] 

{'train_runtime': 8748.9982, 'train_samples_per_second': 0.778, 'train_steps_per_second': 0.049, 'train_loss': 0.7534212004305891, 'epoch': 3.0}





TrainOutput(global_step=426, training_loss=0.7534212004305891, metrics={'train_runtime': 8748.9982, 'train_samples_per_second': 0.778, 'train_steps_per_second': 0.049, 'total_flos': 5.275437604169564e+17, 'train_loss': 0.7534212004305891, 'epoch': 3.0})

In [24]:
# Save the final trained model
model.save_pretrained("./final_model")

# Inference

In [25]:
model = ViTForImageClassification.from_pretrained("./final_model")
model.eval()

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [34]:
# Pick a training image to test with
example_path = "../data/processed/selected/480px-ISO_7000_-_Ref-No_0082.svg.png"
expected_class = "0082"
expected_class_idx = labels.index(expected_class)
example = load_image({"image": example_path}, t)
example = example["pixel_values"].unsqueeze(0)  # Add batch dimension

In [35]:
example

tensor([[[[0.6000, 0.6000, 0.6000,  ..., 0.6000, 0.6000, 0.6000],
          [0.6000, 0.5804, 0.4941,  ..., 0.4941, 0.5804, 0.6000],
          [0.6000, 0.4941, 0.0745,  ..., 0.0745, 0.4941, 0.6000],
          ...,
          [0.6000, 0.4941, 0.0745,  ..., 0.0745, 0.4941, 0.6000],
          [0.6000, 0.5804, 0.4941,  ..., 0.4941, 0.5804, 0.6000],
          [0.6000, 0.6000, 0.6000,  ..., 0.6000, 0.6000, 0.6000]],

         [[0.6000, 0.6000, 0.6000,  ..., 0.6000, 0.6000, 0.6000],
          [0.6000, 0.5804, 0.4941,  ..., 0.4941, 0.5804, 0.6000],
          [0.6000, 0.4941, 0.0745,  ..., 0.0745, 0.4941, 0.6000],
          ...,
          [0.6000, 0.4941, 0.0745,  ..., 0.0745, 0.4941, 0.6000],
          [0.6000, 0.5804, 0.4941,  ..., 0.4941, 0.5804, 0.6000],
          [0.6000, 0.6000, 0.6000,  ..., 0.6000, 0.6000, 0.6000]],

         [[0.6000, 0.6000, 0.6000,  ..., 0.6000, 0.6000, 0.6000],
          [0.6000, 0.5804, 0.4941,  ..., 0.4941, 0.5804, 0.6000],
          [0.6000, 0.4941, 0.0745,  ..., 0

In [36]:
# Perform inference
with torch.no_grad():  # Disable gradient computation during inference
    outputs = model(example)
    logits = outputs.logits

# Convert logits to probabilities and get predicted class
probs = torch.softmax(logits, dim=-1)
predicted_class_idx = torch.argmax(probs, dim=-1).item()

In [37]:
predicted_class_idx

15

In [38]:
expected_class_idx

8

In [39]:
probs

tensor([[0.0098, 0.0087, 0.0029, 0.0106, 0.0069, 0.0082, 0.0210, 0.0116, 0.0117,
         0.0145, 0.0062, 0.0353, 0.0117, 0.0061, 0.0284, 0.0615, 0.0208, 0.0199,
         0.0026, 0.0400, 0.0224, 0.0056, 0.0228, 0.0242, 0.0230, 0.0052, 0.0086,
         0.0490, 0.0123, 0.0175, 0.0492, 0.0055, 0.0287, 0.0201, 0.0113, 0.0074,
         0.0220, 0.0063, 0.0053, 0.0163, 0.0062, 0.0104, 0.0224, 0.0052, 0.0215,
         0.0045, 0.0159, 0.0072, 0.0049, 0.0095, 0.0263, 0.0295, 0.0094, 0.0186,
         0.0054, 0.0172, 0.0148, 0.0187, 0.0069, 0.0047, 0.0067, 0.0276, 0.0053]])