In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torchvision.models import vit_b_16
from sklearn.metrics import classification_report

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Paths
train_dir = "/kaggle/input/tomato/train"
val_dir = "/kaggle/input/tomato/valid"

# Image transformations
image_size = 224
batch_size = 32

train_transforms = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

val_transforms = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# Load datasets
train_dataset = datasets.ImageFolder(root=train_dir, transform=train_transforms)
val_dataset = datasets.ImageFolder(root=val_dir, transform=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Class names
class_names = train_dataset.classes
num_classes = len(class_names)

print(f"Classes: {class_names}")

# Load pretrained ViT and modify the classifier
model = vit_b_16(pretrained=True)
model.heads = nn.Sequential(
    nn.Linear(model.heads.head.in_features, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, num_classes)
)

model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training loop
epochs = 10

for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}/{epochs}")
    model.train()
    running_loss = 0.0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Train Loss: {running_loss / len(train_loader):.4f}")

    # Validation
    model.eval()
    correct, total = 0, 0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    val_acc = 100 * correct / total
    print(f"Validation Accuracy: {val_acc:.2f}%")



Classes: ['Bacterial_spot', 'Early_blight', 'Late_blight', 'Leaf_Mold', 'Septoria_leaf_spot', 'Spider_mites Two-spotted_spider_mite', 'Target_Spot', 'Tomato_Yellow_Leaf_Curl_Virus', 'Tomato_mosaic_virus', 'healthy', 'powdery_mildew']


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:01<00:00, 183MB/s]  



Epoch 1/10
Train Loss: 0.3693
Validation Accuracy: 92.77%

Epoch 2/10
Train Loss: 0.1328
Validation Accuracy: 97.10%

Epoch 3/10
Train Loss: 0.0900
Validation Accuracy: 94.16%

Epoch 4/10
Train Loss: 0.0692
Validation Accuracy: 94.45%

Epoch 5/10
Train Loss: 0.0651
Validation Accuracy: 96.62%

Epoch 6/10
Train Loss: 0.0541
Validation Accuracy: 96.80%

Epoch 7/10
Train Loss: 0.0562
Validation Accuracy: 95.80%

Epoch 8/10
Train Loss: 0.0431
Validation Accuracy: 97.44%

Epoch 9/10
Train Loss: 0.0474
Validation Accuracy: 97.16%

Epoch 10/10
Train Loss: 0.0452
Validation Accuracy: 97.22%


In [2]:
# Classification Report
print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=class_names))

# Save model
torch.save(model.state_dict(), "vit_tomato_disease.pth")
print("\nModel saved as 'vit_tomato_disease.pth'")


Classification Report:
                                      precision    recall  f1-score   support

                      Bacterial_spot       0.93      0.98      0.96       732
                        Early_blight       0.95      0.97      0.96       643
                         Late_blight       0.98      0.97      0.97       792
                           Leaf_Mold       0.99      0.99      0.99       739
                  Septoria_leaf_spot       0.96      0.96      0.96       746
Spider_mites Two-spotted_spider_mite       0.99      0.95      0.97       435
                         Target_Spot       0.94      0.98      0.96       457
       Tomato_Yellow_Leaf_Curl_Virus       1.00      0.98      0.99       498
                 Tomato_mosaic_virus       0.99      0.99      0.99       584
                             healthy       0.99      0.97      0.98       805
                      powdery_mildew       1.00      0.94      0.97       252

                            accuracy  