# InstaGeo Data Splitter Demo

This notebook demonstrates how to use the `data_splitter.py` module to split geospatial datasets into train, validation, and test sets while maintaining geographical coherence. This currently works best for worldwide distributed data.

## Overview
- Splits datasets into train, validation, and test sets
- Maintains geographical coherence using MGRS tiles
- Supports multiple splitting strategies (KMeans, MGRS, year-based, random)
- Generates visualizations for split analysis
- Handles temporal and spatial distribution

The data splitter uses a progressive fallback approach to ensure splitting always succeeds:
- KMeans Clustering (Default): Groups samples by spatial coordinates for geographical coherence
- MGRS Tile-based Splitting: Groups by MGRS tiles when KMeans fails
- Year-based Splitting: Groups by year for temporal coherence when MGRS fails
- Random Splitting: Random assignment when all other strategies fail
## Prerequisites
1. InstaGeo installed
2. A CSV file with Input and Label columns

> **💡 Tip**: Run the examples in the [Chip Creator Demo Notebook](chip_creator_demo.ipynb) first to generate sample chips and segmentation maps. The `data_splitter.py` uses the dataset CSV containing chips and segmentation maps filenames/paths.


## Setup and Configuration

In [None]:
import os
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import re
from typing import List, Set, Tuple
import subprocess
import os
from IPython.display import Image, display
import matplotlib.pyplot as plt
np.random.seed(42)


def setup_directories():
    """Set up working directories for the demo."""
    notebook_dir = Path.cwd()
    data_dir = notebook_dir / "demo_data"
    
    # Create directories
    data_dir.mkdir(exist_ok=True)
    
    print(f"Data directory: {data_dir}")
    return data_dir

data_dir = setup_directories()


## Sample Data Creation


In [None]:
# It should be similar to the CSV we obtain after generating chips 
# and seg maps (hls_raster_dataset.csv for instance). 

def create_sample_dataset(data_dir, n_samples=2000):
    """Create a sample dataset CSV for demonstration."""
    
    # Create sample data with MGRS tiles and years
    sample_data = []
    # Worldwide MGRS tiles representing different continents/regions
    mgrs_tiles = [
        '18TWL', '18TXM', '18TYN',
        '11TNK', '11TNL', '11TNM',
        '12TUK', '12TUL', '12TUM',
        '32TNT', '32TNU', '32TNV',
        '50TMT', '50TMU', '50TMV',
        '36MTT', '36MTU', '36MTV',
        '20KMU', '20KMV', '15TUF', 
        '15TUG', '15TUH', '15TUJ', 
        '15TUK', '46TGT', '46TGU', 
        '46TGV', '52TGT', '52TGU', 
        '52TGV', '39MTT', '39MTU', 
        '39MTV', '33MTT', '33MTU', 
        '33MTV', '23KMT', '23KMU', 
        '23KMV', 
    ]
    years = [2020, 2021, 2022, 2023, 2024]
    
    for i in range(n_samples):
        mgrs_tile = np.random.choice(mgrs_tiles)
        year = np.random.choice(years)
        month = np.random.randint(1, 13)
        day = np.random.randint(1, 29)
        
        # Create file paths similar to actual chip creator output
        # Include year in filename so it can be extracted
        chip_filename = f"chip_{i+1:x}_{mgrs_tile}_{year}{month:02d}{day:02d}.tif"
        seg_map_filename = f"label_{i+1:x}_{mgrs_tile}_{year}{month:02d}{day:02d}.tif"
        
        sample_data.append({
            'Input': f"chips/{chip_filename}",
            'Label': f"seg_maps/{seg_map_filename}",
        })
    
    df = pd.DataFrame(sample_data)
    
    # Save sample dataset
    dataset_file = data_dir / "sample_dataset.csv"
    df.to_csv(dataset_file, index=False)
    
    print(f"Sample dataset saved to: {dataset_file}")
    return df, dataset_file

# Create sample dataset
df, dataset_file = create_sample_dataset(data_dir)


In [None]:
def run_data_splitter_example(example_num, output_dir, extra_args=""):
    """Run a data splitter example and display results."""
    cmd = f"""
    python -m instageo.data.data_splitter \
        --input_file="demo_data/sample_dataset.csv" \
        --output_dir="{output_dir}" \
        {extra_args}
    """

    print(f"=== Running Example {example_num} ===")
    print(f"Output directory: {output_dir}")

    try:
        result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
        if result.returncode == 0:
            print("✅ Command executed successfully!")

            # Display visualizations if they exist
            viz_dir = Path(output_dir) / "splits" / "visualizations"
            if viz_dir.exists():
                print(f"\n📊 Generated visualizations:")

                # Look for specific visualization files
                map_viz = None
                year_viz = None

                for viz_file in viz_dir.glob("*.png"):
                    filename = viz_file.name.lower()
                    print(f"  - {viz_file.name}")

                    # Identify visualization types
                    if "map" in filename:
                        map_viz = viz_file
                    elif "year" in filename:
                        year_viz = viz_file

                # Display visualizations with labels
                if map_viz:
                    print(f"\n🗺️  Spatial Distribution Map:")
                    display(Image(str(map_viz)))

                if year_viz:
                    print(f"\n📅 Temporal Distribution by Year:")
                    display(Image(str(year_viz)))

            else:
                print("No visualizations found.")

        else:
            print(f"❌ Command failed with error:")
            print(result.stderr)
    except Exception as e:
        print(f"❌ Error running command: {e}")

## Example 1: Basic splitting
```bash
python -m instageo.data.data_splitter \
    --input_file="demo_data/sample_dataset.csv" \
    --output_dir="split_output_1"
```

In [None]:
run_data_splitter_example(1, "split_output_1")

## Example 2: Custom KMeans Clustering
```bash
python -m instageo.data.data_splitter \
    --input_file="demo_data/sample_dataset.csv" \
    --output_dir="split_output_2" \
    --use_kmeans=True \
    --n_clusters=25
```

In [None]:
run_data_splitter_example(2, "split_output_2", "--use_kmeans=True --n_clusters=25")

## Example 3: "Connected MGRS Tiles" based Splitting
```bash
python -m instageo.data.data_splitter \
    --input_file="demo_data/sample_dataset.csv" \
    --output_dir="split_output_3" \
    --use_kmeans=False \
    --distance_threshold=200.0
```

In [None]:
run_data_splitter_example(3, "split_output_3", "--use_kmeans=False --distance_threshold=200.0")

## Example 4: Train/Test Only (No Validation)
```bash
python -m instageo.data.data_splitter \
    --input_file="demo_data/sample_dataset.csv" \
    --output_dir="split_output_4" \
    --include_val=False \
    --use_kmeans=False \
    --distance_threshold=200.0 \
    --test_ratio=0.3
```

In [None]:
run_data_splitter_example(4, "split_output_4", "--use_kmeans=False --distance_threshold=200.0 --include_val=False --test_ratio=0.3")

# Example 5: Custom Split Ratios
```bash
python -m instageo.data.data_splitter \
    --input_file="demo_data/sample_dataset.csv" \
    --output_dir="split_output_5" \
    --test_ratio=0.15 \
    --val_ratio=0.15
```

In [None]:
run_data_splitter_example(5, "split_output_5", "--test_ratio=0.15 --val_ratio=0.15")

## Example 6: No Visualizations
```bash
python -m instageo.data.data_splitter \
    --input_file="demo_data/sample_dataset.csv" \
    --output_dir="split_output_6" \
    --visualize=False
```

In [None]:
run_data_splitter_example(6, "split_output_6", "--visualize=False")

## Results comparison

In [None]:
# Compare Results from Different Examples

def compare_split_results():
    """Compare the results from different splitting examples."""
    print("=== Comparing Split Results ===")
    
    results = []
    for i in range(1, 6):
        output_dir = f"split_output_{i}"
        splits_dir = Path(output_dir) / "splits"
        
        if splits_dir.exists():
            train_file = splits_dir / "train.csv"
            val_file = splits_dir / "val.csv"
            test_file = splits_dir / "test.csv"
            
            train_count = len(pd.read_csv(train_file)) if train_file.exists() else 0
            val_count = len(pd.read_csv(val_file)) if val_file.exists() else 0
            test_count = len(pd.read_csv(test_file)) if test_file.exists() else 0
            total = train_count + val_count + test_count
            
            results.append({
                'Example': i,
                'Train': train_count,
                'Val': val_count,
                'Test': test_count,
                'Total': total,
                'Train%': f"{train_count/total*100:.1f}%" if total > 0 else "0%",
                'Val%': f"{val_count/total*100:.1f}%" if total > 0 else "0%",
                'Test%': f"{test_count/total*100:.1f}%" if total > 0 else "0%"
            })
    
    if results:
        results_df = pd.DataFrame(results)
        print(results_df.to_string(index=False))
        
        # Create a comparison visualization
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # Bar chart of split counts
        x = range(len(results))
        train_counts = [r['Train'] for r in results]
        val_counts = [r['Val'] for r in results]
        test_counts = [r['Test'] for r in results]
        
        ax1.bar(x, train_counts, label='Train', alpha=0.8)
        ax1.bar(x, val_counts, bottom=train_counts, label='Val', alpha=0.8)
        ax1.bar(x, test_counts, bottom=[t+v for t,v in zip(train_counts, val_counts)], label='Test', alpha=0.8)
        
        ax1.set_xlabel('Example')
        ax1.set_ylabel('Number of Samples')
        ax1.set_title('Split Counts by Example')
        ax1.set_xticks(x)
        ax1.set_xticklabels([f'Ex {r["Example"]}' for r in results])
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Pie chart for the first example
        if results:
            first_example = results[0]
            sizes = [first_example['Train'], first_example['Val'], first_example['Test']]
            labels = ['Train', 'Val', 'Test']
            colors = ['skyblue', 'lightcoral', 'lightgreen']
            
            ax2.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=90)
            ax2.set_title(f'Example 1 Split Distribution')
        
        plt.tight_layout()
        plt.show()
    else:
        print("No results found. Run the examples first.")

# Compare results
compare_split_results()


## Summary

### 🎯 Key Takeaways:

1. **Data Splitter** maintains geographical coherence in train/validation/test splits
2. Supports **multiple splitting strategies** (KMeans, MGRS, year-based, random)
3. Generates **comprehensive visualizations** for split analysis
4. **Flexible configuration** for different use cases
   
| Parameter | Description | Default |
|-----------|-------------|---------|
| `input_file` | Path to input CSV file containing Input and Label columns | Required |
| `output_dir` | Base directory for output files | Required |
| `include_val` | Whether to include validation split | True |
| `include_test` | Whether to include test split | True |
| `test_ratio` | Ratio of data to use for test set (0.0-1.0) | 0.20 |
| `val_ratio` | Ratio of data to use for validation set (0.0-1.0) | 0.20 |
| `random_state` | Random seed for reproducibility | 42 |
| `visualize` | Whether to generate visualizations | True |
| `allow_group_overlap` | Whether to allow groups to be split across sets | True |
| `distance_threshold` | Maximum distance in km to consider MGRS tiles as close or connected | 400.0 |
| `n_clusters` | Number of clusters for KMeans | 20 |
| `use_kmeans` | Whether to use KMeans clustering | True |

### Troubleshooting Guide

| Issue | Problem | Solution | Command |
|-------|---------|----------|---------|
| **Invalid MGRS Tiles** | Cannot extract MGRS tiles from file paths | Check file naming convention | Ensure files contain MGRS tile codes in format: XXYYY |
| **Insufficient Groups** | Not enough groups for splitting | Reduce distance_threshold or use random splitting | `--distance_threshold=100.0` |
| **Imbalanced Splits** | Splits don't meet target proportions | Adjust ratios or allow group overlap | `--allow_group_overlap=True` |
| **Memory Issues** | Out of memory during clustering | Reduce n_clusters or process in batches | `--n_clusters=10` |
| **No Valid Splits** | Cannot create valid splits with any strategy | Check data quality and parameters | Verify input file format and content |
### 📚 Next Steps:
- Try running the commands with your own dataset
- Experiment with different splitting strategies
- Analyze the generated visualizations
- Use the splits in your ML pipeline

### 🔗 Related Demos:
- `chip_creator_demo.ipynb`: Point-based chip creation
- `raster_chip_creator_demo.ipynb`: Raster-based chip creation
- `data_cleaner_demo.ipynb`: Data cleaning and preprocessing
