In [1]:
from transformers import ViTForImageClassification, ViTFeatureExtractor, TrainingArguments, Trainer
from datasets import load_dataset, Dataset
import torch
from torchvision import datasets, transforms
from pathlib import Path
import pandas as pd
import numpy as np
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [41]:
train_dataset_path = "C:\MSAAI\AAI-521\Final Project Data\Transformer Data\Training Data"
train_metadata_csv = "C:\MSAAI\AAI-521\Final Project Data\Transformer Data\metadata_train.csv"

val_dataset_path = "C:\MSAAI\AAI-521\Final Project Data\Transformer Data\Validation Data"
val_metadata_csv = "C:\MSAAI\AAI-521\Final Project Data\Transformer Data\metadata_val.csv"

In [42]:
train_metadata = pd.read_csv("C:\MSAAI\AAI-521\Final Project Data\Transformer Data\metadata_train.csv")
val_metadata = pd.read_csv("C:\MSAAI\AAI-521\Final Project Data\Transformer Data\metadata_val.csv")

In [43]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),  # Normalize between -1 and 1
])

In [44]:
train_dataset = datasets.ImageFolder(root=train_dataset_path, transform=transform)
val_dataset = datasets.ImageFolder(root=val_dataset_path, transform=transform)

In [None]:
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=len(train_dataset.classes),
    ignore_mismatched_sizes=True
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([7]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([7, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [46]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [47]:
model.to(device)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [48]:
from PIL import Image

# Load the feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")

# Function to preprocess datasets
def preprocess_dataset(dataset):
    image_paths = [sample[0] for sample in dataset.samples]
    labels = [sample[1] for sample in dataset.samples]

    def preprocess(images, labels):
        # Load images and preprocess
        processed_images = [Image.open(image_path).convert("RGB") for image_path in images]
        pixel_values = feature_extractor(processed_images, return_tensors="pt")["pixel_values"]
        return {"pixel_values": pixel_values, "labels": torch.tensor(labels)}

    dataset = Dataset.from_dict({"image_path": image_paths, "label": labels})
    return dataset.map(lambda x: preprocess(x["image_path"], x["label"]), batched=True)

# Preprocess training and validation datasets
train_hf_dataset = preprocess_dataset(train_dataset)
val_hf_dataset = preprocess_dataset(val_dataset)




[A[A[A


[A[A[A


[A[A[A


 18%|█▊        | 500/2770 [29:44<2:15:00,  3.57s/it]



[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


Map: 100%|██████████| 8864/8864 [00:34<00:00, 256.31 examples/s]
Map: 100%|██████████| 1094/1094 [00:03<00:00, 284.34 examples/s]


In [None]:
from transformers import TrainingArguments, Trainer

# Define training arguments
training_args = TrainingArguments(
    output_dir="./vit-finetune-results",
    evaluation_strategy="steps",
    eval_steps=500,
    save_steps=500,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    learning_rate=5e-5,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=100,
    load_best_model_at_end=True,
    save_total_limit=2,
    fp16=True,  
    report_to="none",  
)

# Function to compute metrics, including per-class accuracy
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    preds = np.argmax(predictions, axis=1)
    report = classification_report(labels, preds, target_names=train_dataset.classes, output_dict=True)
    
    # Print per-class accuracy
    print("\nPer-Class Accuracy:")
    for class_name, metrics in report.items():
        if isinstance(metrics, dict):
            print(f"{class_name}: {metrics['precision'] * 100:.2f}%")

    # Return overall metrics for Trainer
    return {
        "accuracy": report["accuracy"],
        "precision": np.mean([metrics["precision"] for metrics in report.values() if isinstance(metrics, dict)]),
        "recall": np.mean([metrics["recall"] for metrics in report.values() if isinstance(metrics, dict)]),
        "f1": np.mean([metrics["f1-score"] for metrics in report.values() if isinstance(metrics, dict)]),
    }

# Define the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_hf_dataset,
    eval_dataset=val_hf_dataset,
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
)

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


In [50]:
# Fine-tune the model
trainer.train()

  4%|▎         | 100/2770 [01:19<34:01,  1.31it/s]

[A[A                                            


                                                  
[A

  4%|▎         | 100/2770 [01:19<34:01,  1.31it/s]

{'loss': 0.8743, 'grad_norm': 8.041891098022461, 'learning_rate': 4.82129963898917e-05, 'epoch': 0.18}


  7%|▋         | 200/2770 [02:40<34:06,  1.26it/s]

[A[A                                            


                                                  
[A

  7%|▋         | 200/2770 [02:40<34:06,  1.26it/s]

{'loss': 0.6424, 'grad_norm': 8.900994300842285, 'learning_rate': 4.640794223826715e-05, 'epoch': 0.36}


 11%|█         | 300/2770 [03:59<32:29,  1.27it/s]

[A[A                                            


                                                  
[A

 11%|█         | 300/2770 [03:59<32:29,  1.27it/s]

{'loss': 0.5732, 'grad_norm': 7.524267196655273, 'learning_rate': 4.46028880866426e-05, 'epoch': 0.54}


 14%|█▍        | 400/2770 [05:18<31:25,  1.26it/s]

[A[A                                            


                                                  
[A

 14%|█▍        | 400/2770 [05:18<31:25,  1.26it/s]

{'loss': 0.5017, 'grad_norm': 6.2072248458862305, 'learning_rate': 4.279783393501805e-05, 'epoch': 0.72}


 18%|█▊        | 500/2770 [06:38<29:38,  1.28it/s]

[A[A                                            


                                                  
[A

 18%|█▊        | 500/2770 [06:38<29:38,  1.28it/s]

{'loss': 0.4966, 'grad_norm': 9.182103157043457, 'learning_rate': 4.0992779783393506e-05, 'epoch': 0.9}



[A                                            

[A[A                                            


                                                  
[A

 18%|█▊        | 500/2770 [07:31<29:38,  1.28it/s]


Per-Class Accuracy:
akiec: 100.00%
bcc: 76.92%
bkl: 63.16%
df: 71.43%
mel: 41.18%
nv: 92.83%
vasc: 92.31%
macro avg: 76.83%
weighted avg: 87.85%
{'eval_loss': 0.31588634848594666, 'eval_accuracy': 0.8893967093235832, 'eval_precision': 0.7805550262086154, 'eval_recall': 0.6100277639207221, 'eval_f1': 0.6316084738077187, 'eval_runtime': 52.9763, 'eval_samples_per_second': 20.651, 'eval_steps_per_second': 1.302, 'epoch': 0.9}


 22%|██▏       | 600/2770 [08:50<29:10,  1.24it/s]   

[A[A                                            


                                                  
[A

 22%|██▏       | 600/2770 [08:51<29:10,  1.24it/s]

{'loss': 0.3758, 'grad_norm': 5.233223915100098, 'learning_rate': 3.9187725631768956e-05, 'epoch': 1.08}


 25%|██▌       | 700/2770 [10:10<26:45,  1.29it/s]

[A[A                                            


                                                  
[A

 25%|██▌       | 700/2770 [10:10<26:45,  1.29it/s]

{'loss': 0.3071, 'grad_norm': 1.5983847379684448, 'learning_rate': 3.7382671480144405e-05, 'epoch': 1.26}


 29%|██▉       | 800/2770 [11:30<25:08,  1.31it/s]

[A[A                                            


                                                  
[A

 29%|██▉       | 800/2770 [11:30<25:08,  1.31it/s]

{'loss': 0.2517, 'grad_norm': 6.933020114898682, 'learning_rate': 3.5577617328519854e-05, 'epoch': 1.44}


 32%|███▏      | 900/2770 [12:49<24:32,  1.27it/s]

[A[A                                            


                                                  
[A

 32%|███▏      | 900/2770 [12:49<24:32,  1.27it/s]

{'loss': 0.2868, 'grad_norm': 1.7802473306655884, 'learning_rate': 3.377256317689531e-05, 'epoch': 1.62}


 36%|███▌      | 1000/2770 [14:09<24:39,  1.20it/s]

[A[A                                            


                                                   
[A

 36%|███▌      | 1000/2770 [14:10<24:39,  1.20it/s]

{'loss': 0.2503, 'grad_norm': 7.261033058166504, 'learning_rate': 3.196750902527076e-05, 'epoch': 1.81}



[A                                            

[A[A                                            


                                                   
[A

 36%|███▌      | 1000/2770 [15:03<24:39,  1.20it/s]


Per-Class Accuracy:
akiec: 73.91%
bcc: 76.47%
bkl: 90.16%
df: 83.33%
mel: 56.52%
nv: 95.38%
vasc: 85.71%
macro avg: 80.21%
weighted avg: 91.94%
{'eval_loss': 0.2495461404323578, 'eval_accuracy': 0.9223034734917733, 'eval_precision': 0.8151769138234054, 'eval_recall': 0.7441569684876175, 'eval_f1': 0.7730059705785233, 'eval_runtime': 53.8845, 'eval_samples_per_second': 20.303, 'eval_steps_per_second': 1.281, 'epoch': 1.81}


 40%|███▉      | 1100/2770 [16:24<21:37,  1.29it/s]  

[A[A                                            


                                                   
[A

 40%|███▉      | 1100/2770 [16:24<21:37,  1.29it/s]

{'loss': 0.2562, 'grad_norm': 1.4506551027297974, 'learning_rate': 3.0162454873646213e-05, 'epoch': 1.99}


 43%|████▎     | 1200/2770 [17:43<20:53,  1.25it/s]

[A[A                                            


                                                   
[A

 43%|████▎     | 1200/2770 [17:43<20:53,  1.25it/s]

{'loss': 0.094, 'grad_norm': 1.1170485019683838, 'learning_rate': 2.835740072202166e-05, 'epoch': 2.17}


 47%|████▋     | 1300/2770 [19:04<19:52,  1.23it/s]

[A[A                                            


                                                   
[A

 47%|████▋     | 1300/2770 [19:04<19:52,  1.23it/s]

{'loss': 0.0758, 'grad_norm': 4.185434818267822, 'learning_rate': 2.6552346570397112e-05, 'epoch': 2.35}


 51%|█████     | 1400/2770 [20:27<18:42,  1.22it/s]

[A[A                                            


                                                   
[A

 51%|█████     | 1400/2770 [20:27<18:42,  1.22it/s]

{'loss': 0.0672, 'grad_norm': 2.9656546115875244, 'learning_rate': 2.4747292418772565e-05, 'epoch': 2.53}


 54%|█████▍    | 1500/2770 [21:50<18:18,  1.16it/s]

[A[A                                            


                                                   
[A

 54%|█████▍    | 1500/2770 [21:50<18:18,  1.16it/s]

{'loss': 0.0816, 'grad_norm': 2.1244020462036133, 'learning_rate': 2.2942238267148018e-05, 'epoch': 2.71}



[A                                            

[A[A                                            


                                                   
[A

 54%|█████▍    | 1500/2770 [22:46<18:18,  1.16it/s]


Per-Class Accuracy:
akiec: 71.43%
bcc: 78.38%
bkl: 81.01%
df: 46.15%
mel: 52.63%
nv: 96.72%
vasc: 75.00%
macro avg: 71.62%
weighted avg: 91.72%
{'eval_loss': 0.28892436623573303, 'eval_accuracy': 0.9186471663619744, 'eval_precision': 0.7385086993355778, 'eval_recall': 0.7779029573127472, 'eval_f1': 0.7517136254854577, 'eval_runtime': 56.1628, 'eval_samples_per_second': 19.479, 'eval_steps_per_second': 1.229, 'epoch': 2.71}


 58%|█████▊    | 1600/2770 [24:08<16:06,  1.21it/s]  

[A[A                                            


                                                   
[A

 58%|█████▊    | 1600/2770 [24:08<16:06,  1.21it/s]

{'loss': 0.0876, 'grad_norm': 2.789628744125366, 'learning_rate': 2.1137184115523467e-05, 'epoch': 2.89}


 61%|██████▏   | 1700/2770 [25:28<14:00,  1.27it/s]

[A[A                                            


                                                   
[A

 61%|██████▏   | 1700/2770 [25:28<14:00,  1.27it/s]

{'loss': 0.0569, 'grad_norm': 0.7724908590316772, 'learning_rate': 1.9332129963898917e-05, 'epoch': 3.07}


 65%|██████▍   | 1800/2770 [26:48<12:56,  1.25it/s]

[A[A                                            


                                                   
[A

 65%|██████▍   | 1800/2770 [26:48<12:56,  1.25it/s]

{'loss': 0.01, 'grad_norm': 0.38071224093437195, 'learning_rate': 1.752707581227437e-05, 'epoch': 3.25}


 69%|██████▊   | 1900/2770 [28:09<11:31,  1.26it/s]

[A[A                                            


                                                   
[A

 69%|██████▊   | 1900/2770 [28:09<11:31,  1.26it/s]

{'loss': 0.0061, 'grad_norm': 2.140629768371582, 'learning_rate': 1.5722021660649822e-05, 'epoch': 3.43}


 72%|███████▏  | 2000/2770 [29:30<10:36,  1.21it/s]

[A[A                                            


                                                   
[A

 72%|███████▏  | 2000/2770 [29:30<10:36,  1.21it/s]

{'loss': 0.0059, 'grad_norm': 0.015092713758349419, 'learning_rate': 1.3916967509025272e-05, 'epoch': 3.61}



[A                                            

[A[A                                            


                                                   
[A

 72%|███████▏  | 2000/2770 [30:25<10:36,  1.21it/s]


Per-Class Accuracy:
akiec: 65.38%
bcc: 90.32%
bkl: 79.07%
df: 85.71%
mel: 59.38%
nv: 96.32%
vasc: 80.00%
macro avg: 79.46%
weighted avg: 92.10%
{'eval_loss': 0.3384018838405609, 'eval_accuracy': 0.926873857404022, 'eval_precision': 0.8086025904252022, 'eval_recall': 0.7670858079383616, 'eval_f1': 0.7838459848664159, 'eval_runtime': 54.373, 'eval_samples_per_second': 20.12, 'eval_steps_per_second': 1.269, 'epoch': 3.61}


 76%|███████▌  | 2100/2770 [31:46<08:56,  1.25it/s]  

[A[A                                            


                                                   
[A

 76%|███████▌  | 2100/2770 [31:46<08:56,  1.25it/s]

{'loss': 0.0076, 'grad_norm': 0.29263395071029663, 'learning_rate': 1.2111913357400723e-05, 'epoch': 3.79}


 79%|███████▉  | 2200/2770 [33:06<07:33,  1.26it/s]

[A[A                                            


                                                   
[A

 79%|███████▉  | 2200/2770 [33:06<07:33,  1.26it/s]

{'loss': 0.0051, 'grad_norm': 2.8094382286071777, 'learning_rate': 1.0306859205776172e-05, 'epoch': 3.97}


 83%|████████▎ | 2300/2770 [34:27<06:17,  1.25it/s]

[A[A                                            


                                                   
[A

 83%|████████▎ | 2300/2770 [34:27<06:17,  1.25it/s]

{'loss': 0.0015, 'grad_norm': 0.10787701606750488, 'learning_rate': 8.501805054151625e-06, 'epoch': 4.15}


 87%|████████▋ | 2400/2770 [35:48<04:54,  1.26it/s]

[A[A                                            


                                                   
[A

 87%|████████▋ | 2400/2770 [35:48<04:54,  1.26it/s]

{'loss': 0.0007, 'grad_norm': 0.008645083755254745, 'learning_rate': 6.6967509025270755e-06, 'epoch': 4.33}


 90%|█████████ | 2500/2770 [37:09<03:37,  1.24it/s]

[A[A                                            


                                                   
[A

 90%|█████████ | 2500/2770 [37:09<03:37,  1.24it/s]

{'loss': 0.0008, 'grad_norm': 0.010107063688337803, 'learning_rate': 4.8916967509025275e-06, 'epoch': 4.51}



[A                                            

[A[A                                            


                                                   
[A

 90%|█████████ | 2500/2770 [38:04<03:37,  1.24it/s]


Per-Class Accuracy:
akiec: 76.00%
bcc: 80.00%
bkl: 81.48%
df: 85.71%
mel: 55.00%
nv: 96.75%
vasc: 85.71%
macro avg: 80.09%
weighted avg: 92.48%
{'eval_loss': 0.34534671902656555, 'eval_accuracy': 0.9287020109689214, 'eval_precision': 0.8146996530602681, 'eval_recall': 0.7809380397955548, 'eval_f1': 0.7960141173773496, 'eval_runtime': 54.1603, 'eval_samples_per_second': 20.199, 'eval_steps_per_second': 1.274, 'epoch': 4.51}


 94%|█████████▍| 2600/2770 [39:25<02:18,  1.23it/s]  

[A[A                                            


                                                   
[A

 94%|█████████▍| 2600/2770 [39:25<02:18,  1.23it/s]

{'loss': 0.0006, 'grad_norm': 0.038976412266492844, 'learning_rate': 3.0866425992779787e-06, 'epoch': 4.69}


 97%|█████████▋| 2700/2770 [40:46<00:58,  1.20it/s]

[A[A                                            


                                                   
[A

 97%|█████████▋| 2700/2770 [40:46<00:58,  1.20it/s]

{'loss': 0.0007, 'grad_norm': 0.06750218570232391, 'learning_rate': 1.2815884476534297e-06, 'epoch': 4.87}


100%|██████████| 2770/2770 [41:42<00:00,  1.67it/s]

[A[A                                            


                                                   
[A

100%|██████████| 2770/2770 [41:43<00:00,  1.11it/s]

{'train_runtime': 2503.285, 'train_samples_per_second': 17.705, 'train_steps_per_second': 1.107, 'train_loss': 0.19202751522944292, 'epoch': 5.0}





TrainOutput(global_step=2770, training_loss=0.19202751522944292, metrics={'train_runtime': 2503.285, 'train_samples_per_second': 17.705, 'train_steps_per_second': 1.107, 'total_flos': 3.4345988889388646e+18, 'train_loss': 0.19202751522944292, 'epoch': 5.0})

In [51]:
# Evaluate the model on validation dataset
metrics = trainer.evaluate()
print("Validation metrics:", metrics)

100%|██████████| 69/69 [00:54<00:00,  1.27it/s]


Per-Class Accuracy:
akiec: 73.91%
bcc: 76.47%
bkl: 90.16%
df: 83.33%
mel: 56.52%
nv: 95.38%
vasc: 85.71%
macro avg: 80.21%
weighted avg: 91.94%
Validation metrics: {'eval_loss': 0.2495461404323578, 'eval_accuracy': 0.9223034734917733, 'eval_precision': 0.8151769138234054, 'eval_recall': 0.7441569684876175, 'eval_f1': 0.7730059705785233, 'eval_runtime': 55.8367, 'eval_samples_per_second': 19.593, 'eval_steps_per_second': 1.236, 'epoch': 5.0}





In [52]:
# Print accuracy
print(f"Validation Accuracy: {metrics['eval_accuracy'] * 100:.2f}%")

Validation Accuracy: 92.23%


In [2]:
# Load the log history from the trainer
log_history = trainer.state.log_history

# Convert to a pandas DataFrame for easier manipulation
df = pd.DataFrame(log_history)

# Filter for training and validation losses
train_loss = df[df["loss"].notnull()]
eval_loss = df[df["eval_loss"].notnull()]

# Plot training and validation losses
plt.figure(figsize=(10, 6))
plt.plot(train_loss["step"], train_loss["loss"], label="Training Loss", marker="o")
plt.plot(eval_loss["step"], eval_loss["eval_loss"], label="Validation Loss", marker="o")

# Add labels, legend, and grid
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Learning Curve")
plt.legend()
plt.grid(True)
plt.show()

NameError: name 'trainer' is not defined

In [53]:
# Save the final model explicitly
trainer.save_model("./vit-final-model")