In [1]:
import tensorflow as tf
import numpy as np
import logging
from typing import Dict, List, Optional

class RiceClassifier:
    """Rice grain classification model using MobileNetV2 transfer learning"""
    
    def __init__(self):
        """Initialize the classifier"""
        self.model = None
        self.class_names = ['Arborio', 'Basmati', 'Ipsala', 'Jasmine', 'Karacadag']
        self.class_labels = {
            'arborio': 0,
            'basmati': 1,
            'ipsala': 2,
            'jasmine': 3,
            'karacadag': 4
        }
        
        # Agricultural recommendations for each rice type
        self.recommendations = {
            'Arborio': {
                'description': 'Short-grain rice ideal for risotto and Mediterranean dishes.',
                'cultivation': 'Requires consistent moisture and warm temperatures. Best grown in flooded fields.',
                'water_needs': 'High water requirement - maintain flooded conditions during growing season.',
                'fertilizer': 'Use balanced NPK fertilizer. Apply nitrogen in split doses.',
                'harvest_time': '120-150 days from planting'
            },
            'Basmati': {
                'description': 'Long-grain aromatic rice prized for its fragrance and fluffy texture.',
                'cultivation': 'Grows best in well-drained, fertile soil with good organic matter.',
                'water_needs': 'Moderate to high water requirement. Avoid waterlogging during grain filling.',
                'fertilizer': 'Organic matter-rich soil preferred. Use phosphorus-rich fertilizer.',
                'harvest_time': '120-140 days from planting'
            },
            'Ipsala': {
                'description': 'Turkish rice variety known for its cooking qualities and grain structure.',
                'cultivation': 'Adapted to Mediterranean climate conditions. Requires careful water management.',
                'water_needs': 'Moderate water requirement with good drainage.',
                'fertilizer': 'Balanced fertilization with emphasis on potassium for grain quality.',
                'harvest_time': '130-150 days from planting'
            },
            'Jasmine': {
                'description': 'Fragrant long-grain rice popular in Asian cuisine.',
                'cultivation': 'Thrives in tropical and subtropical conditions with high humidity.',
                'water_needs': 'High water requirement during vegetative growth, reduce before harvest.',
                'fertilizer': 'Nitrogen-rich fertilizer in early stages, reduce during flowering.',
                'harvest_time': '110-130 days from planting'
            },
            'Karacadag': {
                'description': 'Traditional Turkish rice variety with excellent cooking properties.',
                'cultivation': 'Hardy variety suitable for various soil types and climate conditions.',
                'water_needs': 'Moderate water requirement with tolerance to slight water stress.',
                'fertilizer': 'Responds well to organic fertilizers and balanced NPK application.',
                'harvest_time': '125-145 days from planting'
            }
        }
        
        self._load_model()
    
    def _load_model(self):
        """Load a lightweight CNN model for rice classification"""
        try:
            logging.info("Loading rice classification model...")
            
            # Create a lightweight CNN model for demo purposes
            self.model = tf.keras.Sequential([
                tf.keras.layers.Resizing(224, 224),
                tf.keras.layers.Rescaling(1./255),
                
                # Simple CNN architecture
                tf.keras.layers.Conv2D(32, 3, activation='relu'),
                tf.keras.layers.MaxPooling2D(),
                tf.keras.layers.Conv2D(64, 3, activation='relu'),
                tf.keras.layers.MaxPooling2D(),
                tf.keras.layers.Conv2D(128, 3, activation='relu'),
                tf.keras.layers.MaxPooling2D(),
                
                tf.keras.layers.GlobalAveragePooling2D(),
                tf.keras.layers.Dropout(0.5),
                tf.keras.layers.Dense(128, activation='relu'),
                tf.keras.layers.Dense(len(self.class_names), activation='softmax')
            ])
            
            # Compile the model
            self.model.compile(
                optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy']
            )
            
            # Initialize model weights by building with input shape
            self.model.build((None, 224, 224, 3))
            
            # Set random weights for demo (in production, load trained weights)
            logging.info("Model initialized successfully (demo mode)")
            
        except Exception as e:
            logging.error(f"Error loading model: {str(e)}")
            self.model = None
    
    def is_loaded(self) -> bool:
        """Check if model is loaded"""
        return self.model is not None
    
    def predict(self, image: np.ndarray) -> Optional[Dict]:
        """
        Make prediction on preprocessed image
        
        Args:
            image: Preprocessed image array of shape (1, 224, 224, 3)
            
        Returns:
            Dictionary containing prediction results or None if error
        """
        if not self.is_loaded():
            logging.error("Model not loaded")
            return None
        
        try:
            # Make prediction
            predictions = self.model(image, training=False)
            
            # Get predicted class and confidence
            predicted_class_idx = np.argmax(predictions[0])
            confidence = float(predictions[0][predicted_class_idx]) * 100
            predicted_class = self.class_names[predicted_class_idx]
            
            # Get all class probabilities
            all_predictions = []
            for i, class_name in enumerate(self.class_names):
                prob = float(predictions[0][i]) * 100
                all_predictions.append({
                    'class': class_name,
                    'probability': round(prob, 2)
                })
            
            # Sort by probability (highest first)
            all_predictions.sort(key=lambda x: x['probability'], reverse=True)
            
            # Get recommendations for predicted class
            recommendations = self.recommendations.get(predicted_class, {})
            
            return {
                'predicted_class': predicted_class,
                'confidence': round(confidence, 2),
                'all_predictions': all_predictions,
                'recommendations': recommendations
            }
            
        except Exception as e:
            logging.error(f"Error making prediction: {str(e)}")
            return None
    
    def get_class_info(self, class_name: str) -> Dict:
        """Get information about a specific rice class"""
        return self.recommendations.get(class_name, {})
