# Common Houseplant Identification Assistant

This notebook implements a machine learning-based application that helps identify common houseplants from images and provides basic care recommendations. The application uses computer vision techniques to classify houseplant images and displays relevant care information for the identified plants.

## Notebook Contents

1. [Setup and Installation](#setup)
2. [Data Preparation](#data)
3. [Model Development](#model)
4. [Training the Model](#training)
5. [Inference and Evaluation](#inference)
6. [Care Recommendations](#care)
7. [Interactive Interface with Gradio](#interface)

Let's begin by setting up our environment and installing the necessary dependencies.

## 1. Setup and Installation <a id="setup"></a>

First, we'll install the required packages:

In [1]:
# Install required packages
!pip install torch torchvision tqdm numpy matplotlib pandas pillow gradio scikit-learn transformers huggingface_hub




[notice] A new release of pip is available: 25.0.1 -> 25.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
# Import necessary libraries
import os
import sys
import json
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm.notebook import tqdm
from pathlib import Path
import requests
from io import BytesIO
import time
from sklearn.metrics import classification_report, confusion_matrix
import gradio as gr
import warnings
warnings.filterwarnings('ignore')

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create necessary directories
os.makedirs('data', exist_ok=True)
os.makedirs('models', exist_ok=True)
os.makedirs('outputs', exist_ok=True)

Using device: cpu


## 2. Data Preparation <a id="data"></a>

We'll now define our dataset of common houseplants. For the purposes of this notebook, we'll create a database with information about 20 common houseplant species and their care requirements.

In [None]:
# Define a list of common houseplant species
COMMON_HOUSEPLANTS = [
    "Monstera deliciosa",        # Swiss Cheese Plant
    "Ficus lyrata",              # Fiddle Leaf Fig
    "Sansevieria trifasciata",   # Snake Plant
    "Chlorophytum comosum",      # Spider Plant
    "Epipremnum aureum",         # Pothos
    "Spathiphyllum wallisii",    # Peace Lily
    "Zamioculcas zamiifolia",    # ZZ Plant
    "Dracaena marginata",        # Dragon Tree
    "Calathea makoyana",         # Peacock Plant
    "Pilea peperomioides",       # Chinese Money Plant
    "Philodendron bipinnatifidum", # Split-leaf Philodendron
    "Aloe vera",                 # Aloe Vera
    "Ficus elastica",            # Rubber Plant
    "Maranta leuconeura",        # Prayer Plant
    "Aglaonema commutatum",      # Chinese Evergreen
    "Peperomia obtusifolia",     # Baby Rubber Plant
    "Anthurium andraeanum",      # Flamingo Flower
    "Schlumbergera bridgesii",   # Christmas Cactus
    "Crassula ovata",            # Jade Plant
    "Aspidistra elatior"         # Cast Iron Plant
]

In [None]:
# Create care information database
care_database = {
  "Monstera deliciosa": {
    "common_name": "Swiss Cheese Plant",
    "light": "Bright indirect light. Avoid direct sunlight which can burn the leaves.",
    "water": "Allow top 2-3 inches of soil to dry out between waterings. Water less in winter.",
    "soil": "Well-draining potting mix with some peat and perlite.",
    "temperature": "65-85°F (18-29°C). Keep away from cold drafts.",
    "humidity": "Prefers moderate to high humidity. Regular misting is beneficial.",
    "fertilizer": "Feed monthly during growing season with balanced houseplant fertilizer.",
    "common_issues": "Brown leaf tips (low humidity), yellow leaves (overwatering), lack of fenestration (insufficient light).",
    "toxicity": "Toxic to pets if ingested. Contains calcium oxalate crystals."
  },
  "Ficus lyrata": {
    "common_name": "Fiddle Leaf Fig",
    "light": "Bright indirect light. Some direct morning sun is beneficial.",
    "water": "Water when top inch of soil is dry. Ensure thorough drainage.",
    "soil": "Well-draining potting mix with peat and perlite.",
    "temperature": "60-75°F (15-24°C). Avoid temperature fluctuations.",
    "humidity": "Moderate humidity, around 40-60%.",
    "fertilizer": "Feed monthly in spring and summer with diluted houseplant fertilizer.",
    "common_issues": "Brown spots (overwatering or bacterial infection), leaf drop (stress from relocation or temperature changes).",
    "toxicity": "Mildly toxic to pets and humans if ingested."
  },
  "Sansevieria trifasciata": {
    "common_name": "Snake Plant",
    "light": "Adaptable to various light conditions from low light to bright indirect light.",
    "water": "Allow soil to dry completely between waterings. Water sparingly in winter.",
    "soil": "Well-draining, sandy soil mix.",
    "temperature": "55-85°F (13-29°C). Can tolerate temperature fluctuations.",
    "humidity": "Adaptable to various humidity levels, including dry air.",
    "fertilizer": "Feed sparingly with cactus fertilizer during growing season.",
    "common_issues": "Root rot (overwatering), brown tips (fluoride in water).",
    "toxicity": "Mildly toxic to pets if ingested."
  }
}

# Save the care database to a JSON file
with open("data/care_database.json", "w") as f:
    json.dump(care_database, f, indent=2)

print(f"Created care database with {len(care_database)} plant species")

### Dataset and DataLoader

Now, let's define our dataset and image transformations for training, validation, and inference.

In [None]:
# Define image size and transformations
IMAGE_SIZE = 224

# Create transformations for training and validation
def create_transforms(is_training=True):
    """Create image transformation pipelines for training and validation sets."""
    if is_training:
        # More aggressive augmentation for training set
        transform = transforms.Compose([
            transforms.RandomResizedCrop(IMAGE_SIZE),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(30),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],  # ImageNet means
                std=[0.229, 0.224, 0.225]     # ImageNet stds
            )
        ])
    else:
        # Simpler preprocessing for validation/inference
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(IMAGE_SIZE),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
    
    return transform

# Create dataset class
class HouseplantDataset(Dataset):
    """Dataset class for houseplant images."""
    
    def __init__(self, data_dir, transform=None, class_names=None):
        """Initialize the dataset."""
        self.data_dir = data_dir
        self.transform = transform
        
        # If there's actual data in the directory
        if os.path.exists(data_dir) and len(os.listdir(data_dir)) > 0:
            self.classes = sorted([d for d in os.listdir(data_dir) 
                             if os.path.isdir(os.path.join(data_dir, d))])
            self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
            self.samples = self._make_dataset()
        # If we're just initializing with class names (for testing/demo)
        elif class_names is not None:
            self.classes = class_names
            self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
            self.samples = []  # Empty samples list
        else:
            # Fallback to empty lists
            self.classes = []
            self.class_to_idx = {}
            self.samples = []
    
    def _make_dataset(self):
        """Create a list of (image_path, class_idx) tuples."""
        samples = []
        for class_name in self.classes:
            class_dir = os.path.join(self.data_dir, class_name)
            class_idx = self.class_to_idx[class_name]
            
            for root, _, files in os.walk(class_dir):
                for file in files:
                    if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                        samples.append((os.path.join(root, file), class_idx))
        
        return samples
    
    def __len__(self):
        """Return the number of samples in the dataset."""
        return len(self.samples)
    
    def __getitem__(self, idx):
        """Get a sample from the dataset."""
        if len(self.samples) == 0:
            # Create a dummy sample for testing/demo
            dummy_tensor = torch.zeros((3, IMAGE_SIZE, IMAGE_SIZE))
            return dummy_tensor, 0
        
        image_path, class_idx = self.samples[idx]
        
        # Load and convert image
        try:
            image = Image.open(image_path).convert('RGB')
            
            # Apply transforms if available
            if self.transform:
                image = self.transform(image)
            
            return image, class_idx
            
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            # Return a placeholder if image loading fails
            placeholder = torch.zeros((3, IMAGE_SIZE, IMAGE_SIZE))
            return placeholder, class_idx
    
    def get_class_name(self, class_idx):
        """Get the class name for a given class index."""
        if class_idx < len(self.classes):
            return self.classes[class_idx]
        else:
            return "Unknown"

For this notebook, we'll create a small dummy dataset for demonstration purposes. In a real scenario, you would download and prepare a real dataset with plant images.

In [None]:
# Let's create a dummy dataset for demonstration
# In a real scenario, you would download a real dataset

# Function to download a sample image
def download_sample_image(plant_name, save_dir="data/demo"):
    """Download a sample image for a given plant species."""
    # Convert plant name to a search term
    search_term = plant_name.replace(" ", "+")
    
    # Create a placeholder image
    img = Image.new('RGB', (300, 300), color=(73, 109, 137))
    
    # Save in the appropriate directory
    os.makedirs(os.path.join(save_dir, plant_name.replace(" ", "_")), exist_ok=True)
    img_path = os.path.join(save_dir, plant_name.replace(" ", "_"), f"{plant_name.replace(' ', '_')}_sample.jpg")
    img.save(img_path)
    
    return img_path

# Create a small dummy dataset
demo_plants = COMMON_HOUSEPLANTS[:3]  # Just use the first 3 plants for demo
demo_dir = "data/demo"
os.makedirs(demo_dir, exist_ok=True)

demo_images = []
for plant in demo_plants:
    img_path = download_sample_image(plant, demo_dir)
    demo_images.append(img_path)

print(f"Created demo dataset with {len(demo_plants)} plant species")

In [None]:
# Create train and validation datasets
train_transform = create_transforms(is_training=True)
val_transform = create_transforms(is_training=False)

# We'll use the demo dataset for both training and validation
train_dataset = HouseplantDataset(demo_dir, transform=train_transform)
val_dataset = HouseplantDataset(demo_dir, transform=val_transform)

# Create data loaders
batch_size = 4  # Small batch size for demo
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"Created datasets with {len(train_dataset.classes)} classes")
print(f"Class names: {train_dataset.classes}")

## 3. Model Development <a id="model"></a>

Now let's define our model for plant identification. We'll use transfer learning with a pre-trained MobileNetV3 model.

In [None]:
def create_model(num_classes, model_type="mobilenet", pretrained=True):
    """Create a model for plant classification."""
    if model_type == "mobilenet":
        # Load MobileNetV3 small for better efficiency
        model = models.mobilenet_v3_small(pretrained=pretrained)
        
        # Modify classifier for our number of classes
        in_features = model.classifier[3].in_features
        model.classifier[3] = nn.Linear(in_features, num_classes)
    
    elif model_type == "resnet":
        # Load ResNet-18 for a good balance of performance and speed
        model = models.resnet18(pretrained=pretrained)
        
        # Modify final fully connected layer for our number of classes
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
    
    elif model_type == "efficientnet":
        # Load EfficientNet B0 for a good balance
        model = models.efficientnet_b0(pretrained=pretrained)
        
        # Modify classifier for our number of classes
        in_features = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(in_features, num_classes)
    
    else:
        raise ValueError(f"Unsupported model type: {model_type}")
    
    return model

# Create a model with our current classes
model = create_model(len(train_dataset.classes), model_type="mobilenet")
model = model.to(device)

print(f"Created {type(model).__name__} model with {len(train_dataset.classes)} output classes")

## 4. Training the Model <a id="training"></a>

Let's set up our training functions and train the model.

In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    """Train the model for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    # Progress bar
    pbar = tqdm(dataloader, desc="Training")
    
    for inputs, targets in pbar:
        # Move data to device
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        
        # Calculate loss
        loss = criterion(outputs, targets)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Update metrics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        # Update progress bar
        pbar.set_postfix({
            'loss': running_loss / (pbar.n + 1),
            'acc': 100. * correct / total if total > 0 else 0
        })
    
    # Calculate epoch metrics
    epoch_loss = running_loss / len(dataloader) if len(dataloader) > 0 else 0
    epoch_acc = 100. * correct / total if total > 0 else 0
    
    return epoch_loss, epoch_acc

def validate(model, dataloader, criterion, device):
    """Validate the model on validation data."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for inputs, targets in dataloader:
            # Move data to device
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(inputs)
            
            # Calculate loss
            loss = criterion(outputs, targets)
            
            # Update metrics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # Store predictions and targets for detailed metrics
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
    
    # Calculate validation metrics
    val_loss = running_loss / len(dataloader) if len(dataloader) > 0 else 0
    val_acc = 100. * correct / total if total > 0 else 0
    
    return val_loss, val_acc, all_preds, all_targets

def train_model(model, train_loader, val_loader, num_epochs=5, learning_rate=0.001):
    """Train the model for a specified number of epochs."""
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Training metrics
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    best_val_acc = 0.0
    
    print(f"Starting training for {num_epochs} epochs...")
    start_time = time.time()
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        # Train
        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, device
        )
        
        # Validate
        val_loss, val_acc, all_preds, all_targets = validate(
            model, val_loader, criterion, device
        )
        
        # Save metrics
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch': epoch,
                'val_acc': val_acc,
                'train_acc': train_acc
            }, "models/best_model.pth")
            print(f"Saved new best model with validation accuracy: {val_acc:.2f}%")
    
    # Calculate training time
    total_time = time.time() - start_time
    print(f"Training completed in {total_time/60:.2f} minutes")
    print(f"Best validation accuracy: {best_val_acc:.2f}%")
    
    # Return metrics for plotting
    return train_losses, val_losses, train_accs, val_accs, best_val_acc

In [None]:
# Train the model with our dummy data
# In a real scenario, you would train with a larger dataset
# or use a pre-trained model

# We'll just do a quick training run for demonstration
if len(train_dataset.samples) > 0 and len(val_dataset.samples) > 0:
    train_losses, val_losses, train_accs, val_accs, best_val_acc = train_model(
        model, train_loader, val_loader, num_epochs=2, learning_rate=0.0001
    )
    
    # Plotting training metrics
    plt.figure(figsize=(12, 5))
    
    # Plot losses
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    
    # Plot accuracies
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Training Accuracy')
    plt.plot(val_accs, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Training and Validation Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
else:
    print("Skipping training due to empty dataset.")
    print("For a real implementation, use a proper dataset or a pre-trained model.")
    
    # Save a dummy model for demonstration
    torch.save({
        'model_state_dict': model.state_dict(),
        'epoch': 0,
        'val_acc': 0.0,
        'train_acc': 0.0
    }, "models/best_model.pth")
    
    print("Saved a dummy model for demonstration purposes.")

## 5. Inference and Evaluation <a id="inference"></a>

Now let's implement functions to use our trained model for inference.

In [None]:
def preprocess_image(image, size=IMAGE_SIZE):
    """Preprocess an image for inference."""
    # Create inference transform
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(size),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
    
    # Load and transform image
    if isinstance(image, str):
        # If image is a file path
        image = Image.open(image).convert('RGB')
    elif isinstance(image, np.ndarray):
        # If image is a numpy array
        image = Image.fromarray(image).convert('RGB')
    
    # Apply transform
    image_tensor = transform(image)
    
    # Add batch dimension
    image_tensor = image_tensor.unsqueeze(0)
    return image_tensor

def identify_plant(model, image, class_names, top_k=3):
    """Identify a plant from an image."""
    # Preprocess the image
    image_tensor = preprocess_image(image)
    image_tensor = image_tensor.to(device)
    
    # Set model to evaluation mode
    model.eval()
    
    # Run inference
    with torch.no_grad():
        outputs = model(image_tensor)
        
        # Apply softmax to get probabilities
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        
        # Get top-k probabilities and indices
        top_probs, top_indices = torch.topk(probabilities, min(top_k, len(class_names)))
        
        # Convert to lists
        top_probs = top_probs.cpu().numpy()[0]
        top_indices = top_indices.cpu().numpy()[0]
    
    # Create result list
    results = []
    for i, (idx, prob) in enumerate(zip(top_indices, top_probs)):
        if idx < len(class_names):
            results.append({
                'rank': i + 1,
                'class_name': class_names[idx],
                'probability': float(prob)
            })
    
    return results

def get_plant_care_info(plant_name, care_database):
    """Get care information for a plant species."""
    # Look for exact match first
    if plant_name in care_database:
        return care_database[plant_name]
    
    # Try common name match
    for scientific_name, info in care_database.items():
        if info.get('common_name') == plant_name:
            return info
    
    # Try case-insensitive partial match
    plant_name_lower = plant_name.lower()
    for scientific_name, info in care_database.items():
        if (plant_name_lower in scientific_name.lower() or 
            plant_name_lower in info.get('common_name', '').lower()):
            return info
    
    # No match found
    return None

In [None]:
# Load the model for inference
def load_model_for_inference(model_path, num_classes):
    """Load a trained model for inference."""
    # Create a new model
    model = create_model(num_classes, model_type="mobilenet")
    
    # Load saved state dict
    try:
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Loaded model from {model_path} with validation accuracy: {checkpoint.get('val_acc', 0):.2f}%")
    except Exception as e:
        print(f"Error loading model: {e}")
        print("Using uninitialized model instead.")
    
    model = model.to(device)
    model.eval()
    return model

# Load our trained model
inference_model = load_model_for_inference("models/best_model.pth", len(train_dataset.classes))

# For demonstration, we'll use the class names from the training dataset
class_names = train_dataset.classes

# Load care database
with open("data/care_database.json", "r") as f:
    care_database = json.load(f)

Let's test our inference functionality with a sample image:

In [None]:
# Let's test with one of our demo images
if len(demo_images) > 0:
    test_image_path = demo_images[0]
    
    # Show the image
    plt.figure(figsize=(8, 8))
    img = Image.open(test_image_path).convert('RGB')
    plt.imshow(img)
    plt.title("Test Image")
    plt.axis('off')
    plt.show()
    
    # Identify the plant
    results = identify_plant(inference_model, test_image_path, class_names)
    
    # Display results
    print("\nIdentification Results:")
    for result in results:
        print(f"{result['rank']}. {result['class_name']}: {result['probability']:.2%} confidence")
    
    # Get care information for the top prediction
    top_prediction = results[0]['class_name']
    care_info = get_plant_care_info(top_prediction, care_database)
    
    if care_info:
        print("\nCare Information:")
        print(f"Common name: {care_info.get('common_name', 'Unknown')}")
        print(f"Light: {care_info.get('light', 'Unknown')}")
        print(f"Water: {care_info.get('water', 'Unknown')}")
        print(f"Temperature: {care_info.get('temperature', 'Unknown')}")
    else:
        print("\nNo care information available for this plant.")
else:
    print("No demo images available for testing.")

## 6. Care Recommendations <a id="care"></a>

Let's implement a function to format care information for a better user experience:

In [None]:
def format_care_info(care_info):
    """Format care information for display."""
    if not care_info:
        return "No care information available for this plant."
    
    common_name = care_info.get('common_name', 'Unknown')
    
    # Format care sections with icons
    care_sections = [
        ('light', 'Light Requirements', '☀️'),
        ('water', 'Watering Needs', '💧'),
        ('soil', 'Soil Type', '🌱'),
        ('temperature', 'Temperature', '🌡️'),
        ('humidity', 'Humidity', '💨'),
        ('fertilizer', 'Fertilizer', '🧪'),
        ('common_issues', 'Common Issues', '⚠️')
    ]
    
    # Build the formatted output
    output = f"## {common_name} Care Guide\n\n"
    
    for key, label, icon in care_sections:
        if key in care_info and care_info[key]:
            output += f"### {icon} {label}\n"
            output += f"{care_info[key]}\n\n"
    
    # Add toxicity warning if applicable
    if 'toxicity' in care_info and care_info['toxicity']:
        output += f"### ⚠️ Toxicity\n"
        output += f"{care_info['toxicity']}\n"
    
    return output

# Display formatted care information for a sample plant
if "Monstera deliciosa" in care_database:
    care_info = care_database["Monstera deliciosa"]
    formatted_care = format_care_info(care_info)
    print(formatted_care)

## 7. Interactive Interface with Gradio <a id="interface"></a>

Now, let's create an interactive interface using Gradio to make our plant identification assistant user-friendly:

In [None]:
def identify_and_get_care(image):
    """Identify a plant from an image and provide care information."""
    if image is None:
        return "Please upload an image", "", 0
    
    # Identify the plant
    results = identify_plant(inference_model, image, class_names)
    
    # Format the identification results
    identification_html = "<div style='padding: 10px; background-color: #f8f9fa; border-radius: 10px;'>"
    identification_html += "<h3>Identification Results:</h3>"
    
    # Add top prediction with larger styling
    top_pred = results[0]
    scientific_name = top_pred['class_name']
    
    # Try to get common name from care database
    care_info = get_plant_care_info(scientific_name, care_database)
    common_name = care_info.get('common_name', '') if care_info else ''
    
    name_display = f"{scientific_name}"
    if common_name:
        name_display += f" <span style='font-style: italic;'>({common_name})</span>"
    
    identification_html += f"<div style='margin-bottom: 15px;'>"
    identification_html += f"<p style='font-size: 18px; font-weight: bold;'>{name_display}</p>"
    identification_html += f"<p>Confidence: {top_pred['probability']:.1%}</p>"
    identification_html += "</div>"
    
    # Add other predictions if there are any
    if len(results) > 1:
        identification_html += "<div style='margin-top: 10px;'>"
        identification_html += "<h4>Other possibilities:</h4>"
        identification_html += "<ul>"
        
        for pred in results[1:]:
            # Try to get common name
            pred_care = get_plant_care_info(pred['class_name'], care_database)
            pred_common = pred_care.get('common_name', '') if pred_care else ''
            
            pred_display = f"{pred['class_name']}"
            if pred_common:
                pred_display += f" <span style='font-style: italic;'>({pred_common})</span>"
                
            identification_html += f"<li>{pred_display} <span>({pred['probability']:.1%})</span></li>"
        
        identification_html += "</ul>"
        identification_html += "</div>"
    
    identification_html += "</div>"
    
    # Get and format care information
    if care_info:
        care_text = format_care_info(care_info)
    else:
        care_text = "No care information available for this plant."
    
    # Return confidence as a percentage for the meter
    confidence = float(top_pred['probability']) * 100
    
    return identification_html, care_text, confidence

# Create the Gradio interface
def create_gradio_interface():
    """Create a Gradio interface for the plant identification assistant."""
    # Title and description
    title = "Common Houseplant Identification Assistant"
    description = """
    Upload an image of your houseplant, and this tool will identify the species and provide care recommendations.
    The model can recognize common houseplant species and provide detailed care instructions.
    """
    
    # CSS for styling
    css = """
    .gradio-container {max-width: 900px;}
    h1 {text-align: center; color: #2e7d32;}
    .confidence-meter {height: 25px; border-radius: 5px; overflow: hidden;}
    .confidence-meter .bar {height: 100%; background-color: #4caf50;}
    """
    
    # Create the interface
    with gr.Blocks(css=css) as interface:
        gr.Markdown(f"<h1>{title}</h1>")
        gr.Markdown(description)
        
        with gr.Row():
            with gr.Column(scale=1):
                # Input components
                input_image = gr.Image(
                    type="pil",
                    label="Upload or take a photo of your houseplant"
                )
                identify_button = gr.Button("Identify Plant", variant="primary")
                
            with gr.Column(scale=1):
                # Output components
                confidence_meter = gr.Number(
                    label="Confidence (%)", 
                    value=0,
                    interactive=False
                )
                identification_result = gr.HTML(
                    label="Identification Results",
                    value="Upload an image and click 'Identify Plant' to get results."
                )
                care_info = gr.Markdown(
                    label="Care Information",
                    value="Plant care information will appear here after identification."
                )
        
        # Set up the action
        identify_button.click(
            fn=identify_and_get_care,
            inputs=[input_image],
            outputs=[identification_result, care_info, confidence_meter]
        )
        
        # Footer info
        gr.Markdown("""
        ### Tips for Best Results
        
        - Take photos in natural light
        - Capture both leaves and overall plant structure
        - Avoid blurry or poorly lit images
        - For more accurate identification, try multiple angles
        
        This tool works best with common houseplants. If your plant isn't identified correctly, 
        consider consulting a plant specialist or reference book.
        """)
    
    return interface

# Create and launch the interface
interface = create_gradio_interface()
interface.launch(inline=True)

## Conclusion

Congratulations! You've implemented a complete plant identification assistant in this Jupyter notebook. This notebook includes:

1. Data preparation and loading
2. Model creation using transfer learning
3. Training and evaluation functions
4. Plant identification and care recommendation features
5. An interactive Gradio interface

### Next Steps

To create a full-featured application, you could:

1. **Collect a larger dataset**: Gather more plant images for better accuracy
2. **Train with more epochs**: Run training for longer to improve the model
3. **Expand the care database**: Add more plant species and detailed care information
4. **Add plant health analysis**: Extend the model to detect common plant diseases
5. **Deploy the application**: Export and deploy as a web service or mobile app

This notebook provides a solid foundation that you can build upon to create a practical tool for plant enthusiasts.