# Lab 7: Model Evaluation (2 Hours)

## ‚è±Ô∏è Time Allocation
- **Part 1 (30 min):** Load model and generate predictions
- **Part 2 (30 min):** Calculate metrics
- **Part 3 (30 min):** Confusion matrix and visualization
- **Part 4 (30 min):** Error analysis

## üéØ Learning Objectives

### Core (Essential)
- ‚úÖ Load trained model and test data
- ‚úÖ Generate predictions
- ‚úÖ Compute accuracy, precision, recall, F1
- ‚úÖ Create confusion matrix
- ‚úÖ Visualize correct/incorrect predictions

### Optional (For Early Finishers)
- üîµ ROC and PR curves
- üîµ Per-class analysis
- üîµ Spatial error mapping
- üîµ Model comparison
- üîµ Statistical testing

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import os
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, classification_report
)
from tqdm import tqdm

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.dpi'] = 100

print("‚úÖ Libraries imported successfully")

### Load Test Dataset

In [None]:
# Define paths
project_dir = Path(os.getenv('PROJECT_training2600')) / 'my_workspace'
data_dir = project_dir / 'data' / 'preprocessed'
model_dir = project_dir / 'models'
results_dir = project_dir / 'results' / 'evaluation'
results_dir.mkdir(parents=True, exist_ok=True)

print(f"üìÇ Data directory: {data_dir}")
print(f"üìÇ Model directory: {model_dir}")
print(f"üìÇ Results directory: {results_dir}")

# Load test data
print("\nüì• Loading test data...")
X_test = np.load(data_dir / 'X_test.npy')
y_test = np.load(data_dir / 'y_test.npy')

# Load metadata
with open(data_dir / 'dataset_metadata.json', 'r') as f:
    metadata = json.load(f)

class_names = metadata['class_names']
num_classes = metadata['num_classes']

print(f"\n‚úÖ Test data loaded:")
print(f"   Shape: {X_test.shape}")
print(f"   Classes: {num_classes}")
print(f"   Samples: {len(X_test)}")

### Load Trained Model

In [None]:
# Re-define model architecture (must match training)
class SatelliteCNN(nn.Module):
    """Baseline CNN for satellite image classification."""
    
    def __init__(self, in_channels=6, num_classes=7):
        super(SatelliteCNN, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        self.conv4 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        self.gap = nn.AdaptiveAvgPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.gap(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SatelliteCNN(
    in_channels=metadata['num_bands'],
    num_classes=num_classes
).to(device)

# Load checkpoint
checkpoint = torch.load(model_dir / 'best_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"‚úÖ Model loaded from checkpoint (epoch {checkpoint['epoch']})")
print(f"   Validation accuracy: {checkpoint['val_acc']:.2f}%")
print(f"   Device: {device}")

## Section 2: Generate Predictions (5 min)

### Create Test DataLoader

In [None]:
class SatelliteDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.FloatTensor(X)
        self.y = torch.LongTensor(y)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

test_dataset = SatelliteDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(f"‚úÖ Test DataLoader created ({len(test_loader)} batches)")

### Run Inference

In [None]:
print("üîÆ Generating predictions on test set...\n")

all_predictions = []
all_labels = []
all_probabilities = []

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc='Inference'):
        images = images.to(device)
        
        # Get predictions
        outputs = model(images)
        probabilities = torch.softmax(outputs, dim=1)
        _, predicted = outputs.max(1)
        
        # Store results
        all_predictions.extend(predicted.cpu().numpy())
        all_labels.extend(labels.numpy())
        all_probabilities.extend(probabilities.cpu().numpy())

all_predictions = np.array(all_predictions)
all_labels = np.array(all_labels)
all_probabilities = np.array(all_probabilities)

print(f"\n‚úÖ Predictions generated for {len(all_predictions)} samples")

## Section 3: Calculate Classification Metrics (8 min)

### Overall Metrics

In [None]:
# Calculate metrics
accuracy = accuracy_score(all_labels, all_predictions)
precision = precision_score(all_labels, all_predictions, average='weighted', zero_division=0)
recall = recall_score(all_labels, all_predictions, average='weighted', zero_division=0)
f1 = f1_score(all_labels, all_predictions, average='weighted', zero_division=0)

print("="*60)
print("üìä OVERALL CLASSIFICATION METRICS")
print("="*60)
print(f"Accuracy:  {accuracy*100:.2f}%")
print(f"Precision: {precision*100:.2f}%")
print(f"Recall:    {recall*100:.2f}%")
print(f"F1-Score:  {f1*100:.2f}%")
print("="*60)

### Per-Class Metrics

In [None]:
# Detailed classification report
report = classification_report(
    all_labels,
    all_predictions,
    target_names=class_names,
    digits=3,
    zero_division=0
)

print("\nüìã PER-CLASS CLASSIFICATION REPORT")
print("="*80)
print(report)

# Save report
with open(results_dir / 'classification_report.txt', 'w') as f:
    f.write(report)

print(f"\nüíæ Saved classification report to: {results_dir / 'classification_report.txt'}")

### Metrics Visualization

In [None]:
# Per-class metrics
precision_per_class = precision_score(all_labels, all_predictions, average=None, zero_division=0)
recall_per_class = recall_score(all_labels, all_predictions, average=None, zero_division=0)
f1_per_class = f1_score(all_labels, all_predictions, average=None, zero_division=0)

# Plot
fig, ax = plt.subplots(figsize=(12, 6))

x = np.arange(len(class_names))
width = 0.25

ax.bar(x - width, precision_per_class, width, label='Precision', alpha=0.8)
ax.bar(x, recall_per_class, width, label='Recall', alpha=0.8)
ax.bar(x + width, f1_per_class, width, label='F1-Score', alpha=0.8)

ax.set_xlabel('Land Cover Class', fontsize=12)
ax.set_ylabel('Score', fontsize=12)
ax.set_title('Per-Class Performance Metrics', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(class_names, rotation=45, ha='right')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')
ax.set_ylim([0, 1.05])

plt.tight_layout()
plt.savefig(results_dir / 'per_class_metrics.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"üíæ Saved metrics plot to: {results_dir / 'per_class_metrics.png'}")

## Section 4: Confusion Matrix (8 min)

### Compute Confusion Matrix

In [None]:
# Compute confusion matrix
cm = confusion_matrix(all_labels, all_predictions)

print("üìä Confusion Matrix (counts):")
print(cm)

### Visualize Confusion Matrix

In [None]:
# Normalize confusion matrix (percentage)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Plot
fig, axes = plt.subplots(1, 2, figsize=(18, 7))

# Plot 1: Counts
sns.heatmap(
    cm,
    annot=True,
    fmt='d',
    cmap='Blues',
    xticklabels=class_names,
    yticklabels=class_names,
    ax=axes[0],
    cbar_kws={'label': 'Count'}
)
axes[0].set_title('Confusion Matrix (Counts)', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Predicted Label', fontsize=12)
axes[0].set_ylabel('True Label', fontsize=12)

# Plot 2: Normalized (percentages)
sns.heatmap(
    cm_normalized,
    annot=True,
    fmt='.2%',
    cmap='Greens',
    xticklabels=class_names,
    yticklabels=class_names,
    ax=axes[1],
    cbar_kws={'label': 'Percentage'}
)
axes[1].set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Predicted Label', fontsize=12)
axes[1].set_ylabel('True Label', fontsize=12)

plt.tight_layout()
plt.savefig(results_dir / 'confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nüíæ Saved confusion matrix to: {results_dir / 'confusion_matrix.png'}")

### Analyze Common Misclassifications

In [None]:
# Find most common misclassifications (off-diagonal elements)
print("\nüîç Most Common Misclassifications:")
print("="*60)

misclassifications = []
for i in range(len(class_names)):
    for j in range(len(class_names)):
        if i != j and cm[i, j] > 0:
            misclassifications.append({
                'true': class_names[i],
                'predicted': class_names[j],
                'count': cm[i, j],
                'percentage': cm_normalized[i, j] * 100
            })

# Sort by count
misclassifications = sorted(misclassifications, key=lambda x: x['count'], reverse=True)

for idx, mc in enumerate(misclassifications[:5], 1):
    print(f"{idx}. {mc['true']} ‚Üí {mc['predicted']}: "
          f"{mc['count']} samples ({mc['percentage']:.1f}% of {mc['true']})")

print("="*60)

## Section 5: Visualize Predictions (9 min)

### Display Correct Predictions

In [None]:
# Find correctly classified samples
correct_indices = np.where(all_predictions == all_labels)[0]

# Helper function to display RGB
def normalize_for_display(img, percentile=2):
    vmin, vmax = np.percentile(img, [percentile, 100-percentile])
    return np.clip((img - vmin) / (vmax - vmin), 0, 1)

# Plot 6 correct predictions
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for idx, ax_idx in enumerate(np.random.choice(correct_indices, 6, replace=False)):
    img = X_test[ax_idx]
    rgb = img[[2, 1, 0], :, :].transpose(1, 2, 0)  # B4, B3, B2
    rgb_display = normalize_for_display(rgb)
    
    true_label = class_names[all_labels[ax_idx]]
    pred_label = class_names[all_predictions[ax_idx]]
    confidence = all_probabilities[ax_idx][all_predictions[ax_idx]] * 100
    
    axes[idx].imshow(rgb_display)
    axes[idx].set_title(
        f"‚úì True: {true_label}\nPred: {pred_label} ({confidence:.1f}%)",
        fontsize=10,
        color='green'
    )
    axes[idx].axis('off')

plt.suptitle('Correctly Classified Samples', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(results_dir / 'correct_predictions.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"üíæ Saved correct predictions to: {results_dir / 'correct_predictions.png'}")

### Display Incorrect Predictions

In [None]:
# Find incorrectly classified samples
incorrect_indices = np.where(all_predictions != all_labels)[0]

if len(incorrect_indices) > 0:
    # Plot 6 incorrect predictions
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    sample_size = min(6, len(incorrect_indices))
    for idx, ax_idx in enumerate(np.random.choice(incorrect_indices, sample_size, replace=False)):
        img = X_test[ax_idx]
        rgb = img[[2, 1, 0], :, :].transpose(1, 2, 0)
        rgb_display = normalize_for_display(rgb)
        
        true_label = class_names[all_labels[ax_idx]]
        pred_label = class_names[all_predictions[ax_idx]]
        confidence = all_probabilities[ax_idx][all_predictions[ax_idx]] * 100
        
        axes[idx].imshow(rgb_display)
        axes[idx].set_title(
            f"‚úó True: {true_label}\nPred: {pred_label} ({confidence:.1f}%)",
            fontsize=10,
            color='red'
        )
        axes[idx].axis('off')
    
    plt.suptitle('Misclassified Samples', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(results_dir / 'incorrect_predictions.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"üíæ Saved incorrect predictions to: {results_dir / 'incorrect_predictions.png'}")
else:
    print("üéâ Perfect classification! No errors to display.")

## Section 6: Performance Summary (5 min)

### Class-Wise Performance Summary

In [None]:
# Create summary dataframe
import pandas as pd

summary_data = []
for i, class_name in enumerate(class_names):
    class_mask = (all_labels == i)
    class_count = class_mask.sum()
    correct_count = ((all_labels == i) & (all_predictions == i)).sum()
    
    summary_data.append({
        'Class': class_name,
        'Samples': class_count,
        'Correct': correct_count,
        'Precision': precision_per_class[i],
        'Recall': recall_per_class[i],
        'F1-Score': f1_per_class[i]
    })

summary_df = pd.DataFrame(summary_data)

print("\nüìä CLASS-WISE PERFORMANCE SUMMARY")
print("="*80)
print(summary_df.to_string(index=False))
print("="*80)

# Save summary
summary_df.to_csv(results_dir / 'class_performance_summary.csv', index=False)
print(f"\nüíæ Saved summary to: {results_dir / 'class_performance_summary.csv'}")

### Generate Final Report

In [None]:
# Comprehensive evaluation report
report_content = f"""
{'='*80}
MODEL EVALUATION REPORT
{'='*80}

DATASET INFORMATION
{'-'*80}
Test Samples: {len(X_test)}
Classes: {num_classes}
Class Names: {', '.join(class_names)}
Image Size: {X_test.shape[2]}x{X_test.shape[3]} pixels
Bands: {X_test.shape[1]}

OVERALL METRICS
{'-'*80}
Accuracy:  {accuracy*100:.2f}%
Precision: {precision*100:.2f}%
Recall:    {recall*100:.2f}%
F1-Score:  {f1*100:.2f}%

BEST PERFORMING CLASSES
{'-'*80}
"""

# Add top 3 classes by F1-score
top_classes = summary_df.nlargest(3, 'F1-Score')
for _, row in top_classes.iterrows():
    report_content += f"{row['Class']:15s} - F1: {row['F1-Score']:.3f}, Samples: {row['Samples']}\n"

report_content += f"""
WORST PERFORMING CLASSES
{'-'*80}
"""

# Add bottom 3 classes by F1-score
worst_classes = summary_df.nsmallest(3, 'F1-Score')
for _, row in worst_classes.iterrows():
    report_content += f"{row['Class']:15s} - F1: {row['F1-Score']:.3f}, Samples: {row['Samples']}\n"

report_content += f"""
KEY FINDINGS
{'-'*80}
- Total correct predictions: {(all_predictions == all_labels).sum()} / {len(all_labels)}
- Total misclassifications: {(all_predictions != all_labels).sum()}
- Most confused pair: {misclassifications[0]['true']} ‚Üí {misclassifications[0]['predicted']} ({misclassifications[0]['count']} cases)

RECOMMENDATIONS
{'-'*80}
1. Classes with low F1-scores may need more training samples
2. Consider data augmentation for underrepresented classes
3. Analyze misclassified samples for common patterns
4. Experiment with different architectures or pre-trained models

{'='*80}
Generated: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}
{'='*80}
"""

# Save report
with open(results_dir / 'evaluation_report.txt', 'w') as f:
    f.write(report_content)

print(report_content)
print(f"\nüíæ Saved evaluation report to: {results_dir / 'evaluation_report.txt'}")

## Summary & Course Conclusion

### What We Covered in This Lab
‚úÖ Loaded trained model and test data  
‚úÖ Generated predictions on test set  
‚úÖ Calculated classification metrics  
‚úÖ Visualized confusion matrices  
‚úÖ Analyzed model strengths and weaknesses  
‚úÖ Created comprehensive evaluation report  

### Course Journey Recap
Over 6 labs, we covered the complete ML pipeline for Earth Observation:

1. **Lab 1:** HPC access and Judoor setup
2. **Lab 2:** Jupyter-JSC and Git version control
3. **Lab 3:** Sentinel-2 data acquisition with GEE
4. **Lab 4:** Data preprocessing and patch extraction
5. **Lab 5.1:** Baseline CNN model training
6. **Lab 5.2:** Comprehensive model evaluation (this lab)

### Key Evaluation Concepts
- **Accuracy:** Overall correctness (can be misleading with imbalanced data)
- **Precision:** Of predicted positives, how many are actually positive?
- **Recall:** Of actual positives, how many did we find?
- **F1-Score:** Harmonic mean of precision and recall
- **Confusion Matrix:** Visualizes prediction patterns

### When to Use Each Metric
- **Accuracy:** Overall performance with balanced classes
- **Precision:** When false positives are costly (e.g., urban detection)
- **Recall:** When false negatives are costly (e.g., fire detection)
- **F1-Score:** General balance, especially with imbalanced classes

### Next Steps for Your Projects
1. **Improve Model Performance:**
   - Try pre-trained models (ResNet, EfficientNet)
   - Implement data augmentation
   - Use foundation models (TerraTorch, Prithvi)
   - Ensemble multiple models

2. **Expand Dataset:**
   - Acquire more Sentinel-2 scenes
   - Include temporal information (time series)
   - Add auxiliary data (elevation, climate)

3. **Advanced Techniques:**
   - Semantic segmentation (pixel-level classification)
   - Change detection (compare multiple dates)
   - Uncertainty estimation
   - Active learning

4. **Production Deployment:**
   - Export model to ONNX for faster inference
   - Create web interface with Flask/FastAPI
   - Scale inference with distributed computing
   - Monitor model performance over time

### Additional Resources
- **Metrics Tutorial:** https://scikit-learn.org/stable/modules/model_evaluation.html
- **Confusion Matrix:** https://en.wikipedia.org/wiki/Confusion_matrix
- **Foundation Models:** https://github.com/NASA-IMPACT/terratorch
- **EO Datasets:** https://github.com/satellite-image-deep-learning/datasets

### Course Feedback
Please share your feedback:
- What worked well?
- What could be improved?
- Topics you'd like to explore further?

### Stay Connected
- **Course Slack:** Continue discussions and share results
- **GitHub:** Share your code and projects
- **Email:** s.hashim@fz-juelich.de for questions

---

## üéì Congratulations!

You've successfully completed the Machine Learning for Earth Observation course!

You now have the skills to:
- ‚úÖ Access and process satellite imagery
- ‚úÖ Build and train deep learning models
- ‚úÖ Evaluate model performance rigorously
- ‚úÖ Deploy models on HPC infrastructure

**Keep learning, keep building, and keep exploring!** üöÄüõ∞Ô∏èüåç

---

## ‚úÖ Lab 7 Completion Checklist

### Core Tasks
- [ ] Model loaded successfully
- [ ] Predictions generated
- [ ] Metrics calculated (accuracy, precision, recall, F1)
- [ ] Confusion matrix created
- [ ] Sample visualizations shown

### Optional Tasks
- [ ] ROC/PR curves plotted
- [ ] Per-class deep dive completed
- [ ] Spatial analysis performed
- [ ] Comparison with baseline

## üöÄ Next Lab
**Lab 8: TerraTorch Fine-tuning** - Use foundation models for better performance