# 06. Train Random Forest Model

This notebook trains a Random Forest classifier for deforestation detection.

**Approach:**
- Traditional machine learning (no deep learning)
- Flattened features: 128×128×14 = 229,376 features per sample
- Ensemble of decision trees (100 trees)
- Feature importance analysis

**Outputs:**
- Trained model: `checkpoints/random_forest_best.pkl`
- Training metrics
- Feature importance analysis
- Comparison with CNN models

## 1. Setup

In [None]:
import sys
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

# Add project root to path
project_root = Path.cwd().parent
src_path = project_root / 'src'

if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))
if str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))

print(f"Project root: {project_root}")
print(f"Source dir: {src_path}")

In [None]:
from ml_models import RandomForestModel, load_patches_for_ml
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc

# Set style
plt.style.use('default')
sns.set_palette("husl")

print("Libraries imported successfully!")

## 2. Load Data

In [None]:
# Paths
patches_dir = project_root / 'data' / 'patches'
checkpoints_dir = project_root / 'checkpoints'
logs_dir = project_root / 'logs'
figures_dir = project_root / 'figures'

checkpoints_dir.mkdir(exist_ok=True)
logs_dir.mkdir(exist_ok=True)
figures_dir.mkdir(exist_ok=True)

print("Directories created/verified!")

In [None]:
print("="*80)
print("LOADING PATCHES")
print("="*80)
print()

# Load train set
X_train, y_train = load_patches_for_ml(patches_dir, 'train')
print()

# Load validation set
X_val, y_val = load_patches_for_ml(patches_dir, 'val')
print()

# Load test set
X_test, y_test = load_patches_for_ml(patches_dir, 'test')
print()

print("="*80)
print("DATASET SUMMARY")
print("="*80)
print(f"Train: {len(X_train)} samples")
print(f"Val: {len(X_val)} samples")
print(f"Test: {len(X_test)} samples")
print(f"Total: {len(X_train) + len(X_val) + len(X_test)} samples")
print()
print(f"Features per sample: {X_train.shape[1]:,} (128×128×14)")
print(f"Memory: Train={X_train.nbytes/1e6:.1f}MB, Val={X_val.nbytes/1e6:.1f}MB, Test={X_test.nbytes/1e6:.1f}MB")

## 3. Model Configuration

In [None]:
# Model hyperparameters
model_config = {
    'n_estimators': 100,        # Number of trees
    'max_depth': 20,            # Maximum depth of trees
    'min_samples_split': 10,    # Min samples to split a node
    'min_samples_leaf': 4,      # Min samples at leaf node
    'random_state': 42,         # For reproducibility
    'n_jobs': -1                # Use all CPU cores
}

print("Model Configuration:")
print("="*80)
for key, value in model_config.items():
    print(f"  {key:20s}: {value}")
print()
print("Note: n_jobs=-1 means using all available CPU cores for parallel training")

## 4. Train Random Forest

In [None]:
print("="*80)
print("TRAINING RANDOM FOREST")
print("="*80)
print()

# Create model
model = RandomForestModel(**model_config)

# Train model
start_time = datetime.now()
metrics = model.train(X_train, y_train, X_val, y_val)
end_time = datetime.now()

training_time = (end_time - start_time).total_seconds()

print()
print(f"Training completed in {training_time:.2f} seconds ({training_time/60:.2f} minutes)")

## 5. Training Metrics

In [None]:
print("="*80)
print("TRAINING METRICS")
print("="*80)
print()

print("Train Set:")
print("-"*80)
print(f"  Accuracy:  {metrics['train_acc']:.4f}")
print(f"  Precision: {metrics['train_precision']:.4f}")
print(f"  Recall:    {metrics['train_recall']:.4f}")
print(f"  F1 Score:  {metrics['train_f1']:.4f}")
print(f"  AUC:       {metrics['train_auc']:.4f}")
print()

print("Validation Set:")
print("-"*80)
print(f"  Accuracy:  {metrics['val_acc']:.4f}")
print(f"  Precision: {metrics['val_precision']:.4f}")
print(f"  Recall:    {metrics['val_recall']:.4f}")
print(f"  F1 Score:  {metrics['val_f1']:.4f}")
print(f"  AUC:       {metrics['val_auc']:.4f}")

## 6. Test Set Evaluation

In [None]:
print("="*80)
print("TEST SET EVALUATION")
print("="*80)
print()

# Evaluate on test set
test_metrics = model.evaluate(X_test, y_test)

print("Test Set:")
print("-"*80)
print(f"  Accuracy:  {test_metrics['accuracy']:.4f}")
print(f"  Precision: {test_metrics['precision']:.4f}")
print(f"  Recall:    {test_metrics['recall']:.4f}")
print(f"  F1 Score:  {test_metrics['f1']:.4f}")
print(f"  AUC:       {test_metrics['auc']:.4f}")

In [None]:
# Detailed classification report
y_test_pred = model.predict(X_test)

print()
print("Detailed Classification Report:")
print("="*80)
print(classification_report(y_test, y_test_pred, 
                          target_names=['No Deforestation', 'Deforestation']))

## 7. Confusion Matrix

In [None]:
# Compute confusion matrix
cm = confusion_matrix(y_test, y_test_pred)

# Plot confusion matrix
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['No Deforestation', 'Deforestation'],
            yticklabels=['No Deforestation', 'Deforestation'],
            cbar_kws={'label': 'Count'})
ax.set_xlabel('Predicted', fontsize=12)
ax.set_ylabel('Actual', fontsize=12)
ax.set_title('Random Forest - Confusion Matrix (Test Set)', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig(figures_dir / 'random_forest_confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"Confusion matrix saved to: {figures_dir / 'random_forest_confusion_matrix.png'}")

## 8. ROC Curve

In [None]:
# Get predicted probabilities
y_test_proba = model.predict_proba(X_test)[:, 1]

# Compute ROC curve
fpr, tpr, thresholds = roc_curve(y_test, y_test_proba)
roc_auc = auc(fpr, tpr)

# Plot ROC curve
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(fpr, tpr, color='darkorange', lw=2, 
        label=f'Random Forest (AUC = {roc_auc:.4f})')
ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random Guess')
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('False Positive Rate', fontsize=12)
ax.set_ylabel('True Positive Rate', fontsize=12)
ax.set_title('Random Forest - ROC Curve (Test Set)', fontsize=14, fontweight='bold')
ax.legend(loc='lower right', fontsize=11)
ax.grid(alpha=0.3)

plt.tight_layout()
plt.savefig(figures_dir / 'random_forest_roc_curve.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"ROC curve saved to: {figures_dir / 'random_forest_roc_curve.png'}")

## 9. Feature Importance Analysis

In [None]:
# Get feature importance
importance = model.get_feature_importance()

print("="*80)
print("FEATURE IMPORTANCE ANALYSIS")
print("="*80)
print()
print(f"Total features: {len(importance):,}")
print(f"Feature importance range: [{importance.min():.6f}, {importance.max():.6f}]")
print(f"Mean importance: {importance.mean():.6f}")
print(f"Std importance: {importance.std():.6f}")

In [None]:
# Reshape importance to (128, 128, 14)
importance_map = importance.reshape(128, 128, 14)

# Calculate band-wise importance (average over spatial dimensions)
band_importance = importance_map.mean(axis=(0, 1))

# Band names
band_names = [
    'Blue_2024', 'Green_2024', 'Red_2024', 'NIR_2024',
    'NDVI_2024', 'NBR_2024', 'NDMI_2024',
    'Blue_2025', 'Green_2025', 'Red_2025', 'NIR_2025',
    'NDVI_2025', 'NBR_2025', 'NDMI_2025'
]

print()
print("Average Importance by Band:")
print("-"*80)
for i, (name, imp) in enumerate(zip(band_names, band_importance)):
    bar = '█' * int(imp * 1000)
    print(f"  Band {i:2d} ({name:12s}): {imp:.6f} {bar}")

In [None]:
# Plot band importance
fig, ax = plt.subplots(figsize=(12, 6))

colors = ['steelblue']*7 + ['coral']*7  # Different colors for 2024 vs 2025
bars = ax.bar(range(14), band_importance, color=colors, edgecolor='black', alpha=0.8)

ax.set_xlabel('Band', fontsize=12)
ax.set_ylabel('Average Importance', fontsize=12)
ax.set_title('Random Forest - Feature Importance by Band', fontsize=14, fontweight='bold')
ax.set_xticks(range(14))
ax.set_xticklabels(band_names, rotation=45, ha='right')
ax.grid(axis='y', alpha=0.3)

# Add legend
from matplotlib.patches import Patch
legend_elements = [
    Patch(facecolor='steelblue', edgecolor='black', label='2024'),
    Patch(facecolor='coral', edgecolor='black', label='2025')
]
ax.legend(handles=legend_elements, loc='upper right')

plt.tight_layout()
plt.savefig(figures_dir / 'random_forest_band_importance.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"Band importance plot saved to: {figures_dir / 'random_forest_band_importance.png'}")

In [None]:
# Calculate spatial importance (average over bands)
spatial_importance = importance_map.mean(axis=2)

# Plot spatial importance heatmap
fig, ax = plt.subplots(figsize=(10, 10))
im = ax.imshow(spatial_importance, cmap='hot', interpolation='nearest')
ax.set_title('Random Forest - Spatial Feature Importance\n(Averaged over all 14 bands)', 
             fontsize=14, fontweight='bold')
ax.set_xlabel('X (pixels)', fontsize=12)
ax.set_ylabel('Y (pixels)', fontsize=12)
cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label('Importance', fontsize=12, rotation=270, labelpad=20)

plt.tight_layout()
plt.savefig(figures_dir / 'random_forest_spatial_importance.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"Spatial importance heatmap saved to: {figures_dir / 'random_forest_spatial_importance.png'}")
print()
print("Note: Brighter areas indicate more important spatial locations for classification")

## 10. Save Model

In [None]:
# Save trained model
model_path = checkpoints_dir / 'random_forest_best.pkl'
model.save(model_path)

print()
print(f"Model saved to: {model_path}")
print(f"Model size: {model_path.stat().st_size / 1e6:.2f} MB")

## 11. Save Training Log

In [None]:
# Save training log
log_path = logs_dir / 'random_forest_training.txt'

with open(log_path, 'w', encoding='utf-8') as f:
    f.write("="*80 + "\n")
    f.write("RANDOM FOREST TRAINING LOG\n")
    f.write("="*80 + "\n\n")
    
    f.write(f"Training date: {start_time.strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write(f"Training time: {training_time:.2f} seconds ({training_time/60:.2f} minutes)\n\n")
    
    f.write("MODEL CONFIGURATION:\n")
    f.write("-"*80 + "\n")
    for key, value in model_config.items():
        f.write(f"  {key}: {value}\n")
    f.write("\n")
    
    f.write("DATASET:\n")
    f.write("-"*80 + "\n")
    f.write(f"  Train: {len(X_train)} samples\n")
    f.write(f"  Val: {len(X_val)} samples\n")
    f.write(f"  Test: {len(X_test)} samples\n")
    f.write(f"  Features: {X_train.shape[1]:,}\n\n")
    
    f.write("TRAINING METRICS:\n")
    f.write("-"*80 + "\n")
    f.write("Train set:\n")
    f.write(f"  Accuracy:  {metrics['train_acc']:.4f}\n")
    f.write(f"  Precision: {metrics['train_precision']:.4f}\n")
    f.write(f"  Recall:    {metrics['train_recall']:.4f}\n")
    f.write(f"  F1 Score:  {metrics['train_f1']:.4f}\n")
    f.write(f"  AUC:       {metrics['train_auc']:.4f}\n\n")
    
    f.write("Validation set:\n")
    f.write(f"  Accuracy:  {metrics['val_acc']:.4f}\n")
    f.write(f"  Precision: {metrics['val_precision']:.4f}\n")
    f.write(f"  Recall:    {metrics['val_recall']:.4f}\n")
    f.write(f"  F1 Score:  {metrics['val_f1']:.4f}\n")
    f.write(f"  AUC:       {metrics['val_auc']:.4f}\n\n")
    
    f.write("TEST METRICS:\n")
    f.write("-"*80 + "\n")
    f.write(f"  Accuracy:  {test_metrics['accuracy']:.4f}\n")
    f.write(f"  Precision: {test_metrics['precision']:.4f}\n")
    f.write(f"  Recall:    {test_metrics['recall']:.4f}\n")
    f.write(f"  F1 Score:  {test_metrics['f1']:.4f}\n")
    f.write(f"  AUC:       {test_metrics['auc']:.4f}\n\n")
    
    f.write("FEATURE IMPORTANCE BY BAND:\n")
    f.write("-"*80 + "\n")
    for i, (name, imp) in enumerate(zip(band_names, band_importance)):
        f.write(f"  Band {i:2d} ({name:12s}): {imp:.6f}\n")

print(f"Training log saved to: {log_path}")

## 12. Summary

In [None]:
print()
print("="*80)
print("TRAINING COMPLETED SUCCESSFULLY")
print("="*80)
print()
print("Output Files:")
print(f"  1. Model: {model_path}")
print(f"  2. Training log: {log_path}")
print(f"  3. Confusion matrix: {figures_dir / 'random_forest_confusion_matrix.png'}")
print(f"  4. ROC curve: {figures_dir / 'random_forest_roc_curve.png'}")
print(f"  5. Band importance: {figures_dir / 'random_forest_band_importance.png'}")
print(f"  6. Spatial importance: {figures_dir / 'random_forest_spatial_importance.png'}")
print()
print("Test Set Performance:")
print(f"  Accuracy:  {test_metrics['accuracy']:.2%}")
print(f"  F1 Score:  {test_metrics['f1']:.2%}")
print(f"  AUC:       {test_metrics['auc']:.2%}")
print()
print("Key Insights:")
print(f"  - Training time: {training_time/60:.2f} minutes")
print(f"  - Most important band: {band_names[band_importance.argmax()]}")
print(f"  - Model size: {model_path.stat().st_size / 1e6:.2f} MB")
print()
print("Next Steps:")
print("  - Compare with CNN models in notebook 04")
print("  - Analyze feature importance for insights")
print("  - Consider hyperparameter tuning if needed")