# Common Houseplant Identification Assistant with PlantNet

This notebook implements a machine learning-based application that helps identify common houseplants from images and provides basic care recommendations. We'll be using a pre-trained PlantNet model for improved accuracy.

## Notebook Contents

1. [Setup and Installation](#setup)
2. [Data Preparation](#data)
3. [PlantNet Model Integration](#plantnet)
4. [Model Fine-tuning](#fine-tuning)
5. [Inference and Evaluation](#inference)
6. [Care Recommendations](#care)
7. [Interactive Interface with Gradio](#interface)

Let's begin by setting up our environment and installing the necessary dependencies.

## 1. Setup and Installation <a id="setup"></a>

First, we'll install the required packages:

In [None]:
# Install required packages
!pip install torch torchvision tqdm numpy matplotlib pandas pillow gradio scikit-learn requests

In [None]:
# Import necessary libraries
import os
import sys
import json
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm.notebook import tqdm
from pathlib import Path
import requests
from io import BytesIO
import time
from sklearn.metrics import classification_report, confusion_matrix
import gradio as gr
import warnings
warnings.filterwarnings('ignore')

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create necessary directories
os.makedirs('data', exist_ok=True)
os.makedirs('models', exist_ok=True)
os.makedirs('outputs', exist_ok=True)

## 2. Data Preparation <a id="data"></a>

We'll now define our dataset of common houseplants. For the purposes of this notebook, we'll create a database with information about 20 common houseplant species and their care requirements.

In [None]:
# Define a list of common houseplant species
COMMON_HOUSEPLANTS = [
    "Monstera deliciosa",        # Swiss Cheese Plant
    "Ficus lyrata",              # Fiddle Leaf Fig
    "Sansevieria trifasciata",   # Snake Plant
    "Chlorophytum comosum",      # Spider Plant
    "Epipremnum aureum",         # Pothos
    "Spathiphyllum wallisii",    # Peace Lily
    "Zamioculcas zamiifolia",    # ZZ Plant
    "Dracaena marginata",        # Dragon Tree
    "Calathea makoyana",         # Peacock Plant
    "Pilea peperomioides",       # Chinese Money Plant
    "Philodendron bipinnatifidum", # Split-leaf Philodendron
    "Aloe vera",                 # Aloe Vera
    "Ficus elastica",            # Rubber Plant
    "Maranta leuconeura",        # Prayer Plant
    "Aglaonema commutatum",      # Chinese Evergreen
    "Peperomia obtusifolia",     # Baby Rubber Plant
    "Anthurium andraeanum",      # Flamingo Flower
    "Schlumbergera bridgesii",   # Christmas Cactus
    "Crassula ovata",            # Jade Plant
    "Aspidistra elatior"         # Cast Iron Plant
]

In [None]:
# Create care information database
care_database = {
  "Monstera deliciosa": {
    "common_name": "Swiss Cheese Plant",
    "light": "Bright indirect light. Avoid direct sunlight which can burn the leaves.",
    "water": "Allow top 2-3 inches of soil to dry out between waterings. Water less in winter.",
    "soil": "Well-draining potting mix with some peat and perlite.",
    "temperature": "65-85°F (18-29°C). Keep away from cold drafts.",
    "humidity": "Prefers moderate to high humidity. Regular misting is beneficial.",
    "fertilizer": "Feed monthly during growing season with balanced houseplant fertilizer.",
    "common_issues": "Brown leaf tips (low humidity), yellow leaves (overwatering), lack of fenestration (insufficient light).",
    "toxicity": "Toxic to pets if ingested. Contains calcium oxalate crystals."
  },
  "Ficus lyrata": {
    "common_name": "Fiddle Leaf Fig",
    "light": "Bright indirect light. Some direct morning sun is beneficial.",
    "water": "Water when top inch of soil is dry. Ensure thorough drainage.",
    "soil": "Well-draining potting mix with peat and perlite.",
    "temperature": "60-75°F (15-24°C). Avoid temperature fluctuations.",
    "humidity": "Moderate humidity, around 40-60%.",
    "fertilizer": "Feed monthly in spring and summer with diluted houseplant fertilizer.",
    "common_issues": "Brown spots (overwatering or bacterial infection), leaf drop (stress from relocation or temperature changes).",
    "toxicity": "Mildly toxic to pets and humans if ingested."
  },
  "Sansevieria trifasciata": {
    "common_name": "Snake Plant",
    "light": "Adaptable to various light conditions from low light to bright indirect light.",
    "water": "Allow soil to dry completely between waterings. Water sparingly in winter.",
    "soil": "Well-draining, sandy soil mix.",
    "temperature": "55-85°F (13-29°C). Can tolerate temperature fluctuations.",
    "humidity": "Adaptable to various humidity levels, including dry air.",
    "fertilizer": "Feed sparingly with cactus fertilizer during growing season.",
    "common_issues": "Root rot (overwatering), brown tips (fluoride in water).",
    "toxicity": "Mildly toxic to pets if ingested."
  }
}

# Add more plants to the care database
care_database["Chlorophytum comosum"] = {
    "common_name": "Spider Plant",
    "light": "Bright indirect light. Can tolerate lower light conditions.",
    "water": "Keep soil moderately moist. Allow top to dry slightly between waterings.",
    "soil": "Well-draining potting mix.",
    "temperature": "60-75°F (15-24°C).",
    "humidity": "Prefers moderate humidity but tolerates dry air.",
    "fertilizer": "Feed monthly during growing season with balanced fertilizer.",
    "common_issues": "Brown tips (fluoride in water, dry air), pale leaves (too much light).",
    "toxicity": "Non-toxic to pets and humans."
}

care_database["Epipremnum aureum"] = {
    "common_name": "Pothos",
    "light": "Adaptable to various light conditions. Prefers moderate indirect light.",
    "water": "Allow top inch of soil to dry between waterings.",
    "soil": "Standard potting mix with good drainage.",
    "temperature": "65-85°F (18-29°C).",
    "humidity": "Adaptable to normal household humidity.",
    "fertilizer": "Feed monthly with balanced houseplant fertilizer.",
    "common_issues": "Yellow leaves (overwatering), leggy growth (insufficient light).",
    "toxicity": "Toxic to pets if ingested. Contains calcium oxalate crystals."
}

# Save the care database to a JSON file
with open("data/care_database.json", "w") as f:
    json.dump(care_database, f, indent=2)

print(f"Created care database with {len(care_database)} plant species")

### Dataset and DataLoader

Now, let's define our dataset class for loading and preprocessing plant images.

In [None]:
# Define image size and transformations
IMAGE_SIZE = 224

# Create transformations optimized for PlantNet
def get_plantnet_transforms():
    """Get the transforms that match PlantNet's preprocessing."""
    
    # PlantNet typically uses these preprocessing steps
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    
    return data_transforms

# Create dataset class
class HouseplantDataset(Dataset):
    """Dataset class for houseplant images."""
    
    def __init__(self, data_dir, transform=None, class_names=None):
        """Initialize the dataset."""
        self.data_dir = data_dir
        self.transform = transform
        
        # If there's actual data in the directory
        if os.path.exists(data_dir) and len(os.listdir(data_dir)) > 0:
            self.classes = sorted([d for d in os.listdir(data_dir) 
                             if os.path.isdir(os.path.join(data_dir, d))])
            self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
            self.samples = self._make_dataset()
        # If we're just initializing with class names (for testing/demo)
        elif class_names is not None:
            self.classes = class_names
            self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
            self.samples = []  # Empty samples list
        else:
            # Fallback to empty lists
            self.classes = []
            self.class_to_idx = {}
            self.samples = []
    
    def _make_dataset(self):
        """Create a list of (image_path, class_idx) tuples."""
        samples = []
        for class_name in self.classes:
            class_dir = os.path.join(self.data_dir, class_name)
            class_idx = self.class_to_idx[class_name]
            
            for root, _, files in os.walk(class_dir):
                for file in files:
                    if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                        samples.append((os.path.join(root, file), class_idx))
        
        return samples
    
    def __len__(self):
        """Return the number of samples in the dataset."""
        return len(self.samples)
    
    def __getitem__(self, idx):
        """Get a sample from the dataset."""
        if len(self.samples) == 0:
            # Create a dummy sample for testing/demo
            dummy_tensor = torch.zeros((3, IMAGE_SIZE, IMAGE_SIZE))
            return dummy_tensor, 0
        
        image_path, class_idx = self.samples[idx]
        
        # Load and convert image
        try:
            image = Image.open(image_path).convert('RGB')
            
            # Apply transforms if available
            if self.transform:
                image = self.transform(image)
            
            return image, class_idx
            
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            # Return a placeholder if image loading fails
            placeholder = torch.zeros((3, IMAGE_SIZE, IMAGE_SIZE))
            return placeholder, class_idx
    
    def get_class_name(self, class_idx):
        """Get the class name for a given class index."""
        if class_idx < len(self.classes):
            return self.classes[class_idx]
        else:
            return "Unknown"

For this notebook, let's create a small dummy dataset for demonstration purposes. In a real scenario, you would download and prepare a real dataset with plant images.

In [None]:
# Let's create a dummy dataset for demonstration
# In a real scenario, you would download a real dataset

# Function to download a sample image
def download_sample_image(plant_name, save_dir="data/demo"):
    """Create a placeholder image for a given plant species."""
    # Create a placeholder image
    img = Image.new('RGB', (300, 300), color=(73, 109, 137))
    
    # Save in the appropriate directory
    os.makedirs(os.path.join(save_dir, plant_name.replace(" ", "_")), exist_ok=True)
    img_path = os.path.join(save_dir, plant_name.replace(" ", "_"), f"{plant_name.replace(' ', '_')}_sample.jpg")
    img.save(img_path)
    
    return img_path

# Create a small dummy dataset
demo_plants = COMMON_HOUSEPLANTS[:5]  # Just use the first 5 plants for demo
demo_dir = "data/demo"
os.makedirs(demo_dir, exist_ok=True)

demo_images = []
for plant in demo_plants:
    img_path = download_sample_image(plant, demo_dir)
    demo_images.append(img_path)

print(f"Created demo dataset with {len(demo_plants)} plant species")

## 3. PlantNet Model Integration <a id="plantnet"></a>

Now, let's integrate the pre-trained PlantNet model for improved plant identification. PlantNet has been trained on a large dataset of plant images and will provide better accuracy than training from scratch.

In [None]:
def download_plantnet_weights(url, save_path):
    """
    Download PlantNet pre-trained weights if they don't exist locally.
    
    Args:
        url (str): URL to download the weights from
        save_path (str): Path to save the weights file
        
    Returns:
        str: Path to the weights file
    """
    import os
    import requests
    from tqdm.notebook import tqdm
    
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    # Check if file already exists
    if os.path.exists(save_path):
        print(f"PlantNet weights already exist at {save_path}")
        return save_path
    
    # File doesn't exist, download it
    print(f"Downloading PlantNet weights to {save_path}...")
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    
    with open(save_path, 'wb') as file, tqdm(
        desc=os.path.basename(save_path),
        total=total_size,
        unit='B',
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for data in response.iter_content(chunk_size=1024):
            size = file.write(data)
            bar.update(size)
    
    print("Download complete!")
    return save_path

def load_plantnet_model(weights_path, num_classes=1081, use_gpu=True):
    """
    Load the PlantNet pre-trained model.
    
    Args:
        weights_path (str): Path to the pre-trained weights
        num_classes (int): Number of classes in the pre-trained model
        use_gpu (bool): Whether to load the model on GPU
        
    Returns:
        torch.nn.Module: Loaded model
    """
    import torch
    from torchvision.models import resnet18
    
    # Create the model architecture
    model = resnet18(pretrained=False)
    model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
    
    # Load the weights
    if use_gpu and torch.cuda.is_available():
        checkpoint = torch.load(weights_path)
    else:
        checkpoint = torch.load(weights_path, map_location=torch.device('cpu'))
    
    # Load state dict from checkpoint
    if 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    else:
        state_dict = checkpoint
    
    # Remove 'module.' prefix if present (from DataParallel)
    if list(state_dict.keys())[0].startswith('module.'):
        state_dict = {k[7:]: v for k, v in state_dict.items()}
    
    model.load_state_dict(state_dict)
    
    # Move model to GPU if requested
    if use_gpu and torch.cuda.is_available():
        model = model.cuda()
    
    model.eval()  # Set to evaluation mode
    return model

Let's try to download and load the PlantNet pre-trained model. Note that for demonstration purposes, we'll simulate this since the actual weights file might be quite large.

In [None]:
# Normally, we would download the actual PlantNet weights from their site
# For this notebook, we'll simulate it by creating a dummy weights file

def create_dummy_weights_file():
    """Create a dummy weights file for demonstration."""
    # Create a dummy ResNet18 model
    dummy_model = models.resnet18(pretrained=True)
    dummy_model.fc = nn.Linear(dummy_model.fc.in_features, 1081)  # PlantNet has 1081 classes
    
    # Save the model
    os.makedirs('models/plantnet', exist_ok=True)
    weights_path = 'models/plantnet/resnet18_weights_best_acc.tar'
    torch.save({
        'state_dict': dummy_model.state_dict(),
        'epoch': 100,
        'best_acc': 0.85
    }, weights_path)
    
    return weights_path

# Create dummy weights file for demonstration
weights_path = create_dummy_weights_file()
print(f"Created dummy PlantNet weights file at {weights_path}")

In [None]:
# Load the PlantNet model
plantnet_model = load_plantnet_model(
    weights_path=weights_path,
    num_classes=1081,  # PlantNet has 1081 classes
    use_gpu=(device.type == 'cuda')
)

print(f"Loaded PlantNet model: {type(plantnet_model).__name__}")
print(f"Output layer: {plantnet_model.fc}")

Now, let's adapt the PlantNet model for our specific task of identifying common houseplants:

In [None]:
def adapt_plantnet_model(plantnet_model, target_classes, device):
    """
    Adapt the PlantNet pre-trained model for our specific plant classes.
    
    Args:
        plantnet_model: The pre-trained PlantNet model
        target_classes (list): List of our target class names
        device: Device to load the model on
        
    Returns:
        torch.nn.Module: Adapted model
    """
    import torch
    import torch.nn as nn
    
    # Freeze all parameters to use the feature extraction part as is
    for param in plantnet_model.parameters():
        param.requires_grad = False
    
    # Replace the final fully connected layer
    num_ftrs = plantnet_model.fc.in_features
    plantnet_model.fc = nn.Linear(num_ftrs, len(target_classes))
    
    # Move model to the appropriate device
    plantnet_model = plantnet_model.to(device)
    
    return plantnet_model

# Adapt the PlantNet model for our houseplant classes
adapted_model = adapt_plantnet_model(
    plantnet_model=plantnet_model,
    target_classes=COMMON_HOUSEPLANTS,
    device=device
)

print(f"Adapted PlantNet model for {len(COMMON_HOUSEPLANTS)} houseplant classes")
print(f"New output layer: {adapted_model.fc}")

## 4. Model Fine-tuning <a id="fine-tuning"></a>

Now that we have adapted the PlantNet model for our specific classes, let's fine-tune it on our dataset. Since we're using transfer learning, we'll only need to train the final layer.

In [None]:
def fine_tune_model(model, train_loader, val_loader, num_epochs=5, learning_rate=0.001):
    """
    Fine-tune the adapted PlantNet model on our dataset.
    
    Args:
        model: The adapted PlantNet model
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        num_epochs: Number of training epochs
        learning_rate: Learning rate for the optimizer
        
    Returns:
        tuple: (trained model, training metrics)
    """
    # Only train the final layer
    params_to_update = []
    for name, param in model.named_parameters():
        if param.requires_grad:
            params_to_update.append(param)
    
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(params_to_update, lr=learning_rate)
    
    # Training metrics
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    best_val_acc = 0.0
    
    print(f"Starting fine-tuning for {num_epochs} epochs...")
    print(f"Training {len(params_to_update)} parameters")
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        # Progress bar
        pbar = tqdm(train_loader, desc="Training")
        
        for inputs, targets in pbar:
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(inputs)
            
            # Calculate loss
            loss = criterion(outputs, targets)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Update metrics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # Update progress bar
            pbar.set_postfix({
                'loss': running_loss / (pbar.n + 1),
                'acc': 100. * correct / total if total > 0 else 0
            })
        
        # Calculate epoch metrics
        train_loss = running_loss / len(train_loader) if len(train_loader) > 0 else 0
        train_acc = 100. * correct / total if total > 0 else 0
        
        # Validation phase
        model.eval()
        val_running_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                
                # Forward pass
                outputs = model(inputs)
                
                # Calculate loss
                loss = criterion(outputs, targets)
                
                # Update metrics
                val_running_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()
        
        # Calculate validation metrics
        val_loss = val_running_loss / len(val_loader) if len(val_loader) > 0 else 0
        val_acc = 100. * val_correct / val_total if val_total > 0 else 0
        
        # Save metrics
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch': epoch,
                'val_acc': val_acc,
                'train_acc': train_acc
            }, "models/best_plantnet_adapted.pth")
            print(f"Saved new best model with validation accuracy: {val_acc:.2f}%")
    
    print(f"Fine-tuning completed. Best validation accuracy: {best_val_acc:.2f}%")
    
    # Return metrics for plotting
    return model, (train_losses, val_losses, train_accs, val_accs)

In [None]:
# Create train and validation datasets with PlantNet transforms
transforms_dict = get_plantnet_transforms()
train_transform = transforms_dict['train']
val_transform = transforms_dict['val']

# We'll use the demo dataset for both training and validation
train_dataset = HouseplantDataset(demo_dir, transform=train_transform)
val_dataset = HouseplantDataset(demo_dir, transform=val_transform)

# Create data loaders
batch_size = 4  # Small batch size for demo
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"Created datasets with {len(train_dataset.classes)} classes")
print(f"Class names: {train_dataset.classes}")

In [None]:
# Fine-tune the model
if len(train_dataset.samples) > 0 and len(val_dataset.samples) > 0:
    finetuned_model, metrics = fine_tune_model(
        model=adapted_model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=2,  # Just 2 epochs for demonstration
        learning_rate=0.001
    )
    
    train_losses, val_losses, train_accs, val_accs = metrics
    
    # Plotting training metrics
    plt.figure(figsize=(12, 5))
    
    # Plot losses
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    
    # Plot accuracies
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Training Accuracy')
    plt.plot(val_accs, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Training and Validation Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
else:
    print("Skipping fine-tuning due to empty dataset.")
    print("For a real implementation, you would use a proper dataset.")
    
    # Save the adapted model for demonstration
    torch.save({
        'model_state_dict': adapted_model.state_dict(),
        'epoch': 0,
        'val_acc': 0.0,
        'train_acc': 0.0
    }, "models/best_plantnet_adapted.pth")
    
    print("Saved the adapted model for demonstration purposes.")
    finetuned_model = adapted_model  # Just use the adapted model

## 5. Inference and Evaluation <a id="inference"></a>

Now let's implement functions to use our fine-tuned PlantNet model for inference.

In [None]:
def identify_plant_with_plantnet(model, image, class_names, device, top_k=3):
    """
    Identify a plant using the PlantNet model.
    
    Args:
        model: The PlantNet model
        image: PIL Image or path to image
        class_names: List of class names
        device: Device to run inference on
        top_k: Number of top predictions to return
        
    Returns:
        list: Top predictions
    """
    import torch
    from PIL import Image
    from torchvision import transforms
    
    # Ensure model is in eval mode
    model.eval()
    
    # Preprocess the image
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    # Load image if it's a path
    if isinstance(image, str):
        image = Image.open(image).convert('RGB')
    elif isinstance(image, np.ndarray):
        image = Image.fromarray(image).convert('RGB')
    
    # Preprocess and add batch dimension
    image_tensor = preprocess(image).unsqueeze(0).to(device)
    
    # Get predictions
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        
        # Get top-k probabilities and indices
        top_probs, top_indices = torch.topk(probabilities, min(top_k, len(class_names)))
        
        # Convert to lists
        top_probs = top_probs.cpu().numpy()[0]
        top_indices = top_indices.cpu().numpy()[0]
    
    # Create result list
    results = []
    for i, (idx, prob) in enumerate(zip(top_indices, top_probs)):
        if idx < len(class_names):
            results.append({
                'rank': i + 1,
                'class_name': class_names[idx],
                'probability': float(prob)
            })
    
    return results

def get_plant_care_info(plant_name, care_database):
    """Get care information for a plant species."""
    # Look for exact match first
    if plant_name in care_database:
        return care_database[plant_name]
    
    # Try common name match
    for scientific_name, info in care_database.items():
        if info.get('common_name') == plant_name:
            return info
    
    # Try case-insensitive partial match
    plant_name_lower = plant_name.lower()
    for scientific_name, info in care_database.items():
        if (plant_name_lower in scientific_name.lower() or 
            plant_name_lower in info.get('common_name', '').lower()):
            return info
    
    # No match found
    return None

Let's test our plant identification with a sample image:

In [None]:
# Let's test with one of our demo images
if len(demo_images) > 0:
    test_image_path = demo_images[0]
    
    # Show the image
    plt.figure(figsize=(8, 8))
    img = Image.open(test_image_path).convert('RGB')
    plt.imshow(img)
    plt.title("Test Image")
    plt.axis('off')
    plt.show()
    
    # Identify the plant
    results = identify_plant_with_plantnet(finetuned_model, test_image_path, COMMON_HOUSEPLANTS, device)
    
    # Display results
    print("\nIdentification Results:")
    for result in results:
        print(f"{result['rank']}. {result['class_name']}: {result['probability']:.2%} confidence")
    
    # Get care information for the top prediction
    top_prediction = results[0]['class_name']
    care_info = get_plant_care_info(top_prediction, care_database)
    
    if care_info:
        print("\nCare Information:")
        print(f"Common name: {care_info.get('common_name', 'Unknown')}")
        print(f"Light: {care_info.get('light', 'Unknown')}")
        print(f"Water: {care_info.get('water', 'Unknown')}")
        print(f"Temperature: {care_info.get('temperature', 'Unknown')}")
    else:
        print("\nNo care information available for this plant.")
else:
    print("No demo images available for testing.")
