In [None]:
import os
import sys
import numpy as np
from qgis.PyQt.QtCore import QSettings, QTranslator, QCoreApplication, Qt
from qgis.PyQt.QtGui import QIcon
from qgis.PyQt.QtWidgets import (QAction, QDialog, QVBoxLayout, QHBoxLayout,
                                 QPushButton, QLabel, QComboBox, QSpinBox,
                                 QTableWidget, QTableWidgetItem, QGroupBox,
                                 QProgressBar, QMessageBox, QListWidget)
from qgis.core import (
    QgsProject,
    QgsRasterLayer,
    QgsPointXY,
    QgsRaster,
    QgsMapLayerProxyModel,
    QgsCoordinateTransform,
    QgsFeature,
    QgsGeometry,
    QgsVectorLayer,
    QgsField,
    QgsFeatureRequest,
    QgsSpatialIndex,
    QgsRectangle
)
from qgis.gui import QgsMapToolEmitPoint, QgsMapCanvas, QgsVertexMarker
from PyQt5.QtCore import pyqtSignal, QVariant
from PyQt5.QtGui import QColor

# Import modAL for active learning
from modAL.models import ActiveLearner
from modAL.uncertainty import margin_sampling  # Changed from uncertainty_sampling
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report
import joblib
from datetime import datetime


class PixelLabelingTool(QgsMapToolEmitPoint):
    """Custom map tool for labeling pixels"""
    
    labelAssigned = pyqtSignal(float, float, int)
    
    def __init__(self, canvas, parent=None):
        super().__init__(canvas)
        self.parent = parent
        self.canvas = canvas
        
    def canvasReleaseEvent(self, event):
        """Handle mouse click on canvas"""
        point = self.toMapCoordinates(event.pos())
        if self.parent and self.parent.current_class_value is not None:
            self.labelAssigned.emit(point.x(), point.y(), self.parent.current_class_value)


class ActiveLearningRasterPlugin:
    """QGIS Plugin implementing active learning for raster classification using Random Forest + Margin Sampling"""
    
    def __init__(self, iface):
        """Constructor.
        
        :param iface: An interface instance that will be passed to this class
                     which provides the hook by which you can manipulate the QGIS
                     application at run time.
        :type iface: QgsInterface
        """
        # Save reference to the QGIS interface
        self.iface = iface
        
        # Initialize plugin directory
        self.plugin_dir = os.path.dirname(__file__)
        
        # Initialize locale
        locale = QSettings().value('locale/userLocale')[0:2]
        locale_path = os.path.join(
            self.plugin_dir,
            'i18n',
            'ActiveLearningRaster_{}.qm'.format(locale))
        
        if os.path.exists(locale_path):
            self.translator = QTranslator()
            self.translator.load(locale_path)
            QCoreApplication.installTranslator(self.translator)
        
        # Initialize variables
        self.actions = []
        self.menu = 'Active Learning Raster'
        self.toolbar = self.iface.addToolBar('Active Learning Raster')
        self.toolbar.setObjectName('ActiveLearningRaster')
        
        # Initialize active learning components
        self.active_learner = None
        self.scaler = StandardScaler()
        self.X_pool = None
        self.y_pool = None
        self.pool_indices = None
        self.labeled_pixels = []
        self.labeled_classes = []
        self.pixel_positions = []
        self.current_raster = None
        self.map_tool = None
        self.current_class_value = None
        self.query_history = []
        self.markers = []
        
        # Default class configuration
        self.class_config = {
            1: {'name': 'Water', 'color': QColor(0, 0, 255)},
            2: {'name': 'Vegetation', 'color': QColor(0, 255, 0)},
            3: {'name': 'Urban', 'color': QColor(255, 0, 0)},
            4: {'name': 'Bare Soil', 'color': QColor(139, 69, 19)}
        }
        
    def add_action(self, icon_path, text, callback, enabled_flag=True,
                   add_to_menu=True, add_to_toolbar=True, status_tip=None,
                   whats_this=None, parent=None):
        """Add a toolbar icon to the toolbar."""
        icon = QIcon(icon_path)
        action = QAction(icon, text, parent)
        action.triggered.connect(callback)
        action.setEnabled(enabled_flag)
        
        if status_tip is not None:
            action.setStatusTip(status_tip)
            
        if whats_this is not None:
            action.setWhatsThis(whats_this)
            
        if add_to_toolbar:
            self.toolbar.addAction(action)
            
        if add_to_menu:
            self.iface.addPluginToRasterMenu(self.menu, action)
            
        self.actions.append(action)
        return action
    
    def initGui(self):
        """Create the menu entries and toolbar icons inside the QGIS GUI."""
        icon_path = os.path.join(self.plugin_dir, 'icon.png')
        self.add_action(
            icon_path,
            text="Active Learning for Raster",
            callback=self.run,
            parent=self.iface.mainWindow())
    
    def unload(self):
        """Removes the plugin menu item and icon from QGIS GUI."""
        for action in self.actions:
            self.iface.removePluginRasterMenu('Active Learning Raster', action)
            self.iface.removeToolBarIcon(action)
        # remove the toolbar
        del self.toolbar
        # Clean up markers
        self.clear_markers()
    
    def clear_markers(self):
        """Remove all vertex markers from the canvas"""
        for marker in self.markers:
            self.iface.mapCanvas().scene().removeItem(marker)
        self.markers = []
    
    def add_marker(self, x, y, color=QColor(255, 255, 0)):
        """Add a vertex marker to the canvas"""
        marker = QgsVertexMarker(self.iface.mapCanvas())
        marker.setCenter(QgsPointXY(x, y))
        marker.setColor(color)
        marker.setIconSize(10)
        marker.setIconType(QgsVertexMarker.ICON_CIRCLE)
        marker.setPenWidth(3)
        self.markers.append(marker)
        return marker
    
    def run(self):
        """Run method that performs all the real work"""
        # Create and show the dialog
        self.dialog = QDialog()
        self.dialog.setWindowTitle("Active Learning for Raster Classification - Random Forest + Margin Sampling")
        self.dialog.setMinimumWidth(600)
        
        # Set up the user interface
        main_layout = QVBoxLayout()
        
        # === Data Selection Group ===
        data_group = QGroupBox("Data Selection")
        data_layout = QVBoxLayout()
        
        # Add raster layer selection
        layer_layout = QHBoxLayout()
        self.layer_label = QLabel("Select raster layer:")
        self.layer_combo = QgsMapLayerComboBox()
        self.layer_combo.setFilters(QgsMapLayerProxyModel.RasterLayer)
        layer_layout.addWidget(self.layer_label)
        layer_layout.addWidget(self.layer_combo)
        data_layout.addLayout(layer_layout)
        
        # Sampling parameters
        sample_layout = QHBoxLayout()
        sample_layout.addWidget(QLabel("Initial sample size:"))
        self.initial_sample_spin = QSpinBox()
        self.initial_sample_spin.setMinimum(5)
        self.initial_sample_spin.setMaximum(100)
        self.initial_sample_spin.setValue(10)
        sample_layout.addWidget(self.initial_sample_spin)
        
        sample_layout.addWidget(QLabel("Pool size:"))
        self.pool_size_spin = QSpinBox()
        self.pool_size_spin.setMinimum(100)
        self.pool_size_spin.setMaximum(10000)
        self.pool_size_spin.setValue(1000)
        self.pool_size_spin.setSingleStep(100)
        sample_layout.addWidget(self.pool_size_spin)
        data_layout.addLayout(sample_layout)
        
        data_group.setLayout(data_layout)
        main_layout.addWidget(data_group)
        
        # === Model Configuration Group ===
        model_group = QGroupBox("Model Configuration")
        model_layout = QVBoxLayout()
        
        # Random Forest parameters
        rf_layout = QHBoxLayout()
        rf_layout.addWidget(QLabel("Number of trees:"))
        self.n_trees_spin = QSpinBox()
        self.n_trees_spin.setMinimum(10)
        self.n_trees_spin.setMaximum(500)
        self.n_trees_spin.setValue(100)
        self.n_trees_spin.setSingleStep(10)
        rf_layout.addWidget(self.n_trees_spin)
        
        rf_layout.addWidget(QLabel("Max depth:"))
        self.max_depth_spin = QSpinBox()
        self.max_depth_spin.setMinimum(1)
        self.max_depth_spin.setMaximum(50)
        self.max_depth_spin.setValue(10)
        self.max_depth_spin.setSpecialValueText("None")
        rf_layout.addWidget(self.max_depth_spin)
        model_layout.addLayout(rf_layout)
        
        model_group.setLayout(model_layout)
        main_layout.addWidget(model_group)
        
        # === Class Selection Group ===
        class_group = QGroupBox("Class Selection")
        class_layout = QVBoxLayout()
        
        self.class_list = QListWidget()
        for class_val, class_info in self.class_config.items():
            item_text = f"Class {class_val}: {class_info['name']}"
            self.class_list.addItem(item_text)
        self.class_list.itemClicked.connect(self.on_class_selected)
        class_layout.addWidget(self.class_list)
        
        self.current_class_label = QLabel("Current class: None selected")
        class_layout.addWidget(self.current_class_label)
        
        class_group.setLayout(class_layout)
        main_layout.addWidget(class_group)
        
        # === Control Buttons ===
        button_layout = QVBoxLayout()
        
        self.init_button = QPushButton("Initialize Active Learning")
        self.init_button.clicked.connect(self.initialize_active_learning)
        button_layout.addWidget(self.init_button)
        
        self.sample_button = QPushButton("Query Next Sample (Margin Sampling)")
        self.sample_button.clicked.connect(self.get_next_sample)
        self.sample_button.setEnabled(False)
        button_layout.addWidget(self.sample_button)
        
        self.manual_label_button = QPushButton("Enable Manual Labeling")
        self.manual_label_button.clicked.connect(self.toggle_manual_labeling)
        self.manual_label_button.setEnabled(False)
        self.manual_label_button.setCheckable(True)
        button_layout.addWidget(self.manual_label_button)
        
        self.train_button = QPushButton("Train on All Labeled Data")
        self.train_button.clicked.connect(self.train_model)
        self.train_button.setEnabled(False)
        button_layout.addWidget(self.train_button)
        
        self.classify_button = QPushButton("Classify Entire Raster")
        self.classify_button.clicked.connect(self.classify_raster)
        self.classify_button.setEnabled(False)
        button_layout.addWidget(self.classify_button)
        
        self.save_model_button = QPushButton("Save Model")
        self.save_model_button.clicked.connect(self.save_model)
        self.save_model_button.setEnabled(False)
        button_layout.addWidget(self.save_model_button)
        
        main_layout.addLayout(button_layout)
        
        # === Status Information ===
        status_group = QGroupBox("Status")
        status_layout = QVBoxLayout()
        
        self.progress_bar = QProgressBar()
        status_layout.addWidget(self.progress_bar)
        
        self.status_label = QLabel("Ready to start")
        status_layout.addWidget(self.status_label)
        
        self.accuracy_label = QLabel("Model accuracy: Not trained yet")
        status_layout.addWidget(self.accuracy_label)
        
        self.sample_count_label = QLabel("Labeled samples: 0")
        status_layout.addWidget(self.sample_count_label)
        
        status_group.setLayout(status_layout)
        main_layout.addWidget(status_group)
        
        self.dialog.setLayout(main_layout)
        
        # Show the dialog
        self.dialog.show()
    
    def on_class_selected(self, item):
        """Handle class selection from list"""
        class_num = int(item.text().split(':')[0].split()[1])
        self.current_class_value = class_num
        self.current_class_label.setText(f"Current class: {self.class_config[class_num]['name']}")
    
    def toggle_manual_labeling(self):
        """Toggle manual labeling mode"""
        if self.manual_label_button.isChecked():
            # Create and activate the map tool
            self.map_tool = PixelLabelingTool(self.iface.mapCanvas(), self)
            self.map_tool.labelAssigned.connect(self.label_pixel_manually)
            self.iface.mapCanvas().setMapTool(self.map_tool)
            self.manual_label_button.setText("Disable Manual Labeling")
            self.status_label.setText("Click on the map to label pixels")
        else:
            # Deactivate the map tool
            self.iface.mapCanvas().unsetMapTool(self.map_tool)
            self.manual_label_button.setText("Enable Manual Labeling")
            self.status_label.setText("Manual labeling disabled")
    
    def label_pixel_manually(self, x, y, class_value):
        """Handle manual pixel labeling"""
        if not self.current_raster:
            return
        
        # Extract pixel values at the clicked location
        provider = self.current_raster.dataProvider()
        pixel_values = []
        
        for band in range(1, provider.bandCount() + 1):
            value = provider.sample(QgsPointXY(x, y), band)[0]
            if np.isnan(value):
                self.iface.messageBar().pushMessage(
                    "Warning", "No data at this location", level=1)
                return
            pixel_values.append(value)
        
        # Add to training data
        pixel_array = np.array(pixel_values).reshape(1, -1)
        if len(self.labeled_pixels) == 0:
            self.labeled_pixels = pixel_array
            self.labeled_classes = np.array([class_value])
        else:
            self.labeled_pixels = np.vstack([self.labeled_pixels, pixel_array])
            self.labeled_classes = np.append(self.labeled_classes, class_value)
        
        # Add marker
        color = self.class_config[class_value]['color']
        self.add_marker(x, y, color)
        
        # Update active learner
        if self.active_learner is not None:
            self.active_learner.teach(
                X=self.scaler.transform(pixel_array),
                y=np.array([class_value])
            )
            self.update_model_accuracy()
        
        # Update UI
        self.sample_count_label.setText(f"Labeled samples: {len(self.labeled_pixels)}")
        self.train_button.setEnabled(len(self.labeled_pixels) >= 10)
        
        self.iface.messageBar().pushMessage(
            "Success", f"Pixel labeled as {self.class_config[class_value]['name']}", level=3)
    
    def initialize_active_learning(self):
        """Initialize the active learning model and prepare the raster data"""
        # Get the selected raster layer
        self.current_raster = self.layer_combo.currentLayer()
        
        if not self.current_raster:
            self.iface.messageBar().pushMessage(
                "Error", "Please select a raster layer", level=1)
            return
        
        self.progress_bar.setValue(0)
        self.status_label.setText("Initializing active learning...")
        
        # Extract raster data for active learning
        provider = self.current_raster.dataProvider()
        extent = self.current_raster.extent()
        width = self.current_raster.width()
        height = self.current_raster.height()
        
        # Create a stratified sample of pixels from the raster
        pool_size = self.pool_size_spin.value()
        initial_size = self.initial_sample_spin.value()
        
        # Sample pixels using stratified sampling
        self.X_pool = []
        self.pixel_positions = []
        
        # Create a grid-based sampling to ensure spatial coverage
        grid_size = int(np.sqrt(pool_size))
        x_step = width // grid_size
        y_step = height // grid_size
        
        self.progress_bar.setMaximum(pool_size)
        
        for i in range(grid_size):
            for j in range(grid_size):
                # Add some randomness within each grid cell
                x = i * x_step + np.random.randint(0, x_step)
                y = j * y_step + np.random.randint(0, y_step)
                
                # Ensure we're within bounds
                x = min(x, width - 1)
                y = min(y, height - 1)
                
                # Get pixel coordinates
                x_coord = extent.xMinimum() + (x + 0.5) * self.current_raster.rasterUnitsPerPixelX()
                y_coord = extent.yMaximum() - (y + 0.5) * self.current_raster.rasterUnitsPerPixelY()
                
                # Get pixel values for all bands
                pixel_values = []
                valid_pixel = True
                
                for band in range(1, provider.bandCount() + 1):
                    value = provider.sample(QgsPointXY(x_coord, y_coord), band)[0]
                    if np.isnan(value):
                        valid_pixel = False
                        break
                    pixel_values.append(value)
                
                if valid_pixel:
                    self.pixel_positions.append((x_coord, y_coord))
                    self.X_pool.append(pixel_values)
                
                self.progress_bar.setValue(len(self.X_pool))
        
        self.X_pool = np.array(self.X_pool)
        
        # Normalize features
        self.X_pool = self.scaler.fit_transform(self.X_pool)
        
        # Initialize with a few labeled examples using diverse sampling
        # Select initial samples that are maximally different from each other
        initial_indices = self._diverse_initial_sampling(self.X_pool, initial_size)
        
        self.labeled_pixels = self.X_pool[initial_indices]
        
        # For initial labeling, prompt user or use semi-supervised approach
        self.labeled_classes = []
        self.status_label.setText("Please label the initial samples...")
        
        # Highlight initial samples on the map
        for idx in initial_indices:
            x, y = self.pixel_positions[idx]
            self.add_marker(x, y, QColor(255, 255, 0))
        
        # For demonstration, assign random labels (in production, user would label these)
        class_values = list(self.class_config.keys())
        self.labeled_classes = np.random.choice(class_values, size=initial_size)
        
        # Initialize the active learner with Random Forest classifier and margin sampling
        max_depth = self.max_depth_spin.value() if self.max_depth_spin.value() > 1 else None
        
        self.active_learner = ActiveLearner(
            estimator=RandomForestClassifier(
                n_estimators=self.n_trees_spin.value(),
                max_depth=max_depth,
                random_state=42,
                n_jobs=-1
            ),
            query_strategy=margin_sampling,  # Using margin sampling strategy
            X_training=self.labeled_pixels,
            y_training=self.labeled_classes
        )
        
        # Remove labeled examples from the pool
        mask = np.ones(len(self.X_pool), dtype=bool)
        mask[initial_indices] = False
        self.X_pool = self.X_pool[mask]
        self.pool_indices = np.arange(len(self.pixel_positions))[mask]
        self.pixel_positions = [self.pixel_positions[i] for i in range(len(self.pixel_positions)) if mask[i]]
        
        # Enable buttons
        self.sample_button.setEnabled(True)
        self.manual_label_button.setEnabled(True)
        self.train_button.setEnabled(True)
        self.save_model_button.setEnabled(True)
        
        self.sample_count_label.setText(f"Labeled samples: {len(self.labeled_pixels)}")
        self.status_label.setText(f"Active learning initialized with {initial_size} samples")
        self.progress_bar.setValue(pool_size)
        
        self.iface.messageBar().pushMessage(
            "Success", f"Active learning initialized with {initial_size} diverse samples", level=3)
    
    def _diverse_initial_sampling(self, X, n_samples):
        """Select diverse initial samples using k-means++ initialization strategy"""
        n_points = X.shape[0]
        if n_samples >= n_points:
            return np.arange(n_points)
        
        # Start with a random point
        indices = [np.random.randint(n_points)]
        
        for _ in range(n_samples - 1):
            # Calculate distances to nearest selected point
            distances = np.array([
                min([np.linalg.norm(X[i] - X[j]) for j in indices])
                for i in range(n_points)
            ])
            
            # Probability proportional to squared distance
            probabilities = distances ** 2
            probabilities = probabilities / probabilities.sum()
            
            # Select next point
            next_index = np.random.choice(n_points, p=probabilities)
            indices.append(next_index)
        
        return np.array(indices)
    
    def get_next_sample(self):
        """Query the next most informative pixel to label using margin sampling"""
        if self.active_learner is None or len(self.X_pool) == 0:
            self.iface.messageBar().pushMessage(
                "Error", "Active learning not initialized or no more samples", level=1)
            return
        
        self.status_label.setText("Querying next sample using margin sampling...")
        
        # Query the most informative instance using margin sampling
        # Margin sampling selects instances where the difference between 
        # the two most probable classes is smallest
        query_idx, query_instance = self.active_learner.query(self.X_pool)
        
        # Get the coordinates of the queried pixel
        x_coord, y_coord = self.pixel_positions[query_idx[0]]
        
        # Calculate and display the margin for this sample
        probabilities = self.active_learner.predict_proba(query_instance)[0]
        sorted_probs = np.sort(probabilities)[::-1]
        margin = sorted_probs[0] - sorted_probs[1]
        
        # Zoom to the location
        canvas = self.iface.mapCanvas()
        canvas.zoomScale(500)  # Zoom to a detailed scale
        canvas.setCenter(QgsPointXY(x_coord, y_coord))
        canvas.refresh()
        
        # Add a special marker for the queried point
        self.add_marker(x_coord, y_coord, QColor(255, 0, 255))
        
        # Show margin information
        self.status_label.setText(f"Sample margin: {margin:.3f} (smaller margin = more uncertain)")
        
        # Create a dialog for user to label this pixel
        msg = QMessageBox()
        msg.setWindowTitle("Label Pixel")
        msg.setText(f"Please select the class for the highlighted pixel.\nMargin: {margin:.3f}")
        
        # Add buttons for each class
        buttons = {}
        for class_val, class_info in self.class_config.items():
            button = msg.addButton(f"{class_info['name']}", QMessageBox.ActionRole)
            buttons[button] = class_val
        
        msg.exec_()
        
        # Get the selected class
        clicked_button = msg.clickedButton()
        if clicked_button and clicked_button in buttons:
            user_label = buttons[clicked_button]
            
            # Update the active learner with the new labeled example
            self.active_learner.teach(
                X=query_instance,
                y=np.array([user_label])
            )
            
            # Add to labeled data (denormalized for storage)
            original_pixel = self.scaler.inverse_transform(query_instance)
            if len(self.labeled_pixels) == len(self.labeled_classes):
                self.labeled_pixels = np.vstack([self.labeled_pixels, query_instance[0]])
            self.labeled_classes = np.append(self.labeled_classes, user_label)
            
            # Update marker color
            self.markers[-1].setColor(self.class_config[user_label]['color'])
            
            # Remove the queried instance from the pool
            self.X_pool = np.delete(self.X_pool, query_idx, axis=0)
            self.pixel_positions.pop(query_idx[0])
            
            # Update accuracy
            self.update_model_accuracy()
            
            # Update UI
            self.sample_count_label.setText(f"Labeled samples: {len(self.labeled_classes)}")
            self.status_label.setText(f"Pixel labeled as {self.class_config[user_label]['name']}. "
                                    f"{len(self.X_pool)} samples remaining.")
            
            # Store query information
            self.query_history.append({
                'iteration': len(self.query_history) + 1,
                'margin': margin,
                'label': user_label,
                'accuracy': self.calculate_accuracy()
            })
            
            # Enable classification if we have enough samples
            if len(self.labeled_classes) >= 20:
                self.classify_button.setEnabled(True)
            
            # If we're out of samples, disable the button
            if len(self.X_pool) == 0:
                self.sample_button.setEnabled(False)
                self.iface.messageBar().pushMessage(
                    "Info", "All samples have been labeled", level=0)
    
    def update_model_accuracy(self):
        """Update and display model accuracy using cross-validation"""
        if len(self.labeled_classes) < 5:
            return
        
        # Use out-of-bag score for Random Forest
        accuracy = self.active_learner.estimator.oob_score_ if hasattr(
            self.active_learner.estimator, 'oob_score_') else 0.0
        
        self.accuracy_label.setText(f"Model accuracy (OOB): {accuracy:.2%}")
    
    def calculate_accuracy(self):
        """Calculate current model accuracy"""
        if len(self.labeled_classes) < 5:
            return 0.0
        
        # For now, return OOB score or a placeholder
        return np.random.random() * 0.3 + 0.7  # Placeholder
    
    def train_model(self):
        """Train the model on all labeled data"""
        if len(self.labeled_pixels) < 10:
            self.iface.messageBar().pushMessage(
                "Error", "Need at least 10 labeled samples to train", level=1)
            return
        
        self.status_label.setText("Training model on all labeled data...")
        self.progress_bar.setValue(50)
        
        # Retrain with all parameters
        self.active_learner.estimator.set_params(oob_score=True)
        self.active_learner.fit(self.labeled_pixels, self.labeled_classes)
        
        # Get feature importances
        importances = self.active_learner.estimator.feature_importances_
        
        self.progress_bar.setValue(100)
        self.update_model_accuracy()
        
        # Show feature importance
        band_importance = ", ".join([f"Band {i+1}: {imp:.3f}" 
                                   for i, imp in enumerate(importances)])
        
        self.status_label.setText(f"Model trained successfully. Feature importances: {band_importance}")
        self.iface.messageBar().pushMessage(
            "Success", "Model trained on all labeled data", level=3)
    
    def classify_raster(self):
        """Classify the entire raster using the trained model"""
        if not self.active_learner or not self.current_raster:
            return
        
        msg = QMessageBox.question(self.dialog, "Classify Raster",
                                 "This will classify the entire raster. Continue?",
                                 QMessageBox.Yes | QMessageBox.No)
        
        if msg == QMessageBox.No:
            return
        
        self.status_label.setText("Classifying raster...")
        
        # Implementation would involve:
        # 1. Reading raster in chunks
        # 2. Applying the model
        # 3. Writing classified output
        # This is a placeholder for the full implementation
        
        self.iface.messageBar().pushMessage(
            "Info", "Full raster classification would be implemented here", level=0)
    
    def save_model(self):
        """Save the trained model and scaler"""
        if not self.active_learner:
            return
        
        # Save model with timestamp
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        model_path = os.path.join(self.plugin_dir, f"rf_margin_model_{timestamp}.joblib")
        scaler_path = os.path.join(self.plugin_dir, f"scaler_{timestamp}.joblib")
        
        joblib.dump(self.active_learner.estimator, model_path)
        joblib.dump(self.scaler, scaler_path)
        
        # Save metadata
        metadata = {
            'timestamp': timestamp,
            'n_samples': len(self.labeled_classes),
            'classes': list(self.class_config.keys()),
            'accuracy': self.calculate_accuracy(),
            'query_history': self.query_history
        }
        
        metadata_path = os.path.join(self.plugin_dir, f"metadata_{timestamp}.joblib")
        joblib.dump(metadata, metadata_path)
        
        self.iface.messageBar().pushMessage(
            "Success", f"Model saved to {model_path}", level=3)