In [9]:
import rasterio
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
from sklearn.preprocessing import StandardScaler
from skimage.util import view_as_windows
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
from scipy.ndimage import median_filter
import joblib

# === Configuration ===
PATCH_SIZE = 3
TEST_SIZE = 0.2
RANDOM_STATE = 42
MIN_SAMPLES_PER_CLASS = 5  # Balanced between accuracy and memory
OUTPUT_DIR = "E:/mtito_andei/outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Memory-optimized Random Forest parameters
RF_PARAMS = {
    'n_estimators': 50,      # Reduced from 100 to save memory
    'max_depth': 15,         # Limited depth to prevent overfitting
    'min_samples_split': 5,  # Increased to reduce tree complexity
    'max_features': 'sqrt',  # Fewer features considered per split
    'n_jobs': -1,            # Still use all cores
    'random_state': RANDOM_STATE
}

# === Memory-efficient raster loading ===
def read_raster(path):
    """Load raster with memory optimization"""
    with rasterio.open(path) as src:
        # Read only the first band by default
        data = src.read(1, masked=True)
        profile = src.profile
        print(f"Loaded {os.path.basename(path)}: {data.shape}, {data.dtype}")
        return data.filled(0), profile

# === Optimized patch extraction ===
def extract_patches_efficient(*arrays, patch_size=PATCH_SIZE):
    """Memory-efficient patch extraction using sliding window"""
    print("Extracting patches efficiently...")
    pad = patch_size // 2
    
    # Process features (all arrays except last)
    features = []
    for arr in arrays[:-1]:
        # Use sliding window view (memory efficient)
        window_shape = (patch_size, patch_size)
        arr_patches = view_as_windows(arr, window_shape)
        arr_patches = arr_patches.reshape(-1, patch_size*patch_size)
        features.append(arr_patches)
    
    # Combine all features
    features = np.concatenate(features, axis=1)
    
    # Get labels (center pixels)
    labels = arrays[-1][pad:-pad, pad:-pad].flatten()
    
    return features, labels

# === Main Processing ===
def main():
    # Load data
    print("Loading rasters...")
    ld, profile = read_raster(os.path.join("E:/mtito_andei", "LD.tif"))
    d2f, _ = read_raster(os.path.join("E:/mtito_andei", "D2F.tif"))
    target, _ = read_raster(os.path.join("E:/mtito_andei", "Map.tif"))

    # Extract patches efficiently
    features, labels = extract_patches_efficient(ld, d2f, target)
    
    # Filter data
    valid_mask = (labels > 0)
    features = features[valid_mask]
    labels = labels[valid_mask]

    # Filter rare classes
    label_counts = pd.Series(labels).value_counts()
    valid_classes = label_counts[label_counts >= MIN_SAMPLES_PER_CLASS].index
    class_mask = np.isin(labels, valid_classes)
    features = features[class_mask]
    labels = labels[class_mask]

    print(f"\nClass distribution:\n{label_counts}")

    # Normalize features (memory-efficient StandardScaler)
    scaler = StandardScaler()
    features = scaler.fit_transform(features)
    joblib.dump(scaler, os.path.join(OUTPUT_DIR, "scaler.joblib"))

    # Train/test split
    X_train, X_test, y_train, y_test = train_test_split(
        features, labels,
        test_size=TEST_SIZE,
        random_state=RANDOM_STATE,
        stratify=labels
    )

    # Train memory-optimized Random Forest
    print("\nTraining memory-efficient Random Forest...")
    rf = RandomForestClassifier(**RF_PARAMS)
    rf.fit(X_train, y_train)

    # Save model
    model_path = os.path.join(OUTPUT_DIR, "memory_efficient_rf_model.joblib")
    joblib.dump(rf, model_path)
    print(f"Model saved to {model_path}")

    # Evaluation
    print("\nEvaluating model...")
    y_pred = rf.predict(X_test)

    print("\nClassification Report:")
    print(classification_report(y_test, y_pred, target_names=[f"Class {c}" for c in valid_classes]))
    print(f"Accuracy: {accuracy_score(y_test, y_pred):.4f}")

    # Feature importance plot
    importances = rf.feature_importances_
    plt.figure(figsize=(12, 6))
    sorted_idx = importances.argsort()[::-1]
    plt.bar(range(X_train.shape[1]), importances[sorted_idx], align='center')
    plt.xticks(range(X_train.shape[1]), sorted_idx, rotation=90)
    plt.xlabel("Feature index")
    plt.ylabel("Feature importance")
    plt.title("Feature Importance Ranking")
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, "feature_importance.png"), dpi=300)
    plt.close()

    # Full prediction with chunk processing
    print("\nGenerating full prediction map (chunked processing)...")
    chunk_size = 50000  # Process in chunks to save memory
    all_preds = np.zeros(len(features), dtype=np.uint8)
    
    for i in tqdm(range(0, len(features), chunk_size), desc="Predicting"):
        chunk = features[i:i + chunk_size]
        all_preds[i:i + chunk_size] = rf.predict(chunk)

    # Create prediction map
    predicted_map = np.zeros_like(target, dtype=np.uint8)
    pad = PATCH_SIZE // 2
    valid_pixels = (target[pad:-pad, pad:-pad] > 0) & (np.isin(target[pad:-pad, pad:-pad], valid_classes))
    predicted_map[pad:-pad, pad:-pad][valid_pixels] = all_preds

    # Apply median filter to reduce noise
    predicted_map = median_filter(predicted_map, size=3)

    # Save prediction
    pred_profile = profile.copy()
    pred_profile.update(dtype=rasterio.uint8, count=1, nodata=0)

    with rasterio.open(os.path.join(OUTPUT_DIR, "Predicted_Map.tif"), 'w', **pred_profile) as dst:
        dst.write(predicted_map, 1)

    print("\n=== Process Complete ===")
    print(f"✅ All outputs saved to: {OUTPUT_DIR}")

    # === LEARNING CURVE ===
    print("\nGenerating learning curve...")

    from sklearn.model_selection import learning_curve
    import seaborn as sns

    train_sizes = np.linspace(0.1, 1.0, 5)
    train_sizes_abs, train_scores, test_scores = learning_curve(
        estimator=rf,
        X=X_train,
        y=y_train,
        train_sizes=train_sizes,
        cv=3,
        scoring='accuracy',
        n_jobs=1,
        shuffle=True,
        random_state=RANDOM_STATE
    )

    train_scores_mean = np.mean(train_scores, axis=1)
    test_scores_mean = np.mean(test_scores, axis=1)

    plt.figure(figsize=(8, 6))
    sns.set(style="whitegrid")

    sns.lineplot(x=train_sizes_abs, y=train_scores_mean, label="Training Accuracy", marker='o')
    sns.lineplot(x=train_sizes_abs, y=test_scores_mean, label="Validation Accuracy", marker='o')

    plt.title("Learning Curve")
    plt.xlabel("Number of Training Samples")
    plt.ylabel("Accuracy")
    plt.legend(loc="lower right")
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, "learning_curve.png"), dpi=300)
    plt.close()
    print("Learning curve saved.")

      


if __name__ == "__main__":
    main()

Loading rasters...
Loaded LD.tif: (1024, 1280), uint8
Loaded D2F.tif: (1024, 1280), uint8
Loaded Map.tif: (1024, 1280), uint8
Extracting patches efficiently...

Class distribution:
138    697044
224    580040
226     15099
225      5042
227      2219
        ...  
71          3
243         2
170         2
254         1
249         1
Name: count, Length: 202, dtype: int64

Training memory-efficient Random Forest...
Model saved to E:/mtito_andei/outputs\memory_efficient_rf_model.joblib

Evaluating model...

Classification Report:


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


              precision    recall  f1-score   support

   Class 138       0.99      1.00      0.99       154
   Class 224       0.62      0.62      0.62         8
   Class 226       0.69      0.73      0.71        15
   Class 225       1.00      0.50      0.67         4
   Class 227       0.00      0.00      0.00         2
    Class 50       0.50      0.25      0.33         4
   Class 228       0.80      0.92      0.86        13
   Class 232       1.00      1.00      1.00         3
   Class 223       1.00      1.00      1.00         2
   Class 240       0.00      0.00      0.00         2
   Class 167       0.50      0.20      0.29         5
   Class 101       0.67      0.67      0.67         3
   Class 197       0.75      1.00      0.86         3
   Class 204       0.50      0.75      0.60         4
   Class 178       0.67      1.00      0.80         4
   Class 229       1.00      0.75      0.86         4
    Class 52       1.00      1.00      1.00         2
   Class 217       1.00    

Predicting: 100%|██████████| 27/27 [00:47<00:00,  1.75s/it]



=== Process Complete ===
✅ All outputs saved to: E:/mtito_andei/outputs

Generating learning curve...
Learning curve saved.
