# Random Forest: Multi-Source ENVO Prediction

**Goal**: Predict ENVO environmental triad terms from satellite imagery across GOLD and NMDC.

**Key Questions**:
1. How well can we predict each ENVO scale (broad/local/medium)?
2. Do different sources show different patterns?
3. What's the impact of removing exact duplicates?

## Setup

In [None]:
from pathlib import Path
import sys

In [None]:
# Add src to path
sys.path.insert(0, str(Path('../src').absolute()))

In [None]:
from env_embeddings.rf_analysis import (
    load_source_data,
    analyze_source,
    create_comparison_table,
    print_summary,
)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# Configure plotting
plt.style.use('default')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 11

## Load Data Sources

In [None]:
# Define data sources (GOLD and NMDC first, NCBI later)
SOURCES = {
    'GOLD': Path('../data/gold_flattened_biosamples_for_env_embeddings_202510061108_complete.csv'),
    'NMDC': Path('../data/nmdc_flattened_biosample_for_env_embeddings_202510061052_complete.csv'),
}

In [None]:
# Load all sources
datasets = {}
for source_name, file_path in SOURCES.items():
    df = load_source_data(file_path, source_name, deduplicate=True)
    if df is not None:
        datasets[source_name] = df

## Train Random Forest Models

In [None]:
# Train models for all sources
all_results = {}
for source_name, df in datasets.items():
    all_results[source_name] = analyze_source(df, source_name)

## Results Summary

In [None]:
# Create comparison table
comparison_df = create_comparison_table(all_results)

In [None]:
print("\nDetailed Results:")
print(comparison_df.to_string(index=False))

In [None]:
# Print actionable summary
print_summary(comparison_df)

## Visualization: Test Accuracy by Scale

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))

for source in comparison_df['Source'].unique():
    data = comparison_df[comparison_df['Source'] == source]
    ax.plot(data['Scale'], data['Test_Acc'], 
            marker='o', linewidth=2, markersize=10, label=source)

ax.set_title('Test Accuracy by ENVO Scale', fontsize=14, fontweight='bold')
ax.set_xlabel('ENVO Scale')
ax.set_ylabel('Test Accuracy')
ax.set_ylim(0, 1.0)
ax.legend(fontsize=12)
ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()

## Visualization: Average Accuracy by Source

In [None]:
avg_by_source = comparison_df.groupby('Source')['Test_Acc'].mean()

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
bars = ax.bar(avg_by_source.index, avg_by_source.values, 
              alpha=0.7, edgecolor='black', linewidth=2)
ax.set_title('Average Test Accuracy by Source', fontsize=14, fontweight='bold')
ax.set_ylabel('Average Test Accuracy')
ax.set_ylim(0, 1.0)
ax.grid(axis='y', alpha=0.3)

# Add value labels
for bar in bars:
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
            f'{height:.3f}', ha='center', fontweight='bold', fontsize=12)

plt.tight_layout()
plt.show()

## Visualization: Overfitting Check

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))

# Color bars by overfitting severity
colors = ['red' if x > 0.1 else 'orange' if x > 0.05 else 'green' 
          for x in comparison_df['Overfitting'].values]

x_labels = [f"{row['Source']}\n{row['Scale']}" 
            for _, row in comparison_df.iterrows()]

bars = ax.bar(range(len(comparison_df)), comparison_df['Overfitting'].values,
              color=colors, alpha=0.7, edgecolor='black', linewidth=1.5)

ax.set_xticks(range(len(comparison_df)))
ax.set_xticklabels(x_labels, rotation=45, ha='right')
ax.set_ylabel('Overfitting (Train - Test Accuracy)')
ax.set_title('Overfitting Analysis', fontsize=14, fontweight='bold')
ax.axhline(y=0.1, color='red', linestyle='--', alpha=0.5, label='High (>0.1)')
ax.axhline(y=0.05, color='orange', linestyle='--', alpha=0.5, label='Moderate (>0.05)')
ax.legend()
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

## Save Results

In [None]:
output_dir = Path('../results/rf_multi_source')
output_dir.mkdir(parents=True, exist_ok=True)

In [None]:
output_file = output_dir / 'comparison_results.csv'
comparison_df.to_csv(output_file, index=False)
print(f"Results saved to: {output_file}")