In [None]:
# üåø Plant Disease AI Classifier (Optimized)

# Key Updates in this Version:
# 1. Hardware Agnostic: Automatically detects and uses CUDA (GPU) if available.
# 2. Production Normalization: Uses standard ImageNet statistics for better accuracy.
# 3. Top-K Probabilities: Returns the top 3 results with confidence bars for better UX.
# 4. Safe Loading: Implements `add_safe_globals` for the `ResNet9` class.

# Environment Setup
# !pip install gradio torch torchvision pillow

import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F
from torch.serialization import add_safe_globals

# Set Device (Contributor Tip: Always support GPU acceleration if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Model Architecture: ResNet9
class ResNet9(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(ResNet9, self).__init__()
        def conv_block(in_channels, out_channels, pool=False):
            layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                      nn.BatchNorm2d(out_channels),
                      nn.ReLU(inplace=True)]
            if pool: layers.append(nn.MaxPool2d(2))
            return nn.Sequential(*layers)

        self.conv1 = conv_block(in_channels, 64)
        self.conv2 = conv_block(64, 128, pool=True)
        self.res1 = nn.Sequential(conv_block(128, 128), conv_block(128, 128))
        self.conv3 = conv_block(128, 256, pool=True)
        self.conv4 = conv_block(256, 512, pool=True)
        self.res2 = nn.Sequential(conv_block(512, 512), conv_block(512, 512))
        self.classifier = nn.Sequential(
            nn.MaxPool2d(4),
            nn.Flatten(),
            nn.Linear(512, 38)  # Fixed Linear Layer sizing for standard ResNet9 output
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.res1(x) + x
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.res2(x) + x
        return self.classifier(x)

# Model Loading
add_safe_globals({'ResNet9': ResNet9})  # Register for safe loading

model_path = "/kaggle/input/plant-disease-01/plant-disease-model-complete.pth"  # Update if path changes
try:
    # Load model and move to active device (GPU/CPU)
    model = torch.load(model_path, map_location=device, weights_only=False)
    model.to(device)
    model.eval()
    print("Model loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}. Please check your model path.")

# üè∑ Class Mapping
class_names = {
    0: 'Tomato___Late_blight', 1: 'Tomato___healthy', 2: 'Grape___healthy', 3: 'Orange___Haunglongbing_(Citrus_greening)',
    4: 'Soybean___healthy', 5: 'Squash___Powdery_mildew', 6: 'Potato___healthy', 7: 'Corn_(maize)___Northern_Leaf_Blight',
    8: 'Tomato___Early_blight', 9: 'Tomato___Septoria_leaf_spot', 10: 'Corn_(maize)___Cercospora_leaf_spot_Gray_leaf_spot',
    11: 'Strawberry___Leaf_scorch', 12: 'Peach___healthy', 13: 'Apple___Apple_scab', 14: 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
    15: 'Tomato___Bacterial_spot', 16: 'Apple___Black_rot', 17: 'Blueberry___healthy', 18: 'Cherry_(including_sour)___Powdery_mildew',
    19: 'Peach___Bacterial_spot', 20: 'Apple___Cedar_apple_rust', 21: 'Tomato___Target_Spot', 22: 'Pepper,_bell___healthy',
    23: 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 24: 'Potato___Late_blight', 25: 'Tomato___Tomato_mosaic_virus',
    26: 'Strawberry___healthy', 27: 'Apple___healthy', 28: 'Grape___Black_rot', 29: 'Potato___Early_blight',
    30: 'Cherry_(including_sour)___healthy', 31: 'Corn_(maize)___Common_rust_', 32: 'Grape___Esca_(Black_Measles)',
    33: 'Raspberry___healthy', 34: 'Tomato___Leaf_Mold', 35: 'Tomato___Spider_mites_Two-spotted_spider_mite',
    36: 'Pepper,_bell___Bacterial_spot', 37: 'Corn_(maize)___healthy'
}

# Image Transformation
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    # Standard ImageNet normalization for deep learning models
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Prediction Logic
def predict(img):
    """
    Transforms image, runs inference, and returns probabilities.
    """
    if img is None:
        return "Please upload an image."
        
    # Preprocess
    img = img.convert('RGB')
    img_tensor = transform(img).unsqueeze(0).to(device)  # Send to GPU if available
    
    with torch.no_grad():
        # Get Model Logits
        output = model(img_tensor)
        
        # Convert to Probabilities
        probabilities = F.softmax(output[0], dim=0)
        
        # Format for Gradio Label
        confidences = {
            class_names[i]: float(probabilities[i]) 
            for i in range(len(class_names))
        }
        
        return confidences

# Launch Gradio UI
gr.Interface(
    fn=predict,
    inputs=gr.Image(type='pil', label="Upload Leaf Photo"),
    outputs=gr.Label(num_top_classes=3, label="Top Predictions"),  # Professional Top-K bar chart
    title="üåø Plant Disease AI Classifier",
    description="Identify 38 plant diseases instantly using deep learning. Optimized for ResNet9.",
    theme="soft"
).launch(share=True)  # Share=True is required for Kaggle/Colab