# DrainageAI Demo - Google Colab Integration with BYOL (Unlabeled Data Focus)

This notebook demonstrates the DrainageAI workflow using Google Colab's GPU acceleration, focusing on the BYOL approach with unlabeled data only. It includes support for grayscale (single-channel) images and does not require MCP dependencies.

## Step 1: Check GPU Availability

In [None]:
import torch

print(f"GPU available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU name: {torch.cuda.get_device_name(0)}")
else:
    print("WARNING: No GPU detected. Processing will be slow.")

## Step 2: Install Dependencies

In [None]:
!pip install rasterio geopandas scikit-image matplotlib pytorch-lightning torch-geometric

## Step 3: Clone the DrainageAI Repository

In [None]:
# Clone the repository
!git clone https://github.com/yourusername/DrainageAI.git

%cd DrainageAI

## Step 4: Create MCP-Free Main Script

To avoid dependency issues with the Model Context Protocol (MCP), we'll create a simplified version of the main script that doesn't require MCP.

In [None]:
%%writefile main_no_mcp.py
"""
Simplified main script for DrainageAI without MCP dependencies.
"""

import os
import sys
import argparse
import torch
import numpy as np
import rasterio
import geopandas as gpd
from pathlib import Path
from shapely.geometry import LineString

from models import EnsembleModel, CNNModel, GNNModel, SelfSupervisedModel, SemiSupervisedModel, BYOLModel, GrayscaleBYOLModel
from preprocessing import DataLoader, ImageProcessor, GraphBuilder, Augmentation
from preprocessing.fixmatch_augmentation import WeakAugmentation, StrongAugmentation


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="DrainageAI: AI-powered drainage pipe detection")
    
    # Create subparsers for different commands
    subparsers = parser.add_subparsers(dest="command", help="Command to run")
    
    # Detect command
    detect_parser = subparsers.add_parser("detect", help="Detect drainage pipes in satellite imagery")
    detect_parser.add_argument("--imagery", required=True, help="Path to satellite imagery file")
    detect_parser.add_argument("--elevation", help="Path to elevation data file (optional)")
    detect_parser.add_argument("--indices", help="Path to spectral indices file (optional)")
    detect_parser.add_argument("--sar", help="Path to SAR imagery file (optional)")
    detect_parser.add_argument("--output", required=True, help="Path to save detection results")
    detect_parser.add_argument("--model", default="ensemble", 
                              choices=["ensemble", "cnn", "gnn", "ssl", "semi", "byol", "grayscale-byol"], 
                              help="Model type to use")
    detect_parser.add_argument("--weights", help="Path to model weights file (optional)")
    detect_parser.add_argument("--threshold", type=float, default=0.5, help="Confidence threshold for detection (0-1)")
    
    # Calculate indices command
    indices_parser = subparsers.add_parser("indices", help="Calculate spectral indices from multispectral imagery")
    indices_parser.add_argument("--imagery", required=True, help="Path to multispectral imagery file")
    indices_parser.add_argument("--output", required=True, help="Path to save indices as a multi-band raster")
    indices_parser.add_argument("--indices", default="ndvi,ndmi,msavi2", help="Comma-separated list of indices to calculate")
    indices_parser.add_argument("--red-band", type=int, default=3, help="Band number for red (default: 3)")
    indices_parser.add_argument("--nir-band", type=int, default=4, help="Band number for NIR (default: 4)")
    indices_parser.add_argument("--swir-band", type=int, default=5, help="Band number for SWIR (default: 5)")
    indices_parser.add_argument("--green-band", type=int, default=2, help="Band number for green (default: 2)")
    
    # Vectorize command
    vectorize_parser = subparsers.add_parser("vectorize", help="Convert raster detection results to vector format")
    vectorize_parser.add_argument("--input", required=True, help="Path to raster detection results")
    vectorize_parser.add_argument("--output", required=True, help="Path to save vector results")
    vectorize_parser.add_argument("--simplify", type=float, default=1.0, help="Tolerance for line simplification")
    
    return parser.parse_args()


def create_model(model_type):
    """Create a DrainageAI model."""
    if model_type == "ensemble":
        return EnsembleModel()
    elif model_type == "cnn":
        return CNNModel()
    elif model_type == "gnn":
        return GNNModel()
    elif model_type == "ssl":
        return SelfSupervisedModel(fine_tuned=True)
    elif model_type == "semi":
        return SemiSupervisedModel(pretrained=True)
    elif model_type == "byol":
        return BYOLModel(fine_tuned=True)
    elif model_type == "grayscale-byol":
        return GrayscaleBYOLModel(fine_tuned=True)
    else:
        raise ValueError(f"Invalid model type: {model_type}")


def load_model(model_type, weights_path=None, with_sar=False):
    """Load a DrainageAI model."""
    # Create model with SAR support if needed
    if model_type == "cnn" and with_sar:
        model = CNNModel(with_sar=True)
    elif model_type == "byol" and with_sar:
        model = BYOLModel(with_sar=True)
    elif model_type == "grayscale-byol":
        model = GrayscaleBYOLModel(with_sar=with_sar)
    else:
        model = create_model(model_type)
    
    # Load weights if provided
    if weights_path:
        model.load(weights_path)
    
    # Set model to evaluation mode
    model.eval()
    
    return model


def main():
    """Main function."""
    # Parse command line arguments
    args = parse_args()
    
    # Run command
    if args.command == "detect":
        # Import detect function to avoid circular imports
        from models import detect
        detect(args)
    elif args.command == "indices":
        # Import calculate_indices function to avoid circular imports
        from preprocessing import calculate_indices
        calculate_indices(args)
    elif args.command == "vectorize":
        # Import vectorize function to avoid circular imports
        from preprocessing import vectorize
        vectorize(args)
    else:
        print("Please specify a command. Use --help for more information.")


if __name__ == "__main__":
    main()

## Step 5: Upload Test Imagery

In [None]:
from google.colab import files

print("Please upload your multispectral imagery file (GeoTIFF format):")
uploaded = files.upload()

# Get the filename of the uploaded file
imagery_filename = list(uploaded.keys())[0]
print(f"Uploaded file: {imagery_filename}")

## Step 6: Create Output Directory

In [None]:
!mkdir -p colab_results

## Step 7: Check Image Channels

In [None]:
import rasterio

# Check the number of channels in the uploaded image
with rasterio.open(imagery_filename) as src:
    num_channels = src.count
    print(f"Image has {num_channels} channel(s)")
    
# Determine which model to use based on the number of channels
if num_channels < 3:
    print("This is a grayscale or 2-channel image. We'll use the grayscale-compatible model.")
    recommended_model = "grayscale-byol"
else:
    print("This is a multi-channel image. We'll use the standard model.")
    recommended_model = "byol"

## Step 8: Calculate Spectral Indices

In [None]:
print("\n=== Step 1: Calculate Spectral Indices ===\n")

!python main_no_mcp.py indices --imagery {imagery_filename} --output colab_results/indices.tif --indices ndvi,ndmi,msavi2

## Step 9: Run Drainage Detection with Grayscale Support

In [None]:
print("\n=== Step 2: Detect Drainage Pipes ===\n")

# Choose one of the following model options based on the image type:

# For grayscale images (1 channel) or 2-channel images
if num_channels < 3:
    print("Using grayscale-compatible BYOL model...")
    !python main_no_mcp.py detect --imagery {imagery_filename} --indices colab_results/indices.tif --output colab_results/drainage_grayscale_byol.tif --model grayscale-byol
# For RGB images (3+ channels)
else:
    print("Using standard BYOL model...")
    !python main_no_mcp.py detect --imagery {imagery_filename} --indices colab_results/indices.tif --output colab_results/drainage_byol.tif --model byol

## Step 10: Vectorize Results

In [None]:
print("\n=== Step 3: Vectorize Results ===\n")

# Determine which detection result to use based on the model used
if num_channels < 3:
    detection_file = "colab_results/drainage_grayscale_byol.tif"
else:
    detection_file = "colab_results/drainage_byol.tif"

!python main_no_mcp.py vectorize --input {detection_file} --output colab_results/drainage_lines.shp

## Using the Grayscale-Compatible BYOL Model

In [None]:
# Create directories for unlabeled data
!mkdir -p data/unlabeled/imagery

# Upload unlabeled imagery (you can upload multiple files)
print("Please upload unlabeled imagery files (GeoTIFF format):")
uploaded_unlabeled = files.upload()

# Save uploaded files to the unlabeled directory
for filename in uploaded_unlabeled.keys():
    with open(f"data/unlabeled/imagery/{filename}", 'wb') as f:
        f.write(uploaded_unlabeled[filename])
    print(f"Saved {filename} to data/unlabeled/imagery/")

### Check Image Channels

In [None]:
import os

# Check all images in the unlabeled directory
has_grayscale = False
for filename in os.listdir('data/unlabeled/imagery'):
    if filename.endswith(('.tif', '.tiff')):
        input_path = os.path.join('data/unlabeled/imagery', filename)
        
        # Check number of channels
        with rasterio.open(input_path) as src:
            num_channels = src.count
            print(f"Image {filename} has {num_channels} channel(s)")
            
            if num_channels < 3:
                has_grayscale = True

print(f"\nDetected grayscale images: {has_grayscale}")

### Run Grayscale BYOL Example

In [None]:
print("\n=== Using Grayscale-Compatible BYOL Model ===\n")

# Run the grayscale BYOL example
!python examples/grayscale_byol_example.py \
    --optical-dir data/unlabeled/imagery \
    --output-dir colab_results \
    --epochs 20 \
    --test-image {imagery_filename}