In [None]:
import os
import json
import requests
import numpy as np
from datetime import datetime
from pathlib import Path
import rasterio
from rasterio.enums import Resampling
from rasterio.warp import reproject
from shapely.geometry import Polygon
from pystac_client import Client
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
import time
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import warnings
warnings.filterwarnings('ignore')

In [None]:
# Configuration
START_MONTH = "2025-06"
END_MONTH = "2025-07"

CITY = "Toronto"
PROVINCE = "Ontario"
COUNTRY = "Canada"

OUTPUT_DIR = "satellite_data"
CLOUD_COVERAGE_THRESHOLD = 30
DOWNLOAD_TIMEOUT = 300
MAX_WORKERS = 10

In [None]:
class SatelliteDownloader:
    def __init__(self, base_dir):
        """Initialize downloader for image-only processing"""
        self.base_dir = Path(base_dir)
        self.stac_client = Client.open("https://earth-search.aws.element84.com/v1")
        self.create_output_structure()
        self.download_lock = threading.Lock()
        self._session = None

    def __del__(self):
        """Clean up session"""
        if hasattr(self, '_session') and self._session:
            self._session.close()
    
    def create_output_structure(self):
        """Create directory structure for images only"""
        for satellite in ['sentinel1', 'sentinel2']:
            (self.base_dir / "raw" / "images" / satellite).mkdir(parents=True, exist_ok=True)
    
    def is_item_already_downloaded(self, item_id, satellite_type):
        """Check if item directory exists with files"""
        image_dir = self.base_dir / "raw" / "images" / satellite_type / item_id
        if not image_dir.exists():
            return False
        
        # Quick check for essential files
        if satellite_type == 'sentinel2':
            essential_files = ['B02', 'B03', 'B04', 'B08']
        else:
            essential_files = ['VV', 'VH']
        
        for file_pattern in essential_files:
            if not any(image_dir.glob(f"{file_pattern}.*")):
                return False
        
        return True

    def convert_s3_to_https(self, url):
        """Convert S3 URL to HTTPS"""
        if url.startswith('s3://'):
            s3_path = url[5:]
            parts = s3_path.split('/', 1)
            bucket = parts[0]
            path = parts[1] if len(parts) == 2 else ""
            return f"https://{bucket}.s3.amazonaws.com/{path}"
        return url

    def create_session(self):
        """Create optimized requests session"""
        session = requests.Session()
        
        retry_strategy = Retry(
            total=3,  # Reduced retries for speed
            backoff_factor=1,
            status_forcelist=[429, 500, 502, 503, 504],
            allowed_methods=["HEAD", "GET", "OPTIONS"]
        )
        
        adapter = HTTPAdapter(
            max_retries=retry_strategy,
            pool_connections=20,  # Increased for concurrent downloads
            pool_maxsize=20,
            pool_block=True
        )
        
        session.mount("http://", adapter)
        session.mount("https://", adapter)
        session.headers.update({
            'User-Agent': 'SatelliteDownloader/2.0',
            'Accept': '*/*',
            'Connection': 'keep-alive'
        })
        
        return session

    def download_file(self, url, output_path, max_retries=2):
        """Download file with minimal retries for speed"""
        url = self.convert_s3_to_https(url)
        output_path.parent.mkdir(parents=True, exist_ok=True)
        
        # Check if file already exists
        if output_path.exists() and output_path.stat().st_size > 0:
            return str(output_path)
        
        # Create session if needed
        if not hasattr(self, '_session') or self._session is None:
            self._session = self.create_session()
        
        for attempt in range(max_retries):
            try:
                with self._session.get(url, stream=True, timeout=(15, 180)) as response:
                    response.raise_for_status()
                    
                    with open(output_path, 'wb') as f:
                        for chunk in response.iter_content(chunk_size=32768):  # Larger chunks
                            if chunk:
                                f.write(chunk)
                
                if not output_path.exists() or output_path.stat().st_size == 0:
                    raise ValueError("Downloaded file is empty")
                
                return str(output_path)
                
            except Exception as e:
                if output_path.exists():
                    output_path.unlink()
                
                if attempt < max_retries - 1:
                    time.sleep(1)  # Short wait
                    if hasattr(self, '_session') and self._session:
                        self._session.close()
                        self._session = self.create_session()
                else:
                    return None
        
        return None

    def resample_to_10m(self, src_path, dst_path):
        """Fast resample to 10m resolution"""
        try:
            with rasterio.open(src_path) as src:
                scale_factor = src.res[0] / 10.0
                new_width = int(src.width * scale_factor)
                new_height = int(src.height * scale_factor)
                
                transform = src.transform * src.transform.scale(
                    (src.width / new_width), (src.height / new_height)
                )
                
                profile = src.profile.copy()
                profile.update({
                    'width': new_width,
                    'height': new_height,
                    'transform': transform,
                    'compress': 'lzw',  # Add compression
                    'tiled': True
                })
                
                with rasterio.open(dst_path, 'w', **profile) as dst:
                    for i in range(1, src.count + 1):
                        reproject(
                            source=rasterio.band(src, i),
                            destination=rasterio.band(dst, i),
                            src_transform=src.transform,
                            src_crs=src.crs,
                            dst_transform=transform,
                            dst_crs=src.crs,
                            resampling=Resampling.bilinear
                        )
            return True
        except Exception:
            return False
    
    def get_date_range_from_months(self, start_month, end_month):
        """Convert month strings to date range"""
        start_year, start_month_num = map(int, start_month.split('-'))
        end_year, end_month_num = map(int, end_month.split('-'))
        
        start_date = datetime(start_year, start_month_num, 1)
        if end_month_num == 12:
            end_date = datetime(end_year + 1, 1, 1)
        else:
            end_date = datetime(end_year, end_month_num + 1, 1)
        
        return start_date, end_date
    
    def get_city_polygon(self, city, province, country):
        """Get city bounding box from Nominatim API"""
        try:
            address = f"{city}, {province}, {country}"
            response = requests.get(
                "https://nominatim.openstreetmap.org/search",
                params={'q': address, 'format': 'json', 'limit': 1},
                headers={'User-Agent': 'SatelliteDownloader/2.0'},
                timeout=15
            )
            
            data = response.json()
            if data and 'boundingbox' in data[0]:
                bbox = data[0]['boundingbox']
                return {
                    'min_lat': float(bbox[0]), 'max_lat': float(bbox[1]),
                    'min_lon': float(bbox[2]), 'max_lon': float(bbox[3])
                }
        except Exception as e:
            print(f"Error with location lookup: {e}")
        
        return None
    
    def query_stac_data(self, collection, start_date, end_date, bounds, max_cloud_cover=None):
        """Query STAC API for satellite data"""
        bbox = [bounds['min_lon'], bounds['min_lat'], bounds['max_lon'], bounds['max_lat']]
        
        search_params = {
            "collections": [collection],
            "datetime": f"{start_date.isoformat()}/{end_date.isoformat()}",
            "bbox": bbox,
            "limit": 500  # Increased limit
        }
        
        if max_cloud_cover is not None:
            search_params["query"] = {"eo:cloud_cover": {"lt": max_cloud_cover}}
        
        search = self.stac_client.search(**search_params)
        items = list(search.items())
        print(f"   📡 {collection}: {len(items)} items found")
        
        return items
    
    def filter_by_city_polygon(self, items, bounds):
        """Quick filter by area intersection"""
        city_polygon = Polygon([
            (bounds['min_lon'], bounds['min_lat']),
            (bounds['max_lon'], bounds['min_lat']),
            (bounds['max_lon'], bounds['max_lat']),
            (bounds['min_lon'], bounds['max_lat']),
            (bounds['min_lon'], bounds['min_lat'])
        ])
        
        filtered_items = []
        for item in items:
            try:
                bbox_coords = item.bbox
                image_polygon = Polygon([
                    (bbox_coords[0], bbox_coords[1]), (bbox_coords[2], bbox_coords[1]),
                    (bbox_coords[2], bbox_coords[3]), (bbox_coords[0], bbox_coords[3]),
                    (bbox_coords[0], bbox_coords[1])
                ])
                
                if city_polygon.intersects(image_polygon):
                    filtered_items.append(item)
            except Exception:
                continue
        
        return filtered_items

    def download_and_process_band(self, item, asset_name, band_name, output_dir, needs_resample=False):
        """Download and optionally resample a band"""
        if asset_name not in item.assets:
            return None
        
        asset_href = item.assets[asset_name].href
        file_ext = '.jp2' if asset_href.endswith('.jp2') else '.tif'
        
        if needs_resample:
            temp_path = output_dir / f"{band_name}_temp{file_ext}"
            final_path = output_dir / f"{band_name}.tif"
        else:
            final_path = output_dir / f"{band_name}{file_ext}"
            temp_path = None
        
        # Download file
        download_path = temp_path if temp_path else final_path
        if not self.download_file(asset_href, download_path):
            return None
        
        # Resample if needed
        if needs_resample and temp_path:
            success = self.resample_to_10m(temp_path, final_path)
            if temp_path.exists():
                temp_path.unlink()  # Always cleanup temp file
            if not success:
                return None
        
        return str(final_path)
    
    def download_sentinel2_item(self, item):
        """Download essential Sentinel-2 bands only"""
        try:
            if self.is_item_already_downloaded(item.id, "sentinel2"):
                return item.id
            
            image_dir = self.base_dir / "raw" / "images" / "sentinel2" / item.id
            image_dir.mkdir(parents=True, exist_ok=True)
            
            # Essential bands only for faster processing
            essential_bands = {
                'B02': (['B02', 'blue'], False),      # Blue - 10m
                'B03': (['B03', 'green'], False),     # Green - 10m
                'B04': (['B04', 'red'], False),       # Red - 10m
                'B08': (['B08', 'nir'], False),       # NIR - 10m
                'B11': (['B11', 'swir16'], True),     # SWIR - 20m->10m
                'SCL': (['SCL', 'scl'], True)          # Scene Classification - 20m->10m
            }
            
            downloaded_count = 0
            
            for band_name, (asset_names, needs_resample) in essential_bands.items():
                asset_name = next((name for name in asset_names if name in item.assets), None)
                if asset_name:
                    result = self.download_and_process_band(item, asset_name, band_name, image_dir, needs_resample)
                    if result:
                        downloaded_count += 1
            
            # Need at least 4 essential bands
            if downloaded_count >= 4:
                return item.id
            else:
                # Clean up incomplete download
                if image_dir.exists():
                    for file in image_dir.iterdir():
                        file.unlink()
                    image_dir.rmdir()
                return None
            
        except Exception:
            return None
    
    def download_sentinel1_item(self, item):
        """Download Sentinel-1 VV and VH bands"""
        try:
            if self.is_item_already_downloaded(item.id, "sentinel1"):
                return item.id
            
            image_dir = self.base_dir / "raw" / "images" / "sentinel1" / item.id
            image_dir.mkdir(parents=True, exist_ok=True)
            
            downloaded_count = 0
            
            # Download VV and VH polarizations
            for pol in ['vv', 'vh']:
                if pol in item.assets:
                    result = self.download_and_process_band(item, pol, pol.upper(), image_dir)
                    if result:
                        downloaded_count += 1
            
            # Need both VV and VH
            if downloaded_count >= 2:
                return item.id
            else:
                # Clean up incomplete download
                if image_dir.exists():
                    for file in image_dir.iterdir():
                        file.unlink()
                    image_dir.rmdir()
                return None
            
        except Exception:
            return None
    
    def download_items_concurrently(self, items, download_func, satellite_name):
        """Download items with progress tracking"""
        successful_downloads = []
        total_items = len(items)
        
        with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
            future_to_item = {executor.submit(download_func, item): item for item in items}
            
            completed = 0
            for future in as_completed(future_to_item):
                completed += 1
                try:
                    result = future.result()
                    if result is not None:
                        successful_downloads.append(result)
                        with self.download_lock:
                            print(f"   ✅ {satellite_name}: {len(successful_downloads)}/{completed} completed ({completed}/{total_items} processed)")
                except Exception:
                    continue
        
        return successful_downloads
    
    def download_satellite_data(self, start_month, end_month, city, province, country):
        """Main download function - images only"""
        try:
            start_date, end_date = self.get_date_range_from_months(start_month, end_month)
            bounds = self.get_city_polygon(city, province, country)
            
            if not bounds:
                raise ValueError(f"Could not find bounds for {city}, {province}, {country}")
            
            print(f"🛰️  Downloading satellite data for {city}, {province}, {country}")
            print(f"📅 Date range: {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')}")
            print(f"🚀 Using {MAX_WORKERS} concurrent workers")
            print(f"🎯 Mode: Images only (no metadata)")
            
            # Query satellite data
            print("\n🔍 Querying satellite catalogs...")
            s1_items = self.query_stac_data("sentinel-1-grd", start_date, end_date, bounds)
            s2_items = self.query_stac_data("sentinel-2-l2a", start_date, end_date, bounds, CLOUD_COVERAGE_THRESHOLD)
            
            # Filter by location
            print("🌍 Filtering by location...")
            s1_filtered = self.filter_by_city_polygon(s1_items, bounds)
            s2_filtered = self.filter_by_city_polygon(s2_items, bounds)
            
            print(f"📊 After filtering: {len(s1_filtered)} S1 + {len(s2_filtered)} S2 items")
            
            # Download concurrently
            print("\n⬇️  Starting downloads...")
            
            s1_successful = []
            s2_successful = []
            
            if s1_filtered:
                print(f"🛰️  Downloading Sentinel-1 items...")
                s1_successful = self.download_items_concurrently(s1_filtered, self.download_sentinel1_item, "S1")
            
            if s2_filtered:
                print(f"🛰️  Downloading Sentinel-2 items...")
                s2_successful = self.download_items_concurrently(s2_filtered, self.download_sentinel2_item, "S2")
            
            # Final summary
            print(f"\n🎉 Download complete!")
            print(f"📊 Results:")
            print(f"   🛰️  Sentinel-1: {len(s1_successful)}/{len(s1_filtered)} items downloaded")
            print(f"   🛰️  Sentinel-2: {len(s2_successful)}/{len(s2_filtered)} items downloaded")
            print(f"📁 Images saved to: {self.base_dir}/raw/images/")
            print(f"🚀 Ready for processing!")
            
            return {
                'sentinel1_downloaded': len(s1_successful),
                'sentinel2_downloaded': len(s2_successful),
                'total_downloaded': len(s1_successful) + len(s2_successful)
            }
            
        except Exception as e:
            print(f"❌ Error in download process: {e}")
            raise

In [None]:
def main():
    """Main execution function"""    
    try:
        print(f"🚀 Satellite Data Downloader")
        print(f"📍 Location: {CITY}, {PROVINCE}, {COUNTRY}")
        print(f"📅 Date range: {START_MONTH} to {END_MONTH}")
        print(f"☁️  Max cloud cover: {CLOUD_COVERAGE_THRESHOLD}%")
        
        # Create downloader and run
        downloader = SatelliteDownloader(OUTPUT_DIR)
        summary = downloader.download_satellite_data(START_MONTH, END_MONTH, CITY, PROVINCE, COUNTRY)
        
        print("\n✅ Download completed successfully!")
        print(f"📂 Input structure created: {OUTPUT_DIR}/raw/images/")
        
    except Exception as e:
        print(f"❌ Error in main execution: {e}")
        import traceback
        traceback.print_exc()

In [None]:
if __name__ == "__main__":
    main()