# PlantCLEF 2025 Inference API

**Purpose:** Provide REST API endpoint for plant identification using PlantCLEF 2025 dataset

**Dataset:** PlantCLEF 2025 (1TB+, 10,000+ species)

**Usage:**
1. Upload this notebook to Kaggle
2. Add PlantCLEF 2025 dataset
3. Enable internet access
4. Run all cells
5. Use ngrok URL as API endpoint

**API Endpoint:**
```
POST /predict
Content-Type: multipart/form-data
Body: image file

Response:
{
  "predictions": [
    {"species": "Rosa damascena", "confidence": 0.95, "common_name": "Damascus Rose"},
    ...
  ],
  "inference_time": 2.3,
  "model": "PlantCLEF-2025-ResNet50"
}
```

In [None]:
# Install required packages
!pip install flask flask-cors pyngrok pillow torch torchvision timm -q

In [None]:
# Import libraries
import os
import io
import json
import time
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from flask import Flask, request, jsonify
from flask_cors import CORS
from pyngrok import ngrok
import timm

print("✅ All libraries imported successfully")
print(f"🔥 PyTorch version: {torch.__version__}")
print(f"🎮 CUDA available: {torch.cuda.is_available()}")

In [None]:
# Dataset path (Kaggle Input)
DATASET_PATH = '/kaggle/input/plantclef2025'

# Check if dataset exists
if os.path.exists(DATASET_PATH):
    print(f"✅ PlantCLEF 2025 dataset found at {DATASET_PATH}")
    # List first few directories
    subdirs = os.listdir(DATASET_PATH)[:10]
    print(f"📁 First 10 directories: {subdirs}")
else:
    print("❌ Dataset not found! Please add PlantCLEF 2025 dataset to this notebook.")
    print("   Go to: Add Data > Search 'PlantCLEF 2025'")

In [None]:
# Load species mapping
# PlantCLEF 2025 typically has a species_map.json or similar
species_map = {}
species_map_path = os.path.join(DATASET_PATH, 'species_map.json')

if os.path.exists(species_map_path):
    with open(species_map_path, 'r') as f:
        species_map = json.load(f)
    print(f"✅ Loaded {len(species_map)} species mappings")
else:
    print("⚠️ Species map not found, using directory names as species IDs")
    # Create mapping from directory names
    if os.path.exists(DATASET_PATH):
        species_dirs = [d for d in os.listdir(DATASET_PATH) if os.path.isdir(os.path.join(DATASET_PATH, d))]
        species_map = {str(i): name for i, name in enumerate(species_dirs)}
        print(f"✅ Created mapping for {len(species_map)} species")

In [None]:
# Load pre-trained model
# Option 1: Load from Kaggle models (if available)
# Option 2: Load pretrained ResNet/EfficientNet and fine-tune
# Option 3: Use timm library for plant classification models

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🎮 Using device: {device}")

# Load model (example: ResNet50 pretrained on ImageNet, can be fine-tuned)
model_name = 'resnet50'
num_classes = len(species_map) if species_map else 10000  # PlantCLEF has ~10k species

print(f"📦 Loading model: {model_name}")
model = timm.create_model(model_name, pretrained=True, num_classes=num_classes)
model = model.to(device)
model.eval()

print(f"✅ Model loaded successfully")
print(f"   Total parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Image preprocessing
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def predict_image(image_bytes, top_k=5):
    """
    Predict plant species from image bytes
    
    Args:
        image_bytes: Image file bytes
        top_k: Number of top predictions to return
    
    Returns:
        List of predictions with species, confidence, and common name
    """
    start_time = time.time()
    
    # Load image
    image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
    
    # Preprocess
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    # Inference
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        top_probs, top_indices = torch.topk(probabilities, top_k)
    
    # Format results
    predictions = []
    for prob, idx in zip(top_probs[0].cpu().numpy(), top_indices[0].cpu().numpy()):
        species_id = str(idx)
        species_name = species_map.get(species_id, f"Unknown_Species_{idx}")
        
        predictions.append({
            "species": species_name,
            "confidence": float(prob),
            "common_name": species_name.replace('_', ' ').title(),
            "species_id": species_id
        })
    
    inference_time = time.time() - start_time
    
    return {
        "predictions": predictions,
        "inference_time": round(inference_time, 3),
        "model": model_name,
        "device": str(device)
    }

print("✅ Prediction function ready")

In [None]:
# Create Flask API
app = Flask(__name__)
CORS(app)  # Enable CORS for external access

@app.route('/', methods=['GET'])
def home():
    return jsonify({
        "service": "PlantCLEF 2025 Inference API",
        "status": "running",
        "model": model_name,
        "num_species": len(species_map),
        "endpoints": {
            "/": "GET - Service info",
            "/predict": "POST - Plant identification (multipart/form-data)",
            "/health": "GET - Health check"
        }
    })

@app.route('/health', methods=['GET'])
def health():
    return jsonify({
        "status": "healthy",
        "model_loaded": True,
        "device": str(device),
        "cuda_available": torch.cuda.is_available()
    })

@app.route('/predict', methods=['POST'])
def predict():
    try:
        # Check if image file is present
        if 'image' not in request.files:
            return jsonify({"error": "No image file provided"}), 400
        
        file = request.files['image']
        if file.filename == '':
            return jsonify({"error": "Empty filename"}), 400
        
        # Read image bytes
        image_bytes = file.read()
        
        # Get top_k parameter (default 5)
        top_k = int(request.form.get('top_k', 5))
        
        # Predict
        result = predict_image(image_bytes, top_k=top_k)
        
        return jsonify(result)
    
    except Exception as e:
        return jsonify({
            "error": str(e),
            "type": type(e).__name__
        }), 500

print("✅ Flask API created")

In [None]:
# Setup ngrok tunnel (for external access)
# You'll need an ngrok auth token
# Get free token from: https://dashboard.ngrok.com/get-started/your-authtoken

NGROK_AUTH_TOKEN = "YOUR_NGROK_TOKEN_HERE"  # Replace with your token

if NGROK_AUTH_TOKEN != "YOUR_NGROK_TOKEN_HERE":
    ngrok.set_auth_token(NGROK_AUTH_TOKEN)
    print("✅ Ngrok auth token set")
else:
    print("⚠️ Please set your ngrok auth token above")
    print("   Get one free at: https://dashboard.ngrok.com/get-started/your-authtoken")

In [None]:
# Start Flask server with ngrok tunnel
from threading import Thread

def run_flask():
    app.run(host='0.0.0.0', port=5000)

# Start Flask in background thread
thread = Thread(target=run_flask)
thread.daemon = True
thread.start()

# Wait for Flask to start
time.sleep(2)

# Create ngrok tunnel
if NGROK_AUTH_TOKEN != "YOUR_NGROK_TOKEN_HERE":
    public_url = ngrok.connect(5000)
    print("\n" + "="*70)
    print("🌐 PUBLIC API URL:")
    print(f"   {public_url}")
    print("="*70)
    print("\n📝 Usage:")
    print(f"   curl -X POST {public_url}/predict -F 'image=@plant.jpg'")
    print("\n⚠️ Copy this URL to your backend .env file as KAGGLE_NOTEBOOK_URL")
    print("\n🔄 Keep this notebook running to maintain the API endpoint")
    print("="*70)
else:
    print("⚠️ Ngrok token not set, API only available locally on port 5000")

In [None]:
# Keep notebook alive
print("✅ API Server is running!")
print("   Press Ctrl+C to stop (but this will stop the API)")
print("   Keep this cell running to maintain the API endpoint")

try:
    while True:
        time.sleep(60)
except KeyboardInterrupt:
    print("\n🛑 Server stopped")