# Earth4D LFMC Ablation Study

**Comprehensive ablation study for LFMC prediction**

This notebook runs 8 experiments to evaluate different feature combinations:

**Without Species Embeddings:**
1. Earth4D alone
2. AEF alone
3. Earth4D + AEF
4. Earth4D + AEF + Daymet

**With Species Embeddings (768D):**
5. Earth4D + Species
6. AEF + Species
7. Earth4D + AEF + Species
8. Earth4D + AEF + Daymet + Species (Full Model)

**Configuration:**
- 100 epochs per experiment
- 768-dimension species embeddings
- Automatic dataset download from cloud storage
- Comprehensive visualization with labeled plots
- PNG and SVG outputs (DPI 300)

## 1. Environment Setup

In [None]:
# Check GPU
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("WARNING: CUDA not available - this model requires GPU!")

In [None]:
# Install dependencies
!pip install -q torch torchvision torchaudio
!pip install -q numpy pandas scikit-learn scipy tqdm matplotlib seaborn
!pip install -q ninja pybind11
!pip install -q pyarrow fastparquet  # For Parquet support

print("Dependencies installed!")

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

import os
WORK_DIR = '/content/drive/MyDrive/Earth4D_LFMC_Ablation_Study'
os.makedirs(WORK_DIR, exist_ok=True)
os.chdir(WORK_DIR)

print(f"Working directory: {os.getcwd()}")

## 2. Download Repository (earth-observation branch)

In [None]:
# Download earth-observation branch from GitHub
import os
import shutil

# Clean previous installation
if os.path.exists('deepearth-earth-observation'):
    shutil.rmtree('deepearth-earth-observation')
if os.path.exists('deepearth-earth-observation.zip'):
    os.remove('deepearth-earth-observation.zip')

# Download earth-observation branch
!wget -O deepearth-earth-observation.zip https://github.com/legel/deepearth/archive/refs/heads/earth-observation.zip
!unzip -q deepearth-earth-observation.zip

print("Repository downloaded!")
print("\nContents:")
!ls -la deepearth-earth-observation/encoders/xyzt/ | head -20

In [None]:
# Install Earth4D
EARTH4D_DIR = os.path.join(WORK_DIR, 'deepearth-earth-observation', 'encoders', 'xyzt')
os.chdir(EARTH4D_DIR)

print(f"Installing Earth4D from: {os.getcwd()}")
!pip install -e .

print("\nEarth4D installed!")

In [None]:
# Test Earth4D import
from earth4d import Earth4D
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
earth4d = Earth4D(
    spatial_levels=4,
    temporal_levels=3,
    features_per_level=2,
    verbose=False
).to(device)

test_coords = torch.tensor([[0.0, 0.0, 0.0, 0.5]], device=device)
with torch.no_grad():
    features = earth4d(test_coords)

print(f"Earth4D import successful!")
print(f"Test output shape: {features.shape}")

## 3. Configuration

In [None]:
# Paths
EARTH4D_DIR = os.path.join(WORK_DIR, 'deepearth-earth-observation', 'encoders', 'xyzt')
DATA_DIR = os.path.join(WORK_DIR, 'data')
OUTPUT_DIR = os.path.join(WORK_DIR, 'ablation_results')

os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Global training config
EPOCHS = 100
SPECIES_DIM = 768
BATCH_SIZE = 30000
LEARNING_RATE = 0.03
SEED = 0

print("Configuration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Species dim: {SPECIES_DIM}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Random seed: {SEED}")
print(f"\nDirectories:")
print(f"  Earth4D: {EARTH4D_DIR}")
print(f"  Data: {DATA_DIR}")
print(f"  Output: {OUTPUT_DIR}")

In [None]:
# Define 8 ablation experiments
experiments = [
    {
        'name': 'exp1_earth4d_only',
        'label': 'Earth4D Only',
        'use_earth4d': True,
        'use_species': False,
        'use_aef': False,
        'use_daymet': False
    },
    {
        'name': 'exp2_aef_only',
        'label': 'AEF Only',
        'use_earth4d': False,
        'use_species': False,
        'use_aef': True,
        'use_daymet': False
    },
    {
        'name': 'exp3_earth4d_aef',
        'label': 'Earth4D + AEF',
        'use_earth4d': True,
        'use_species': False,
        'use_aef': True,
        'use_daymet': False
    },
    {
        'name': 'exp4_earth4d_aef_daymet',
        'label': 'Earth4D + AEF + Daymet',
        'use_earth4d': True,
        'use_species': False,
        'use_aef': True,
        'use_daymet': True
    },
    {
        'name': 'exp5_earth4d_species',
        'label': 'Earth4D + Species',
        'use_earth4d': True,
        'use_species': True,
        'use_aef': False,
        'use_daymet': False
    },
    {
        'name': 'exp6_aef_species',
        'label': 'AEF + Species',
        'use_earth4d': False,
        'use_species': True,
        'use_aef': True,
        'use_daymet': False
    },
    {
        'name': 'exp7_earth4d_aef_species',
        'label': 'Earth4D + AEF + Species',
        'use_earth4d': True,
        'use_species': True,
        'use_aef': True,
        'use_daymet': False
    },
    {
        'name': 'exp8_full_model',
        'label': 'Full Model (Earth4D + AEF + Daymet + Species)',
        'use_earth4d': True,
        'use_species': True,
        'use_aef': True,
        'use_daymet': True
    }
]

print("Ablation Experiments:")
for i, exp in enumerate(experiments, 1):
    features = []
    if exp['use_earth4d']: features.append('Earth4D')
    if exp['use_species']: features.append(f'Species({SPECIES_DIM}D)')
    if exp['use_aef']: features.append('AEF(64D)')
    if exp['use_daymet']: features.append('Daymet(21D)')
    print(f"  {i}. {exp['label']:45s} = {' + '.join(features)}")

## 4. Run All 8 Ablation Experiments

In [None]:
import subprocess
import time
from datetime import datetime

# Change to Earth4D directory
os.chdir(EARTH4D_DIR)

results = []

print("="*80)
print("RUNNING 8 ABLATION EXPERIMENTS")
print("="*80)

for i, exp in enumerate(experiments, 1):
    print(f"\n{'='*80}")
    print(f"EXPERIMENT {i}/8: {exp['label']}")
    print(f"{'='*80}")
    
    # Create experiment output directory
    exp_output_dir = os.path.join(OUTPUT_DIR, exp['name'])
    os.makedirs(exp_output_dir, exist_ok=True)
    
    # Build command
    cmd = [
        'python', 'earth4d-aef-daymet_to_lfmc.py',
        '--data-dir', DATA_DIR,
        '--epochs', str(EPOCHS),
        '--species-dim', str(SPECIES_DIM),
        '--batch-size', str(BATCH_SIZE),
        '--lr', str(LEARNING_RATE),
        '--output-dir', exp_output_dir,
        '--seed', str(SEED),
        '--auto-download'  # Auto-download datasets
    ]
    
    # Add feature flags
    if not exp['use_earth4d']:
        cmd.append('--no-earth4d')
    if not exp['use_species']:
        cmd.append('--no-species')
    if exp['use_aef']:
        cmd.append('--use-aef')
    if exp['use_daymet']:
        cmd.append('--use-daymet')
    
    print(f"\nCommand: {' '.join(cmd)}")
    print(f"\nStarting training at {datetime.now().strftime('%H:%M:%S')}...")
    
    start_time = time.time()
    
    # Run experiment
    result = subprocess.run(cmd, capture_output=False)
    
    elapsed = time.time() - start_time
    
    if result.returncode == 0:
        print(f"\n[SUCCESS] Experiment {i} completed in {elapsed/60:.1f} minutes")
        results.append({
            'experiment': exp['name'],
            'label': exp['label'],
            'success': True,
            'time_minutes': elapsed/60
        })
    else:
        print(f"\n[ERROR] Experiment {i} failed!")
        results.append({
            'experiment': exp['name'],
            'label': exp['label'],
            'success': False,
            'time_minutes': elapsed/60
        })

print("\n" + "="*80)
print("ALL EXPERIMENTS COMPLETED")
print("="*80)

# Print summary
print("\nSummary:")
for r in results:
    status = "SUCCESS" if r['success'] else "FAILED"
    print(f"  {r['label']:45s} - {status} ({r['time_minutes']:.1f} min)")

n_success = sum(1 for r in results if r['success'])
print(f"\nTotal: {n_success}/{len(results)} experiments successful")

## 5. Collect and Analyze Results

In [None]:
import pandas as pd
import glob

# Collect all metrics CSVs
all_metrics = {}

for exp in experiments:
    exp_output_dir = os.path.join(OUTPUT_DIR, exp['name'])
    
    # Find metrics CSV
    csv_files = glob.glob(os.path.join(exp_output_dir, 'training_metrics_*.csv'))
    
    if csv_files:
        # Load most recent CSV
        csv_path = sorted(csv_files)[-1]
        df = pd.read_csv(csv_path)
        all_metrics[exp['name']] = {
            'label': exp['label'],
            'data': df
        }
        print(f"Loaded: {exp['label']} ({len(df)} epochs)")
    else:
        print(f"WARNING: No metrics found for {exp['label']}")

print(f"\nLoaded {len(all_metrics)}/8 experiment results")

## 6. Create Comparison Visualizations

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Set high DPI for publication-quality figures
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300

# Define colors for experiments (without species vs with species)
colors_no_species = ['#E74C3C', '#E67E22', '#F39C12', '#F1C40F']  # Reds/Oranges
colors_with_species = ['#3498DB', '#2ECC71', '#1ABC9C', '#9B59B6']  # Blues/Greens

color_map = {
    'exp1_earth4d_only': colors_no_species[0],
    'exp2_aef_only': colors_no_species[1],
    'exp3_earth4d_aef': colors_no_species[2],
    'exp4_earth4d_aef_daymet': colors_no_species[3],
    'exp5_earth4d_species': colors_with_species[0],
    'exp6_aef_species': colors_with_species[1],
    'exp7_earth4d_aef_species': colors_with_species[2],
    'exp8_full_model': colors_with_species[3]
}

print("Color scheme defined:")
print("  Red/Orange: Experiments WITHOUT species")
print("  Blue/Green: Experiments WITH species")

In [None]:
# Figure 1: Training Loss Curves (All Experiments)
fig, ax = plt.subplots(figsize=(14, 8))

for exp_name, exp_data in all_metrics.items():
    df = exp_data['data']
    label = exp_data['label']
    color = color_map[exp_name]
    
    # Plot training MAE
    ax.plot(df['epoch'], df['train_mae'], 
            label=label, 
            color=color, 
            linewidth=2.5,
            alpha=0.9)

ax.set_xlabel('Epoch', fontsize=14, fontweight='bold')
ax.set_ylabel('Training MAE (percentage points)', fontsize=14, fontweight='bold')
ax.set_title('Training Loss Curves - 8 Ablation Experiments', 
             fontsize=16, fontweight='bold', pad=20)
ax.legend(loc='upper right', fontsize=11, framealpha=0.95)
ax.grid(True, alpha=0.3, linestyle='--')
ax.set_xlim(0, EPOCHS)

plt.tight_layout()

# Save
png_path = os.path.join(OUTPUT_DIR, 'fig1_training_loss_curves.png')
svg_path = os.path.join(OUTPUT_DIR, 'fig1_training_loss_curves.svg')
plt.savefig(png_path, dpi=300, bbox_inches='tight')
plt.savefig(svg_path, dpi=300, bbox_inches='tight')
print(f"Saved: {png_path}")
print(f"Saved: {svg_path}")

plt.show()

In [None]:
# Figure 2: Test Performance (Temporal, Spatial, Random)
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

test_sets = [
    ('temporal_mae', 'Temporal Test MAE', axes[0]),
    ('spatial_mae', 'Spatial Test MAE', axes[1]),
    ('random_mae', 'Random Test MAE', axes[2])
]

for metric, title, ax in test_sets:
    for exp_name, exp_data in all_metrics.items():
        df = exp_data['data']
        label = exp_data['label']
        color = color_map[exp_name]
        
        ax.plot(df['epoch'], df[metric], 
                label=label, 
                color=color, 
                linewidth=2.5,
                alpha=0.9)
    
    ax.set_xlabel('Epoch', fontsize=12, fontweight='bold')
    ax.set_ylabel('MAE (percentage points)', fontsize=12, fontweight='bold')
    ax.set_title(title, fontsize=13, fontweight='bold')
    ax.legend(loc='upper right', fontsize=9, framealpha=0.95)
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.set_xlim(0, EPOCHS)

fig.suptitle('Test Set Performance - 8 Ablation Experiments', 
             fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()

# Save
png_path = os.path.join(OUTPUT_DIR, 'fig2_test_performance.png')
svg_path = os.path.join(OUTPUT_DIR, 'fig2_test_performance.svg')
plt.savefig(png_path, dpi=300, bbox_inches='tight')
plt.savefig(svg_path, dpi=300, bbox_inches='tight')
print(f"Saved: {png_path}")
print(f"Saved: {svg_path}")

plt.show()

In [None]:
# Figure 3: Final Performance Comparison (Bar Chart)
final_performance = {}

for exp_name, exp_data in all_metrics.items():
    df = exp_data['data']
    final_epoch = df.iloc[-1]
    
    final_performance[exp_name] = {
        'label': exp_data['label'],
        'train_mae': final_epoch['train_mae'],
        'temporal_mae': final_epoch['temporal_mae'],
        'spatial_mae': final_epoch['spatial_mae'],
        'random_mae': final_epoch['random_mae']
    }

# Create bar chart
fig, ax = plt.subplots(figsize=(14, 8))

labels = [final_performance[k]['label'] for k in experiments[0]['name'] if k in final_performance]
train_vals = [final_performance[exp['name']]['train_mae'] for exp in experiments if exp['name'] in final_performance]
temporal_vals = [final_performance[exp['name']]['temporal_mae'] for exp in experiments if exp['name'] in final_performance]
spatial_vals = [final_performance[exp['name']]['spatial_mae'] for exp in experiments if exp['name'] in final_performance]
random_vals = [final_performance[exp['name']]['random_mae'] for exp in experiments if exp['name'] in final_performance]

x = np.arange(len(labels))
width = 0.2

bars1 = ax.bar(x - 1.5*width, train_vals, width, label='Train', color='#34495E', alpha=0.8)
bars2 = ax.bar(x - 0.5*width, temporal_vals, width, label='Temporal Test', color='#E74C3C', alpha=0.8)
bars3 = ax.bar(x + 0.5*width, spatial_vals, width, label='Spatial Test', color='#3498DB', alpha=0.8)
bars4 = ax.bar(x + 1.5*width, random_vals, width, label='Random Test', color='#2ECC71', alpha=0.8)

ax.set_xlabel('Experiment', fontsize=14, fontweight='bold')
ax.set_ylabel('Final MAE (percentage points)', fontsize=14, fontweight='bold')
ax.set_title(f'Final Performance After {EPOCHS} Epochs - 8 Ablation Experiments', 
             fontsize=16, fontweight='bold', pad=20)
ax.set_xticks(x)
ax.set_xticklabels([f"{i+1}. {final_performance[exp['name']]['label'].split(' (')[0]}" 
                     for i, exp in enumerate(experiments) if exp['name'] in final_performance], 
                    rotation=45, ha='right', fontsize=10)
ax.legend(loc='upper right', fontsize=12)
ax.grid(True, alpha=0.3, axis='y', linestyle='--')

plt.tight_layout()

# Save
png_path = os.path.join(OUTPUT_DIR, 'fig3_final_performance_comparison.png')
svg_path = os.path.join(OUTPUT_DIR, 'fig3_final_performance_comparison.svg')
plt.savefig(png_path, dpi=300, bbox_inches='tight')
plt.savefig(svg_path, dpi=300, bbox_inches='tight')
print(f"Saved: {png_path}")
print(f"Saved: {svg_path}")

plt.show()

In [None]:
# Figure 4: Species vs No-Species Comparison
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

comparisons = [
    ('exp1_earth4d_only', 'exp5_earth4d_species', 'Earth4D', axes[0, 0]),
    ('exp2_aef_only', 'exp6_aef_species', 'AEF', axes[0, 1]),
    ('exp3_earth4d_aef', 'exp7_earth4d_aef_species', 'Earth4D + AEF', axes[1, 0]),
    ('exp4_earth4d_aef_daymet', 'exp8_full_model', 'Earth4D + AEF + Daymet', axes[1, 1])
]

for exp_no_species, exp_with_species, base_label, ax in comparisons:
    if exp_no_species in all_metrics and exp_with_species in all_metrics:
        # Plot without species
        df_no = all_metrics[exp_no_species]['data']
        ax.plot(df_no['epoch'], df_no['temporal_mae'], 
                label=f'{base_label} (No Species)', 
                color=color_map[exp_no_species], 
                linewidth=2.5,
                linestyle='--',
                alpha=0.9)
        
        # Plot with species
        df_with = all_metrics[exp_with_species]['data']
        ax.plot(df_with['epoch'], df_with['temporal_mae'], 
                label=f'{base_label} + Species', 
                color=color_map[exp_with_species], 
                linewidth=2.5,
                alpha=0.9)
        
        ax.set_xlabel('Epoch', fontsize=12, fontweight='bold')
        ax.set_ylabel('Temporal Test MAE (pp)', fontsize=12, fontweight='bold')
        ax.set_title(f'{base_label}: Species Impact', fontsize=13, fontweight='bold')
        ax.legend(loc='upper right', fontsize=11)
        ax.grid(True, alpha=0.3, linestyle='--')
        ax.set_xlim(0, EPOCHS)

fig.suptitle('Impact of Species Embeddings (768D) on Different Feature Sets', 
             fontsize=16, fontweight='bold', y=1.00)
plt.tight_layout()

# Save
png_path = os.path.join(OUTPUT_DIR, 'fig4_species_impact_comparison.png')
svg_path = os.path.join(OUTPUT_DIR, 'fig4_species_impact_comparison.svg')
plt.savefig(png_path, dpi=300, bbox_inches='tight')
plt.savefig(svg_path, dpi=300, bbox_inches='tight')
print(f"Saved: {png_path}")
print(f"Saved: {svg_path}")

plt.show()

## 7. Summary Table

In [None]:
# Create summary table
summary_data = []

for exp in experiments:
    exp_name = exp['name']
    if exp_name in final_performance:
        perf = final_performance[exp_name]
        summary_data.append({
            'Experiment': perf['label'],
            'Train MAE': f"{perf['train_mae']:.2f}",
            'Temporal MAE': f"{perf['temporal_mae']:.2f}",
            'Spatial MAE': f"{perf['spatial_mae']:.2f}",
            'Random MAE': f"{perf['random_mae']:.2f}"
        })

summary_df = pd.DataFrame(summary_data)

print("\n" + "="*100)
print(f"FINAL PERFORMANCE SUMMARY (After {EPOCHS} Epochs)")
print("="*100)
print(summary_df.to_string(index=False))
print("="*100)

# Save table
csv_path = os.path.join(OUTPUT_DIR, 'ablation_summary_table.csv')
summary_df.to_csv(csv_path, index=False)
print(f"\nSummary table saved to: {csv_path}")

## 8. Package and Download Results

In [None]:
import zipfile
from datetime import datetime

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
zip_name = f"ablation_study_results_{timestamp}.zip"
zip_path = os.path.join('/content', zip_name)

print(f"Creating results package: {zip_name}")

with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
    # Add all plots (PNG and SVG)
    for file in os.listdir(OUTPUT_DIR):
        if file.endswith(('.png', '.svg', '.csv')):
            file_path = os.path.join(OUTPUT_DIR, file)
            zipf.write(file_path, file)
            print(f"  Added: {file}")
    
    # Add metrics from all experiments
    for exp in experiments:
        exp_dir = os.path.join(OUTPUT_DIR, exp['name'])
        if os.path.exists(exp_dir):
            for file in os.listdir(exp_dir):
                if file.endswith('.csv'):
                    file_path = os.path.join(exp_dir, file)
                    arc_name = os.path.join(exp['name'], file)
                    zipf.write(file_path, arc_name)
                    print(f"  Added: {arc_name}")

zip_size = os.path.getsize(zip_path) / (1024*1024)
print(f"\nResults package created: {zip_name} ({zip_size:.1f} MB)")

# Download
from google.colab import files
files.download(zip_path)
print("Download started!")

## 9. Final Summary

In [None]:
print("\n" + "="*100)
print("ABLATION STUDY COMPLETE")
print("="*100)
print(f"\nConfiguration:")
print(f"  Epochs per experiment: {EPOCHS}")
print(f"  Species embedding dimension: {SPECIES_DIM}D")
print(f"  Total experiments: 8")
print(f"\nExperiments completed:")
for i, exp in enumerate(experiments, 1):
    print(f"  {i}. {exp['label']}")
print(f"\nFigures generated:")
print(f"  â€¢ Figure 1: Training loss curves (all experiments)")
print(f"  â€¢ Figure 2: Test performance (temporal, spatial, random)")
print(f"  â€¢ Figure 3: Final performance comparison (bar chart)")
print(f"  â€¢ Figure 4: Species impact comparison")
print(f"\nAll figures saved in:")
print(f"  â€¢ PNG format (DPI 300)")
print(f"  â€¢ SVG format (DPI 300, vector graphics)")
print(f"\nResults location: {OUTPUT_DIR}")
print(f"Download package: {zip_name}")
print("="*100)