<a href="https://colab.research.google.com/github/karthik7147/few-shortlearning-with-vision-tranformers/blob/main/fewshort_learning_with_vision_transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
import kagglehub
kagglehub.login()


In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

amirberenji_thermal_images_of_induction_motor_path = kagglehub.dataset_download('amirberenji/thermal-images-of-induction-motor')
kkaarrtthhiikkr_thermal_path = kagglehub.dataset_download('kkaarrtthhiikkr/thermal')

print('Data source import complete.')


In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import timm
import numpy as np
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import seaborn as sns
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score

# ==============================
# Step 1: Data Augmentation (Fixed)
# ==============================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.3),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

dataset_path = "../input/thermal-images-of-induction-motor"
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)

train_size = int(0.85 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# ==============================
# Step 2: Vision Transformer Model (Fixed)
# ==============================
class ViTModel(nn.Module):
    def __init__(self, num_classes):
        super(ViTModel, self).__init__()
        self.vit = timm.create_model('vit_base_patch16_224', pretrained=True)
        self.fc = nn.Linear(self.vit.head.in_features, num_classes)
        self.dropout = nn.Dropout(0.4)
        self.bn = nn.BatchNorm1d(self.vit.head.in_features)

    def forward(self, x):
        x = self.vit.forward_features(x)[:, 0, :]
        x = self.bn(x)
        x = self.dropout(x)
        return self.fc(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = len(dataset.classes)
model = ViTModel(num_classes).to(device)

# ==============================
# Step 3: Extract Features (Fixed)
# ==============================
def extract_features(model, data_loader):
    model.eval()
    features, labels = [], []

    with torch.no_grad():
        for images, lbls in data_loader:
            images = images.to(device)
            output = model.vit.forward_features(images)[:, 0, :]
            features.append(output.cpu().numpy())
            labels.append(lbls.cpu().numpy())

    return np.concatenate(features), np.concatenate(labels)

train_features, train_labels = extract_features(model, train_loader)
test_features, test_labels = extract_features(model, test_loader)

# ==============================
# Step 4: Few-Shot Model (Fixed)
# ==============================
class FewShotClassifier(nn.Module):
    def __init__(self, feature_dim, num_classes):
        super(FewShotClassifier, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.fc(x)

few_shot_model = FewShotClassifier(train_features.shape[1], num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(few_shot_model.parameters(), lr=0.005, momentum=0.9, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

train_features = torch.tensor(train_features, dtype=torch.float32).to(device)
train_labels = torch.tensor(train_labels, dtype=torch.long).to(device)

# Train Few-Shot Classifier with Stability Fixes
few_shot_model.train()
for epoch in range(25):  # Increased training epochs
    optimizer.zero_grad()
    outputs = few_shot_model(train_features)
    loss = criterion(outputs, train_labels)
    loss.backward()
    optimizer.step()
    scheduler.step()

    accuracy = accuracy_score(train_labels.cpu().numpy(), torch.argmax(outputs, 1).cpu().numpy())
    print(f"Epoch [{epoch+1}/25], Loss: {loss.item():.4f}, Train Accuracy: {accuracy * 100:.2f}%")

# ==============================
# Step 5: Evaluate Model
# ==============================
test_features = torch.tensor(test_features, dtype=torch.float32).to(device)
test_labels = torch.tensor(test_labels, dtype=torch.long).to(device)

few_shot_model.eval()
with torch.no_grad():
    outputs = few_shot_model(test_features)
    _, predicted = torch.max(outputs, 1)
    accuracy = accuracy_score(test_labels.cpu().numpy(), predicted.cpu().numpy())
    print(f"âœ… Improved Accuracy: {accuracy * 100:.2f}%")

# ==============================
# Step 6: Confusion Matrix & Report
# ==============================
true_labels = test_labels.cpu().numpy()
predicted_labels = predicted.cpu().numpy()

conf_matrix = confusion_matrix(true_labels, predicted_labels)
class_report = classification_report(true_labels, predicted_labels, target_names=dataset.classes)

print("\nðŸ“Š Classification Report:\n", class_report)

plt.figure(figsize=(10, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap="Blues", xticklabels=dataset.classes, yticklabels=dataset.classes)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")
plt.show()


In [None]:
def plot_predictions(model, few_shot_model, data_loader, num_images=5):
    model.eval()
    few_shot_model.eval()

    images, labels = next(iter(data_loader))
    images, labels = images[:num_images].to(device), labels[:num_images].to(device)

    with torch.no_grad():
        features, _ = extract_features(model, [(images, labels)])
        features = torch.tensor(features, dtype=torch.float32).to(device)
        outputs = few_shot_model(features)
        _, preds = torch.max(outputs, 1)

    fig, axes = plt.subplots(1, num_images, figsize=(12, 4))
    for i in range(num_images):
        img = images[i].cpu().permute(1, 2, 0).numpy()
        img = (img - img.min()) / (img.max() - img.min())
        axes[i].imshow(img)
        axes[i].set_title(f"Pred: {dataset.classes[preds[i]]}\nActual: {dataset.classes[labels[i]]}")
        axes[i].axis('off')

    plt.show()

plot_predictions(model, few_shot_model, test_loader)


In [None]:
def plot_predictions(model, data_loader, num_images=5):
    model.eval()

    images, labels = next(iter(data_loader))
    num_images = min(num_images, len(images))  # Ensure we don't exceed available samples
    images, labels = images[:num_images].to(device), labels[:num_images].to(device)

    with torch.no_grad():
        outputs = model(images)
        _, preds = torch.max(outputs, 1)

    fig, axes = plt.subplots(1, num_images, figsize=(14, 4))
    for i in range(num_images):
        img = images[i].cpu().permute(1, 2, 0).numpy()
        img = (img - img.min()) / (img.max() - img.min())  # Normalize for display
        axes[i].imshow(img)
        axes[i].set_title(f"Pred: {dataset.classes[preds[i]]}\nActual: {dataset.classes[labels[i]]}")
        axes[i].axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import timm
import numpy as np
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score

# ========== 1. Augmentation ==========
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomAffine(degrees=10, translate=(0.05, 0.05)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

dataset_path = "../input/thermal-images-of-induction-motor"
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
train_size = int(0.85 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# ========== 2. Full ViT Fine-tuning ==========
class ViTClassifier(nn.Module):
    def __init__(self, num_classes):
        super(ViTClassifier, self).__init__()
        self.vit = timm.create_model('vit_small_patch16_224', pretrained=True, num_classes=num_classes)

    def forward(self, x):
        return self.vit(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = len(dataset.classes)
model = ViTClassifier(num_classes).to(device)

# ========== 3. Training Setup ==========
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

# ========== 4. Training Loop with Early Stopping ==========
best_accuracy = 0
patience, trigger_times = 5, 0

for epoch in range(30):
    model.train()
    total_loss = 0
    all_preds, all_labels = [], []

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        all_preds.extend(torch.argmax(outputs, 1).cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    train_acc = accuracy_score(all_labels, all_preds)
    print(f"Epoch [{epoch+1}/30], Loss: {total_loss:.4f}, Train Accuracy: {train_acc * 100:.2f}%")

    # Early stopping logic
    if train_acc > best_accuracy:
        best_accuracy = train_acc
        trigger_times = 0
        torch.save(model.state_dict(), "best_vit_model.pth")
    else:
        trigger_times += 1
        if trigger_times >= patience:
            print("Early stopping triggered.")
            break

    scheduler.step()

# ========== 5. Evaluation ==========
model.load_state_dict(torch.load("best_vit_model.pth"))
model.eval()
all_preds, all_labels = [], []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        all_preds.extend(torch.argmax(outputs, 1).cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

test_acc = accuracy_score(all_labels, all_preds)
print(f"\nâœ… Final Test Accuracy: {test_acc * 100:.2f}%")

# ========== 6. Report & Confusion Matrix ==========
print("\nðŸ“Š Classification Report:\n", classification_report(all_labels, all_preds, target_names=dataset.classes))

conf_matrix = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap="Blues", xticklabels=dataset.classes, yticklabels=dataset.classes)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")
plt.show()


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import timm
import numpy as np
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score

# ============================
# 1. Augmentation
# ============================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomAffine(degrees=10, translate=(0.05, 0.05)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

dataset_path = "../input/thermal-images-of-induction-motor"
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
train_size = int(0.85 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# ============================
# 2. Model
# ============================
class ViTClassifier(nn.Module):
    def __init__(self, num_classes):
        super(ViTClassifier, self).__init__()
        self.vit = timm.create_model('vit_small_patch16_224', pretrained=True, num_classes=num_classes)

    def forward(self, x):
        return self.vit(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = len(dataset.classes)
model = ViTClassifier(num_classes).to(device)

# ============================
# 3. Training Setup
# ============================
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

# ---- Tracking variables ----
loss_values = []
train_acc_values = []
lr_values = []

# ============================
# 4. Training Loop + Early Stopping
# ============================
best_accuracy = 0
patience, trigger_times = 5, 0

for epoch in range(30):
    model.train()
    total_loss = 0
    all_preds, all_labels = [], []

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        all_preds.extend(torch.argmax(outputs, 1).cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    train_acc = accuracy_score(all_labels, all_preds)
    current_lr = optimizer.param_groups[0]["lr"]

    print(f"Epoch [{epoch+1}/30], Loss: {total_loss:.4f}, Train Acc: {train_acc*100:.2f}%, LR: {current_lr}")

    # ---- Store for graphs ----
    loss_values.append(total_loss)
    train_acc_values.append(train_acc)
    lr_values.append(current_lr)

    # ---- Early Stopping ----
    if train_acc > best_accuracy:
        best_accuracy = train_acc
        trigger_times = 0
        torch.save(model.state_dict(), "best_vit_model.pth")
    else:
        trigger_times += 1
        if trigger_times >= patience:
            print("Early stopping triggered.")
            break

    scheduler.step()

# ============================
# 5. Evaluation
# ============================
model.load_state_dict(torch.load("best_vit_model.pth"))
model.eval()

all_preds, all_labels = [], []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        all_preds.extend(torch.argmax(outputs, 1).cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

test_acc = accuracy_score(all_labels, all_preds)
print(f"\nâœ… Final Test Accuracy: {test_acc * 100:.2f}%")

# ============================
# 6. Classification Report + Confusion Matrix
# ============================
print("\nðŸ“Š Classification Report:\n", classification_report(all_labels, all_preds, target_names=dataset.classes))

conf_matrix = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap="Blues",
            xticklabels=dataset.classes, yticklabels=dataset.classes)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")
plt.show()

# ============================
# 7. Training Graphs
# ============================

# ---- Loss Curve ----
plt.figure(figsize=(8, 5))
plt.plot(loss_values, marker='o')
plt.title("Training Loss Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.show()

# ---- Accuracy Curve ----
plt.figure(figsize=(8, 5))
plt.plot([acc * 100 for acc in train_acc_values], marker='o')
plt.title("Training Accuracy Curve")
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.grid(True)
plt.show()

# ---- Learning Rate Curve ----
plt.figure(figsize=(8, 5))
plt.plot(lr_values, marker='o')
plt.title("Learning Rate Curve")
plt.xlabel("Epoch")
plt.ylabel("Learning Rate")
plt.grid(True)
plt.show()


In [None]:
import matplotlib.pyplot as plt

# ============================
# 3 GRAPHS IN ONE FIGURE
# ============================

plt.figure(figsize=(12, 12))

# ---- 1. Training Loss Curve ----
plt.subplot(3, 1, 1)
plt.plot(loss_values, marker='o')
plt.title("Training Loss Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)

# ---- 2. Training Accuracy Curve ----
plt.subplot(3, 1, 2)
plt.plot([acc * 100 for acc in train_acc_values], marker='o')
plt.title("Training Accuracy Curve")
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.grid(True)

# ---- 3. Learning Rate Curve ----
plt.subplot(3, 1, 3)
plt.plot(lr_values, marker='o')
plt.title("Learning Rate Curve")
plt.xlabel("Epoch")
plt.ylabel("Learning Rate")
plt.grid(True)

plt.tight_layout()
plt.show()


In [None]:
def plot_predictions(model, data_loader, num_images=10):
    model.eval()

    images, labels = next(iter(data_loader))
    images, labels = images[:num_images].to(device), labels[:num_images].to(device)

    with torch.no_grad():
        outputs = model(images)
        _, preds = torch.max(outputs, 1)

    # ---- Create 2 rows Ã— 5 columns ----
    rows, cols = 2, 5
    fig, axes = plt.subplots(rows, cols, figsize=(16, 7))

    idx = 0
    for r in range(rows):
        for c in range(cols):
            if idx >= num_images:
                axes[r][c].axis('off')
                continue

            img = images[idx].cpu().permute(1, 2, 0).numpy()
            img = (img - img.min()) / (img.max() - img.min())

            axes[r][c].imshow(img)
            axes[r][c].set_title(
                f"Pred: {dataset.classes[preds[idx]]}\nActual: {dataset.classes[labels[idx]]}",
                fontsize=10
            )
            axes[r][c].axis('off')
            idx += 1

    plt.tight_layout()
    plt.show()


In [None]:
plot_predictions(model, test_loader, num_images=10)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import timm
import numpy as np
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score

# ============================
# 1. Augmentation
# ============================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomAffine(degrees=10, translate=(0.05, 0.05)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

dataset_path = "../input/thermal-images-of-induction-motor"
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)

# ---- Train / Val / Test split ----
train_size = int(0.7 * len(dataset))
val_size   = int(0.15 * len(dataset))
test_size  = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size]
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False)

# ============================
# 2. Model
# ============================
class ViTClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.vit = timm.create_model(
            'vit_small_patch16_224',
            pretrained=True,
            num_classes=num_classes
        )

    def forward(self, x):
        return self.vit(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = len(dataset.classes)
model = ViTClassifier(num_classes).to(device)

# ============================
# 3. Training Setup
# ============================
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

# ---- Tracking ----
train_loss_vals, val_loss_vals = [], []
train_acc_vals, val_acc_vals = [], []
lr_vals = []

# ============================
# 4. Training Loop (Early Stop on VAL ACC)
# ============================
best_val_acc = 0
patience, trigger = 5, 0

for epoch in range(30):
    # -------- TRAIN --------
    model.train()
    train_loss = 0
    train_preds, train_labels = [], []

    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_preds.extend(outputs.argmax(1).cpu().numpy())
        train_labels.extend(labels.cpu().numpy())

    train_acc = accuracy_score(train_labels, train_preds)

    # -------- VALIDATION --------
    model.eval()
    val_loss = 0
    val_preds, val_labels = [], []

    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            val_preds.extend(outputs.argmax(1).cpu().numpy())
            val_labels.extend(labels.cpu().numpy())

    val_acc = accuracy_score(val_labels, val_preds)

    # ---- Store ----
    train_loss_vals.append(train_loss)
    val_loss_vals.append(val_loss)
    train_acc_vals.append(train_acc)
    val_acc_vals.append(val_acc)
    lr_vals.append(optimizer.param_groups[0]["lr"])

    print(f"Epoch [{epoch+1}/30] | "
          f"Train Acc: {train_acc*100:.2f}% | "
          f"Val Acc: {val_acc*100:.2f}%")

    # ---- Early Stopping ----
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        trigger = 0
        torch.save(model.state_dict(), "best_vit_model.pth")
    else:
        trigger += 1
        if trigger >= patience:
            print("Early stopping triggered.")
            break

    scheduler.step()

# ============================
# 5. Testing
# ============================
model.load_state_dict(torch.load("best_vit_model.pth"))
model.eval()

test_preds, test_labels = [], []

with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        test_preds.extend(outputs.argmax(1).cpu().numpy())
        test_labels.extend(labels.cpu().numpy())

test_acc = accuracy_score(test_labels, test_preds)
print(f"\nâœ… Final Test Accuracy: {test_acc*100:.2f}%")

# ============================
# 6. Report + Confusion Matrix
# ============================
print("\nðŸ“Š Classification Report\n",
      classification_report(test_labels, test_preds, target_names=dataset.classes))

cm = confusion_matrix(test_labels, test_preds)
plt.figure(figsize=(10,6))
sns.heatmap(cm, annot=True, fmt='d', cmap="Blues",
            xticklabels=dataset.classes,
            yticklabels=dataset.classes)
plt.title("Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()

# ============================
# 7. ONE FIGURE â€“ 3 GRAPHS
# ============================
epochs = range(1, len(train_loss_vals)+1)

plt.figure(figsize=(18,5))

# ---- Loss ----
plt.subplot(1,3,1)
plt.plot(epochs, train_loss_vals, label="Train Loss")
plt.plot(epochs, val_loss_vals, label="Val Loss")
plt.title("Loss Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)

# ---- Accuracy ----
plt.subplot(1,3,2)
plt.plot(epochs, [a*100 for a in train_acc_vals], label="Train Acc")
plt.plot(epochs, [a*100 for a in val_acc_vals], label="Val Acc")
plt.title("Accuracy Curve")
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.legend()
plt.grid(True)

# ---- Learning Rate ----
plt.subplot(1,3,3)
plt.plot(epochs, lr_vals)
plt.title("Learning Rate Schedule")
plt.xlabel("Epoch")
plt.ylabel("LR")
plt.grid(True)

plt.tight_layout()
plt.show()



In [None]:
epochs = range(1, len(train_loss_vals)+1)

plt.figure(figsize=(18,5))

# ---- Loss ----
plt.subplot(1,3,1)
plt.plot(epochs, train_loss_vals, label="Train Loss")
plt.plot(epochs, val_loss_vals, label="Val Loss")
plt.title("Loss Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)

# ---- Accuracy ----
plt.subplot(1,3,2)
plt.plot(epochs, [a*100 for a in train_acc_vals], label="Train Acc")
plt.plot(epochs, [a*100 for a in val_acc_vals], label="Val Acc")
plt.title("Accuracy Curve")
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.legend()
plt.grid(True)


In [None]:
plt.figure(figsize=(18, 5))  # wider figure

# ---- Loss ----
plt.subplot(1, 3, 1)
plt.plot(epochs, train_loss_vals, label="Train Loss", linewidth=2)
plt.plot(epochs, val_loss_vals, label="Val Loss", linewidth=2)
plt.title("Loss Curve", fontsize=20)
plt.xlabel("Epoch", fontsize=18)
plt.ylabel("Loss", fontsize=18)
plt.legend(fontsize=12)
plt.grid(True)
plt.tick_params(axis='both', which='major', labelsize=12)

# ---- Accuracy ----
plt.subplot(1, 3, 2)
plt.plot(epochs, [a*100 for a in train_acc_vals], label="Train Accuracy", linewidth=2)
plt.plot(epochs, [a*100 for a in val_acc_vals], label="Validation Accuracy", linewidth=2)
plt.title("Accuracy Curve", fontsize=20)
plt.xlabel("Epoch", fontsize=18)
plt.ylabel("Accuracy (%)", fontsize=18)
plt.legend(fontsize=12)
plt.grid(True)
plt.tick_params(axis='both', which='major', labelsize=12)


plt.tight_layout()
plt.show()
