# üìò Module Usage Examples

This notebook demonstrates how to import and use functions from the `src/` modules.

**Modules available:**
- `utils.py` - Data loading & metadata
- `preprocessing.py` - Data preprocessing
- `visualization.py` - Plotting functions
- `models.py` - CNN architectures

**See:** `src/README.md` for detailed documentation

## 1. Setup: Import Modules

In [None]:
# Standard imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

# Import from src/ modules
# Option 1: Import specific functions
from src import (
    # Utils
    load_tiff,
    check_tiff_metadata,
    get_tiff_stats,
    load_ground_truth,
    
    # Preprocessing
    normalize_band,
    handle_nan,
    extract_patch,
    
    # Visualization
    plot_band,
    plot_band_comparison,
    plot_statistics,
    plot_indices_comparison,
)

# Option 2: Import modules
from src import utils
from src import preprocessing
from src import visualization
from src import models

print("‚úÖ Modules imported successfully!")
print("\nAvailable functions:")
print("  - load_tiff, check_tiff_metadata, get_tiff_stats")
print("  - normalize_band, handle_nan, extract_patch")
print("  - plot_band, plot_band_comparison, plot_statistics")
print("  - models.get_model, models.count_parameters")

## 2. Example: Check TIFF Metadata

In [None]:
# Define file path
s1_file = Path('../data/raw/sentinel1/S1_2024_02_04_matched_S2_2024_01_30.tif')

if s1_file.exists():
    # Check metadata
    print("üìä Checking metadata...")
    meta = check_tiff_metadata(s1_file, verbose=True)
    
    # Access metadata
    print(f"\n‚úÖ File has {meta['bands']} bands")
    print(f"‚úÖ Size: {meta['width']} √ó {meta['height']} pixels")
    print(f"‚úÖ Memory: {meta['memory_mb']:.2f} MB")
else:
    print(f"‚ö†Ô∏è File not found: {s1_file}")

## 3. Example: Get Band Statistics

In [None]:
if s1_file.exists():
    print("üìà Getting band statistics...")
    stats = get_tiff_stats(s1_file, sample_size=1000)
    
    print("\n‚úÖ Statistics DataFrame:")
    print(stats[['band', 'mean', 'std', 'min', 'max', 'nan_percent']])
else:
    print(f"‚ö†Ô∏è File not found: {s1_file}")

## 4. Example: Load and Visualize Band

In [None]:
if s1_file.exists():
    print("üì• Loading band 1...")
    
    # Load using window to get a subset (faster)
    import rasterio
    with rasterio.open(s1_file) as src:
        # Get center 500x500 pixels
        center_x = src.width // 2
        center_y = src.height // 2
        window = rasterio.windows.Window(center_x - 250, center_y - 250, 500, 500)
        
        # Load band 1
        data = src.read(1, window=window)
    
    print(f"‚úÖ Loaded data shape: {data.shape}")
    
    # Visualize using our plotting function
    print("\nüìä Plotting...")
    plot_band(data, title='S1 VH 2024 (500x500 sample)', cmap='viridis')
else:
    print(f"‚ö†Ô∏è File not found: {s1_file}")

## 5. Example: Load Ground Truth

In [None]:
gt_file = Path('../data/raw/ground_truth/Training_Points_CSV.csv')

if gt_file.exists():
    print("üì• Loading ground truth...")
    gt_df = load_ground_truth(gt_file)
    
    print(f"\n‚úÖ Loaded {len(gt_df)} points")
    print(f"\nüìä Sample data:")
    print(gt_df.head())
    
    print(f"\nüè∑Ô∏è Class distribution:")
    print(gt_df['label'].value_counts())
else:
    print(f"‚ö†Ô∏è File not found: {gt_file}")

## 6. Example: Data Preprocessing

In [None]:
# Create dummy data with NaN
dummy_data = np.random.randn(100, 100)
dummy_data[30:35, 40:45] = np.nan  # Add some NaN values

print(f"Original data: {dummy_data.shape}")
print(f"NaN count: {np.isnan(dummy_data).sum()}")

# Handle NaN
print("\nüîß Handling NaN...")
clean_data = handle_nan(dummy_data, method='interpolate')
print(f"After interpolation - NaN count: {np.isnan(clean_data).sum()}")

# Normalize
print("\nüîß Normalizing...")
norm_data = normalize_band(clean_data, method='standardize')
print(f"Normalized - Mean: {norm_data.mean():.4f}, Std: {norm_data.std():.4f}")

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].imshow(dummy_data, cmap='viridis')
axes[0].set_title('Original (with NaN)')
axes[0].axis('off')

axes[1].imshow(clean_data, cmap='viridis')
axes[1].set_title('After NaN handling')
axes[1].axis('off')

axes[2].imshow(norm_data, cmap='viridis')
axes[2].set_title('After normalization')
axes[2].axis('off')

plt.tight_layout()
plt.show()

## 7. Example: CNN Models

In [None]:
import torch

print("üß† Testing CNN models...\n")

# Test all 3 models
model_names = ['spatial_cnn', 'multiscale_cnn', 'shallow_unet']

for model_name in model_names:
    print(f"{'='*80}")
    print(f"Model: {model_name.upper()}")
    print(f"{'='*80}")
    
    # Get model
    model = models.get_model(model_name, in_channels=14)
    
    # Count parameters
    n_params = models.count_parameters(model)
    print(f"\nüìä Parameters: {n_params:,}")
    
    # Test forward pass
    x = torch.randn(2, 18, 128, 128)  # Batch of 2 patches
    
    with torch.no_grad():
        y = model(x)
    
    print(f"\n‚úÖ Forward pass successful:")
    print(f"   Input:  {tuple(x.shape)}")
    print(f"   Output: {tuple(y.shape)}")
    print(f"   Output range: [{y.min():.4f}, {y.max():.4f}]")
    print()

## 8. Example: Model Summary

In [None]:
# Get Shallow U-Net (the most complex)
model = models.get_model('shallow_unet', in_channels=14)

# Print detailed summary
models.print_model_summary(model, input_size=(1, 18, 128, 128))

## 9. Summary

### ‚úÖ What We Learned

1. **Import modules** from `src/` package
2. **Load and check** TIFF metadata
3. **Get statistics** from bands
4. **Visualize** bands with plotting functions
5. **Load ground truth** labels
6. **Preprocess data** (handle NaN, normalize)
7. **Use CNN models** for predictions

### üìö Next Steps

- See `01_data_exploration.ipynb` for full data analysis
- Use `src.preprocessing.create_patches_dataset()` to create training data
- Train models (will be in `02_training_analysis.ipynb`)
- Evaluate results (will be in `03_results_visualization.ipynb`)

### üìñ Documentation

- **Detailed docs:** `src/README.md`
- **Function help:** Use `help(function_name)` or `function_name?` in Jupyter
- **Module source:** Look at `src/*.py` files

### üí° Tips

1. **Reload modules** during development:
   ```python
   import importlib
   from src import utils
   importlib.reload(utils)
   ```

2. **View docstrings:**
   ```python
   help(load_tiff)
   # or in Jupyter:
   load_tiff?
   ```

3. **Import what you need:**
   ```python
   # Only import specific functions
   from src import load_tiff, plot_band
   ```