<a href="https://colab.research.google.com/github/karri-ten/Plant-disease-detection/blob/main/Plant_Disease_Detection_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🌿 Plant Disease Detection Using PyTorch and EfficientNet

In [None]:
# ✅ Mount Google Drive (optional)

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# ✅ Enable GPU

In [2]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [3]:
# ✅ Install Required Libraries (if needed)

In [4]:
!pip install -q kaggle
from google.colab import files
files.upload()  # Upload kaggle.json

Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"teniolakareemat","key":"0a3af859538a3c98ede653ca6231f13a"}'}

In [None]:
# ✅ Set up Kaggle credentials

In [6]:
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

mkdir: cannot create directory ‘/root/.kaggle’: File exists


In [None]:
# ✅ Download and unzip the dataset

In [7]:
!kaggle datasets download -d abdallahalidev/plantvillage-dataset
!unzip -q plantvillage-dataset.zip -d ./data

Dataset URL: https://www.kaggle.com/datasets/abdallahalidev/plantvillage-dataset
License(s): CC-BY-NC-SA-4.0
Downloading plantvillage-dataset.zip to /content
 99% 2.01G/2.04G [00:10<00:00, 234MB/s]
100% 2.04G/2.04G [00:10<00:00, 211MB/s]


In [None]:
# ✅ Check contents of the dataset

In [8]:
import os
print(os.listdir('./data/plantvillage dataset/'))
print(os.listdir('./data/plantvillage dataset/color/'))

['color', 'grayscale', 'segmented']
['Pepper,_bell___healthy', 'Potato___Early_blight', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Spot', 'Tomato___Tomato_mosaic_virus', 'Tomato___Late_blight', 'Tomato___healthy', 'Cherry_(including_sour)___healthy', 'Grape___Black_rot', 'Peach___Bacterial_spot', 'Blueberry___healthy', 'Apple___Apple_scab', 'Tomato___Leaf_Mold', 'Apple___healthy', 'Apple___Cedar_apple_rust', 'Strawberry___Leaf_scorch', 'Orange___Haunglongbing_(Citrus_greening)', 'Tomato___Early_blight', 'Peach___healthy', 'Grape___healthy', 'Corn_(maize)___Northern_Leaf_Blight', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Potato___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Tomato___Bacterial_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Cherry_(including_sour)___Powdery_mildew', 'Strawberry___healthy', 'Pepper,_bell___Bacterial_spo

In [None]:
# ✅ Set up DataLoaders from 'color' folder

In [10]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

IMG_SIZE = 224
BATCH_SIZE = 32

train_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

data_dir = './data/plantvillage dataset/color'
dataset = datasets.ImageFolder(data_dir, transform=train_transforms)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# ✅ Load Pretrained Model (EfficientNet)

In [11]:
from torchvision import models
import torch.nn as nn

model = models.efficientnet_b0(pretrained=True)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, len(dataset.classes))
model = model.to(device)

Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth
100%|██████████| 20.5M/20.5M [00:00<00:00, 99.0MB/s]


In [None]:
# ✅ Define Loss and Optimizer

In [12]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
# ✅ Training and Evaluation Functions

In [13]:
def train_model(model, train_loader, val_loader, epochs=10):
    for epoch in range(epochs):
        model.train()
        running_loss, correct = 0.0, 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            correct += (outputs.argmax(1) == labels).sum().item()
        val_acc = evaluate(model, val_loader)
        print(f"Epoch {epoch+1}: Loss={running_loss:.3f}, Train Acc={correct/len(train_loader.dataset):.3f}, Val Acc={val_acc:.3f}")

def evaluate(model, loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            correct += (outputs.argmax(1) == labels).sum().item()
    return correct / len(loader.dataset)

In [14]:
# ✅ Save and Load the Model

In [15]:
torch.save(model.state_dict(), '/content/drive/MyDrive/plant_disease_model.pth')
# model.load_state_dict(torch.load('/content/drive/MyDrive/plant_disease_model.pth'))
# model.to(device)
# model.eval()

In [None]:
# ✅ Prediction Function

In [16]:
from PIL import Image
def predict(image_path, model):
    image = Image.open(image_path).convert("RGB")
    image = val_transforms(image).unsqueeze(0).to(device)
    model.eval()
    with torch.no_grad():
        output = model(image)
        pred = output.argmax(1).item()
    return dataset.classes[pred]

In [18]:
# ✅ Export to ONNX
!pip3 install onnx



In [19]:
dummy_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device)
torch.onnx.export(model, dummy_input, "plant_disease.onnx", input_names=["input"], output_names=["output"], opset_version=11)

In [None]:
# ✅ Styled Gradio Interface

In [None]:
import gradio as gr
def gradio_predict(img):
    img = img.convert("RGB")
    image_tensor = val_transforms(img).unsqueeze(0).to(device)
    model.eval()
    with torch.no_grad():
        output = model(image_tensor)
        pred = output.argmax(1).item()
    return dataset.classes[pred]

custom_css = '''
#component-0 {
    border: 4px solid #4CAF50 !important;
    border-radius: 10px !important;
    padding: 20px !important;
    background-color: #f9fff9 !important;
}
h1, .output_class {
    color: #2E7D32 !important;
    font-family: 'Segoe UI', sans-serif;
}
button {
    background-color: #4CAF50 !important;
    color: white !important;
    font-weight: bold !important;
    border-radius: 8px !important;
}
'''

with gr.Blocks(css=custom_css) as demo:
    gr.Markdown("# 🌿 Plant Disease Detector")
    gr.Markdown("Upload a plant leaf image below to detect disease.")
    with gr.Row():
        image_input = gr.Image(type="pil", label="Upload Leaf Image")
        output_label = gr.Label(label="Prediction")
    submit_btn = gr.Button("Detect Disease")
    submit_btn.click(fn=gradio_predict, inputs=image_input, outputs=output_label)
demo.launch(debug=True)

It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://f4b8efd733a57717a3.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
