# ViT for Image Classification on Bean Leaf Disease Dataset 🌱

## Setup & Dataset Loading

In [None]:
# Install required packages
!pip install datasets transformers evaluate torchvision matplotlib scikit-learn --q


In [None]:
pip install --upgrade datasets --q

### Load Dataset

In [None]:
from datasets import load_dataset
import matplotlib.pyplot as plt

# Load the Beans dataset from Hugging Face
dataset = load_dataset("beans")

# Check the split names and data size
print(dataset)


### Visualize Class Distribution

In [None]:
from collections import Counter

# Count label occurrences in the train set
label_counts = Counter(dataset['train']['labels'])
label_names = dataset['train'].features['labels'].names

for label_id, count in label_counts.items():
    print(f"{label_names[label_id]}: {count} images")

# bar chart
plt.bar([label_names[i] for i in label_counts.keys()], label_counts.values())
plt.title("Beans Dataset - Class Distribution")
plt.xlabel("Class")
plt.ylabel("Count")
plt.xticks(rotation=15)
plt.show()


### Visualize Sample Images

In [None]:
import random

fig, axes = plt.subplots(1, 5, figsize=(15, 3))
for i in range(5):
    sample = dataset['train'][random.randint(0, len(dataset['train']) - 1)]
    image = sample['image']
    label = label_names[sample['labels']]

    axes[i].imshow(image)
    axes[i].set_title(label)
    axes[i].axis("off")

plt.tight_layout()
plt.show()


## Preprocessing + Model Setup

### Preprocessing with ```AutoImageProcessor```

In [None]:
from transformers import AutoImageProcessor

# Use the image processor associated with ViT
checkpoint = "google/vit-base-patch16-224"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

# Preprocessing function to convert images to tensor inputs
def transform(example):
    return image_processor(example['image'], return_tensors="pt")

# Apply preprocessing to datasets
# We'll set batched=True to automatically stack tensors
encoded_dataset = dataset.map(lambda x: transform(x), batched=True)


**Fix Input Format**

The previous map will return nested PyTorch tensors, which Hugging Face Trainer can’t use directly. So we unwrap the tensors:

In [None]:
def transform(example):
    # Convert PIL image to processed input format
    processed = image_processor(example['image'], return_tensors="pt")
    example['pixel_values'] = processed['pixel_values'][0]
    return example

encoded_dataset = dataset.map(transform)
encoded_dataset.set_format(type="torch", columns=["pixel_values", "labels"])


### Load Pre-trained ViT Model

In [None]:
from transformers import ViTForImageClassification

# Load the ViT model for image classification with 3 output labels
model = ViTForImageClassification.from_pretrained(
    checkpoint,
    num_labels=3,
    ignore_mismatched_sizes=True,
    id2label={i: label for i, label in enumerate(label_names)},
    label2id={label: i for i, label in enumerate(label_names)}
)


## Training & Evaluation with Trainer

### Define Accuracy Metric
We'll use evaluate to calculate accuracy during validation.

In [None]:
import evaluate

accuracy_metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = logits.argmax(axis=-1)
    return accuracy_metric.compute(predictions=predictions, references=labels)


### Set Training Arguments

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./vit-beans-model",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    eval_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=5,
    learning_rate=2e-5,
    logging_dir="./logs",
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    save_total_limit=2,
    report_to="none"  # Disable wandb or others
)


### Initialize Trainer

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    compute_metrics=compute_metrics,
)


### Train the Model

In [None]:
trainer.train()

### Evaluate the Model

In [None]:
metrics = trainer.evaluate(encoded_dataset["test"])
print(metrics)


## Visualize Predictions & Performance

### Get Model Predictions on Test Set

In [None]:
import torch

# Run predictions on the test set
predictions = trainer.predict(encoded_dataset["test"])

# Extract predicted class indices
y_preds = predictions.predictions.argmax(axis=1)
y_true = predictions.label_ids


### Confusion Matrix

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix(y_true, y_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=label_names)

plt.figure(figsize=(6, 6))
disp.plot(cmap="Blues", xticks_rotation=15)
plt.title("Confusion Matrix - ViT on Bean Leaf Dataset")
plt.show()


### Visualize Sample Predictions

In [None]:
import random

test_set = dataset["test"]

# Show 5 random test images with predictions
fig, axes = plt.subplots(1, 5, figsize=(18, 4))

for i in range(5):
    idx = random.randint(0, len(y_preds) - 1)
    img = test_set[idx]["image"]
    true_label = label_names[y_true[idx]]
    pred_label = label_names[y_preds[idx]]

    axes[i].imshow(img)
    axes[i].set_title(f"True: {true_label}\nPred: {pred_label}", color="green" if true_label == pred_label else "red")
    axes[i].axis("off")

plt.tight_layout()
plt.show()


### Save Trained Model (Optional)

In [None]:
# Save model and processor
model.save_pretrained("vit-beans-model")
image_processor.save_pretrained("vit-beans-model")
