# 🌡️ Climate Analysis with xarray - Complete Version

This notebook demonstrates the xarray approach to climate analysis with:
- File browser for raster upload
- Inline plots within notebook cells
- Full ROI selection tools (drawing, coordinates, raster upload)
- Data exploration before analysis
- Export to ../outputs directory

## Key Features:
- ⚡ **Fast analysis**: No repeated GEE API calls
- 🔧 **Data exploration**: Examine xarray structure before analysis
- 📊 **Inline plots**: All visualizations appear in notebook cells
- 🎯 **File browser**: Easy raster file selection
- 📁 **Full pixel export**: Every individual pixel value preserved

In [1]:
# Import required libraries
import ee
import geemap
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
import ipywidgets as widgets
from IPython.display import display, clear_output
import os
import rasterio
from tkinter import filedialog
import tkinter as tk

# Set matplotlib to display inline
%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 8)

# Initialize Earth Engine with your project
try:
    ee.Initialize(project='tl-cities')
    print('✅ Earth Engine initialized successfully')
except Exception as e:
    print(f'❌ Earth Engine initialization failed: {e}')

print('📦 Available packages:')
print(f'   - xarray: {xr.__version__}')
print(f'   - pandas: {pd.__version__}')
print(f'   - numpy: {np.__version__}')

# Create outputs directory
os.makedirs('../outputs', exist_ok=True)
print('📁 Created outputs directory: ../outputs')

*** Earth Engine *** Share your feedback by taking our Annual Developer Satisfaction Survey: https://google.qualtrics.com/jfe/form/SV_7TDKVSyKvBdmMqW?ref=4i2o6


✅ Earth Engine initialized successfully
📦 Available packages:
   - xarray: 2024.7.0
   - pandas: 2.3.1
   - numpy: 1.26.4
📁 Created outputs directory: ../outputs


## 🎯 Step 1: Enhanced ROI Selection with File Browser

In [2]:
# Global variables
analysis_geom = None
temperature_data = None

# Create map for ROI selection
m = geemap.Map(center=[-12.9714, -38.5014], zoom=10)  # Salvador, Brazil
m.add_basemap('SATELLITE')
m.add('draw_control')

def set_roi_from_drawing():
    '''Extract ROI from map drawing'''
    global analysis_geom
    
    try:
        if hasattr(m, 'draw_control') and len(m.draw_control.data) > 0:
            # Get the last drawn feature
            feature = m.draw_control.data[-1]
            coords = feature['geometry']['coordinates']
            
            if feature['geometry']['type'] == 'Polygon':
                analysis_geom = ee.Geometry.Polygon(coords)
            elif feature['geometry']['type'] == 'Rectangle':
                analysis_geom = ee.Geometry.Rectangle(coords)
            
            area_km2 = analysis_geom.area().divide(1000000).getInfo()
            bounds_info = analysis_geom.bounds().getInfo()['coordinates'][0]
            west, south = bounds_info[0]
            east, north = bounds_info[2]
            
            print(f'✅ ROI set from drawing: {area_km2:.1f} km²')
            print(f'   Bounds: W={west:.3f}, E={east:.3f}, S={south:.3f}, N={north:.3f}')
            return True
        else:
            print('❌ No drawing found. Please draw a polygon or rectangle on the map.')
            return False
    except Exception as e:
        print(f'❌ Error setting ROI from drawing: {e}')
        return False

def set_roi_from_coordinates():
    '''Set ROI from coordinate inputs'''
    global analysis_geom
    
    try:
        west = float(west_input.value) if west_input.value else -38.7
        east = float(east_input.value) if east_input.value else -38.3
        south = float(south_input.value) if south_input.value else -13.1
        north = float(north_input.value) if north_input.value else -12.8
        
        analysis_geom = ee.Geometry.Rectangle([west, south, east, north])
        area_km2 = analysis_geom.area().divide(1000000).getInfo()
        
        # Add rectangle to map with proper visualization
        roi_image = ee.Image().paint(analysis_geom, 1, 2)
        m.addLayer(roi_image, {'palette': ['red'], 'max': 1}, 'ROI')
        m.centerObject(analysis_geom, 11)
        
        print(f'✅ ROI set from coordinates: {area_km2:.1f} km²')
        print(f'   Bounds: W={west:.3f}, E={east:.3f}, S={south:.3f}, N={north:.3f}')
        return True
    except Exception as e:
        print(f'❌ Error setting ROI from coordinates: {e}')
        return False

def browse_raster_file():
    '''Open file browser to select raster file'''
    try:
        # Create a temporary tkinter root window
        root = tk.Tk()
        root.withdraw()  # Hide the root window
        
        # Open file dialog
        file_path = filedialog.askopenfilename(
            title='Select Reference Raster File',
            filetypes=[
                ('Raster files', '*.tif *.tiff *.img *.nc *.hdf *.jp2'),
                ('GeoTIFF', '*.tif *.tiff'),
                ('NetCDF', '*.nc'),
                ('All files', '*.*')
            ]
        )
        
        root.destroy()  # Clean up
        
        if file_path:
            raster_path_display.value = file_path
            print(f'📁 Selected file: {os.path.basename(file_path)}')
            print(f'    Full path: {file_path}')
            return file_path
        else:
            print('❌ No file selected')
            return None
            
    except Exception as e:
        print(f'❌ Error opening file browser: {e}')
        print('   Note: File browser requires GUI environment')
        return None

def set_roi_from_raster():
    '''Set ROI from selected raster extent with proper CRS handling'''
    global analysis_geom
    
    try:
        raster_path = raster_path_display.value.strip()
        
        if not raster_path or not os.path.exists(raster_path):
            print('❌ Please select a valid raster file first')
            return False
        
        print(f'📖 Reading raster: {os.path.basename(raster_path)}')
        
        # Read raster bounds and CRS information
        with rasterio.open(raster_path) as src:
            bounds = src.bounds
            crs = src.crs
            shape = src.shape
            transform = src.transform
            
            # Get bounds in original CRS
            west, south, east, north = bounds.left, bounds.bottom, bounds.right, bounds.top
            
            print(f'   📊 Raster info:')
            print(f'      CRS: {crs}')
            print(f'      Shape: {shape}')
            print(f'      Original bounds: W={west:.3f}, E={east:.3f}, S={south:.3f}, N={north:.3f}')
            
            # Transform to WGS84 if needed
            if crs.to_epsg() != 4326:
                from rasterio.warp import transform_bounds
                west_wgs84, south_wgs84, east_wgs84, north_wgs84 = transform_bounds(
                    crs, 'EPSG:4326', west, south, east, north
                )
                print(f'      Transformed to WGS84:')
                print(f'      WGS84 bounds: W={west_wgs84:.6f}, E={east_wgs84:.6f}, S={south_wgs84:.6f}, N={north_wgs84:.6f}')
                west, south, east, north = west_wgs84, south_wgs84, east_wgs84, north_wgs84
            else:
                print(f'      Already in WGS84')
        
        # Validate bounds are reasonable
        if abs(west) > 180 or abs(east) > 180 or abs(south) > 90 or abs(north) > 90:
            print(f'❌ Invalid bounds detected - coordinates out of valid range')
            print(f'   This suggests a CRS projection issue')
            return False
        
        if west >= east or south >= north:
            print(f'❌ Invalid bounds - west >= east or south >= north')
            return False
        
        # Create geometry in WGS84
        analysis_geom = ee.Geometry.Rectangle([west, south, east, north], 'EPSG:4326')
        area_km2 = analysis_geom.area().divide(1000000).getInfo()
        
        # Add to map with proper visualization
        roi_image = ee.Image().paint(analysis_geom, 1, 2)
        m.addLayer(roi_image, {'palette': ['blue'], 'max': 1}, 'Raster ROI')
        m.centerObject(analysis_geom, 11)
        
        print(f'   ✅ ROI set from raster extent: {area_km2:.1f} km²')
        print(f'   Final WGS84 bounds: W={west:.6f}, E={east:.6f}, S={south:.6f}, N={north:.6f}')
        
        # Test if ROI overlaps with GSHTD data coverage
        test_centroid = analysis_geom.centroid().coordinates().getInfo()
        test_lon, test_lat = test_centroid[0], test_centroid[1]
        print(f'   🎯 ROI center: {test_lat:.3f}°N, {test_lon:.3f}°E')
        
        # Check if in valid GSHTD coverage areas
        valid_coverage = False
        if test_lat > 15 and test_lon > -140 and test_lon < -40:  # North America
            print(f'   🌍 ROI appears to be in North America coverage')
            valid_coverage = True
        elif test_lat < 35 and test_lon > -120 and test_lon < -30:  # Latin America  
            print(f'   🌍 ROI appears to be in Latin America coverage')
            valid_coverage = True
        elif test_lat > 30 and test_lon > -15 and test_lon < 180:  # Europe & Asia
            print(f'   🌍 ROI appears to be in Europe/Asia coverage')
            valid_coverage = True
        elif test_lat < 40 and test_lon > -20 and test_lon < 55:  # Africa
            print(f'   🌍 ROI appears to be in Africa coverage')
            valid_coverage = True
        elif test_lat < -5 and test_lon > 110 and test_lon < 180:  # Australia
            print(f'   🌍 ROI appears to be in Australia coverage')
            valid_coverage = True
        
        if not valid_coverage:
            print(f'   ⚠️ Warning: ROI may be outside GSHTD coverage areas')
            print(f'   ⚠️ GSHTD covers: North America, Latin America, Europe/Asia, Africa, Australia')
        
        return True
        
    except Exception as e:
        print(f'❌ Error setting ROI from raster: {e}')
        import traceback
        print(f'   Details: {traceback.format_exc()}')
        return False

# ROI input widgets
west_input = widgets.FloatText(value=-38.7, description='West:')
east_input = widgets.FloatText(value=-38.3, description='East:')
south_input = widgets.FloatText(value=-13.1, description='South:')
north_input = widgets.FloatText(value=-12.8, description='North:')

# File browser widgets
raster_path_display = widgets.Text(
    value='',
    placeholder='No file selected...',
    description='Selected File:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='500px'),
    disabled=True  # Read-only display
)

browse_button = widgets.Button(
    description='📂 Browse Files',
    button_style='info',
    tooltip='Click to select a raster file'
)
browse_button.on_click(lambda b: browse_raster_file())

# Action buttons
set_drawing_button = widgets.Button(description='📍 Use Drawing', button_style='success')
set_coords_button = widgets.Button(description='📍 Use Coordinates', button_style='info')
set_raster_button = widgets.Button(description='📍 Use Raster Extent', button_style='warning')

set_drawing_button.on_click(lambda b: set_roi_from_drawing())
set_coords_button.on_click(lambda b: set_roi_from_coordinates())
set_raster_button.on_click(lambda b: set_roi_from_raster())

roi_interface = widgets.VBox([
    widgets.HTML('<h3>🎯 ROI Selection - Enhanced with File Browser</h3>'),
    
    widgets.HTML('<b>Method 1: Draw on Map</b>'),
    widgets.HTML('Draw a polygon or rectangle on the map below, then click:'),
    set_drawing_button,
    
    widgets.HTML('<b>Method 2: Enter Coordinates</b>'),
    widgets.HBox([west_input, east_input]),
    widgets.HBox([south_input, north_input]),
    set_coords_button,
    
    widgets.HTML('<b>Method 3: Use Raster File Extent</b>'),
    widgets.HTML('Click Browse to select a reference raster file:'),
    widgets.HBox([browse_button, raster_path_display]),
    set_raster_button
])

display(roi_interface)
display(m)

print('🎯 Enhanced ROI Selection Ready with File Browser')
print('🔧 Improved CRS handling and coordinate validation')
print('Choose any method to define your region of interest')

VBox(children=(HTML(value='<h3>🎯 ROI Selection - Enhanced with File Browser</h3>'), HTML(value='<b>Method 1: D…

Map(center=[-12.9714, -38.5014], controls=(WidgetControl(options=['position', 'transparent_bg'], position='top…

🎯 Enhanced ROI Selection Ready with File Browser
🔧 Improved CRS handling and coordinate validation
Choose any method to define your region of interest


## 📊 Step 2: Analysis Configuration

In [3]:
# Analysis configuration
analysis_year = widgets.IntSlider(value=2020, min=2003, max=2020, description='Analysis Year:')
reference_start = widgets.IntSlider(value=2010, min=2003, max=2019, description='Reference Start:')
reference_end = widgets.IntSlider(value=2019, min=2004, max=2020, description='Reference End:')
absolute_threshold = widgets.FloatSlider(value=35.0, min=20.0, max=45.0, step=0.5, description='Threshold (°C):')
percentile_threshold = widgets.FloatSlider(value=95.0, min=50.0, max=99.0, step=1.0, description='Percentile:')

config_interface = widgets.VBox([
    widgets.HTML('<h3>📊 Analysis Configuration</h3>'),
    widgets.HTML('<div style="background-color: #fff3cd; padding: 10px; border-radius: 5px;">' +
                '<b>Note:</b> High thresholds (35°C, 95th percentile) may result in few/zero heat days ' +
                'in some regions. This is scientifically valid for climatological extremes.</div>'),
    analysis_year,
    widgets.HBox([reference_start, reference_end]),
    widgets.HBox([absolute_threshold, percentile_threshold])
])

display(config_interface)
print('📊 Configuration Ready')

VBox(children=(HTML(value='<h3>📊 Analysis Configuration</h3>'), HTML(value='<div style="background-color: #fff…

📊 Configuration Ready


## 🔄 Step 3: Data Extraction

In [4]:
def extract_temperature_data():
    '''Extract temperature data from GSHTD using temporal chunking for 1km resolution'''
    global temperature_data, analysis_geom
    
    if analysis_geom is None:
        print('❌ Please set an ROI first!')
        return False
    
    try:
        print('🔄 Extracting temperature data from GSHTD with temporal chunking for 1km resolution...')
        
        year = analysis_year.value
        ref_start = reference_start.value
        ref_end = reference_end.value
        
        # Debug ROI information
        area_km2 = analysis_geom.area().divide(1000000).getInfo()
        bounds = analysis_geom.bounds().getInfo()
        print(f'   📏 ROI area: {area_km2:.2f} km²')
        print(f'   🗺️ ROI bounds: {bounds}')
        
        # Function to get regional collection based on location
        def get_region_collection(geom):
            """Determine which regional GSHTD collection to use based on geometry location"""
            centroid = geom.centroid().coordinates().getInfo()
            lon, lat = centroid[0], centroid[1]
            
            if lat > 15 and lon > -140 and lon < -40:  # North America
                return "projects/sat-io/open-datasets/global-daily-air-temp/north_america"
            elif lat < 35 and lon > -120 and lon < -30:  # Latin America  
                return "projects/sat-io/open-datasets/global-daily-air-temp/latin_america"
            elif lat > 30 and lon > -15 and lon < 180:  # Europe & Asia
                return "projects/sat-io/open-datasets/global-daily-air-temp/europe_asia"
            elif lat < 40 and lon > -20 and lon < 55:  # Africa
                return "projects/sat-io/open-datasets/global-daily-air-temp/africa"
            elif lat < -5 and lon > 110 and lon < 180:  # Australia
                return "projects/sat-io/open-datasets/global-daily-air-temp/australia"
            else:
                return "projects/sat-io/open-datasets/global-daily-air-temp/north_america"  # Default
        
        # Function to get temperature collection - FIXED: No longer applying scaling here
        def get_temperature_collection(region_geom, start_date, end_date, temp_type='tmax'):
            """Get daily air temperature collection for the specified region and period"""
            collection_id = get_region_collection(region_geom)
            print(f'   📡 Using collection: {collection_id.split("/")[-1]}')
            
            collection = ee.ImageCollection(collection_id)
            
            # Filter by date, bounds, and temperature type using prop_type metadata
            filtered_collection = (collection.filterDate(start_date, end_date)
                                 .filterBounds(region_geom)
                                 .filter(ee.Filter.eq('prop_type', temp_type)))
            
            # FIXED: Just select and clip, don't apply scaling here since getRegion handles raw values
            temp_collection = filtered_collection.map(lambda img: 
                img.select('b1')
                  .clip(region_geom)
                  .copyProperties(img, ['system:time_start'])
            )
            
            return temp_collection
        
        # Test pixel count at 1km resolution
        test_collection = get_temperature_collection(analysis_geom, f'{year}-01-01', f'{year}-01-02', 'tmax')
        
        if test_collection.size().getInfo() == 0:
            print('❌ No images found for test date - check ROI coverage')
            return False
        
        first_image = test_collection.first()
        pixel_count = first_image.select('b1').reduceRegion(
            reducer=ee.Reducer.count(),
            geometry=analysis_geom,
            scale=1000,  # 1km resolution
            maxPixels=1e9
        ).getInfo()
        
        expected_pixels = pixel_count.get('b1', 0)
        print(f'   🔍 Expected pixels per image at 1km: {expected_pixels}')
        
        if expected_pixels == 0:
            print('❌ No pixels found in ROI - check if ROI overlaps with data coverage')
            return False
        
        # Calculate years to extract
        years_to_extract = list(range(ref_start, ref_end + 1)) + [year]
        years_to_extract = sorted(list(set(years_to_extract)))  # Remove duplicates and sort
        
        print(f'   📅 Will extract {len(years_to_extract)} years: {years_to_extract}')
        print(f'   🎯 Using temporal chunking to maintain 1km resolution')
        
        # Extract data year by year
        all_dataframes = []
        
        for extract_year in years_to_extract:
            print(f'\\n   📅 Extracting year {extract_year}...')
            
            year_collection = get_temperature_collection(
                analysis_geom, f'{extract_year}-01-01', f'{extract_year}-12-31', 'tmax'
            )
            
            year_size = year_collection.size().getInfo()
            estimated_values = expected_pixels * year_size
            
            print(f'      Images: {year_size}, Estimated values: {estimated_values:,}')
            
            if estimated_values > 900000:  # Still too large
                print(f'      ⚠️ Still too large for single year, using 2km scale')
                scale = 2000
            else:
                print(f'      ✅ Using 1km scale')
                scale = 1000
            
            try:
                region_data = year_collection.getRegion(
                    geometry=analysis_geom,
                    scale=scale,
                    crs='EPSG:4326'
                ).getInfo()
                
                print(f'      ✅ Extracted {len(region_data)} rows')
                
                if len(region_data) > 1:  # More than just header
                    header = region_data[0]
                    data = region_data[1:]
                    
                    df_year = pd.DataFrame(data, columns=header)
                    df_year['time'] = pd.to_datetime(df_year['time'], unit='ms')
                    
                    # FIXED: Apply temperature scaling and rename column BEFORE dropping nulls
                    if 'b1' in df_year.columns:
                        df_year['temperature'] = df_year['b1'] / 10.0  # Scale to Celsius
                        df_year = df_year.drop(columns=['b1'])  # Remove original column
                    
                    # Now drop nulls from the correctly named temperature column
                    df_year = df_year.dropna(subset=['temperature'])
                    
                    print(f'      📊 Valid observations: {len(df_year)}')
                    
                    if len(df_year) > 0:
                        all_dataframes.append(df_year)
                
            except Exception as e:
                print(f'      ❌ Failed to extract {extract_year}: {e}')
                continue
        
        if not all_dataframes:
            print('❌ No data extracted for any year')
            return False
        
        # Combine all years
        print(f'\\n   🔗 Combining {len(all_dataframes)} years of data...')
        df = pd.concat(all_dataframes, ignore_index=True)
        
        print(f'   📊 Total combined data: {len(df)} observations')
        
        unique_pixels = df[['longitude', 'latitude']].drop_duplicates()
        print(f'   📍 Unique spatial pixels: {len(unique_pixels)}')
        print(f'   🌡️ Temperature range: {df["temperature"].min():.1f}°C to {df["temperature"].max():.1f}°C')
        print(f'   📐 Resolution achieved: {scale}m')
        
        # Show sample of the data
        print('\\n📋 Sample of extracted data:')
        print(df[['time', 'latitude', 'longitude', 'temperature']].head())
        
        # Convert to xarray
        try:
            temperature_data = df.set_index(['time', 'latitude', 'longitude']).to_xarray()
            
            print(f'\\n✅ Xarray dataset created successfully!')
            print(f'   📅 Time range: {temperature_data.time.min().values} to {temperature_data.time.max().values}')
            print(f'   🌍 Spatial dimensions: {temperature_data.dims["latitude"]} × {temperature_data.dims["longitude"]} pixels')
            print(f'   📊 Total observations: {temperature_data.temperature.count().values}')
            print(f'   🎯 Dataset: GSHTD Daily Air Temperature at {scale}m resolution')
            
            return True
            
        except Exception as e:
            print(f'❌ Error converting to xarray: {e}')
            print('   Raw DataFrame saved as backup for debugging')
            globals()['debug_df'] = df
            return False
        
    except Exception as e:
        print(f'❌ Error extracting data: {e}')
        import traceback
        print(f'   Details: {traceback.format_exc()}')
        return False

extract_button = widgets.Button(description='🔄 Extract Data', button_style='primary')
extract_button.on_click(lambda b: extract_temperature_data())

display(extract_button)
print('🔄 Ready to extract pixel-level temperature data using temporal chunking')
print('🎯 Maintains 1km resolution for intra-urban analysis')
print('⚡ Extracts year-by-year to avoid GEE limits')

Button(button_style='primary', description='🔄 Extract Data', style=ButtonStyle())

🔄 Ready to extract pixel-level temperature data using temporal chunking
🎯 Maintains 1km resolution for intra-urban analysis
⚡ Extracts year-by-year to avoid GEE limits


## 🔍 Step 3.5: Data Exploration

In [5]:
def explore_temperature_data():
    '''Explore the extracted temperature xarray dataset with inline visualization'''
    global temperature_data
    
    if temperature_data is None:
        print('❌ Please extract temperature data first!')
        return
    
    try:
        print('🔍 EXPLORING TEMPERATURE XARRAY DATASET')
        print('='*50)
        
        # Dataset overview
        print('📊 DATASET OVERVIEW:')
        print(f'   Data variables: {list(temperature_data.data_vars)}')
        print(f'   Coordinates: {list(temperature_data.coords)}')
        print(f'   Dimensions: {dict(temperature_data.dims)}')
        print(f'   Size in memory: {temperature_data.nbytes / 1024**2:.1f} MB')
        
        # Check if we have valid spatial dimensions
        if 'latitude' not in temperature_data.dims or 'longitude' not in temperature_data.dims:
            print('❌ Missing spatial dimensions (latitude/longitude)')
            return
            
        if len(temperature_data.latitude) == 0 or len(temperature_data.longitude) == 0:
            print('❌ Empty spatial dimensions')
            return
        
        # Spatial coverage
        lats = temperature_data.latitude.values
        lons = temperature_data.longitude.values
        print(f'\\n🌍 SPATIAL COVERAGE:')
        print(f'   Latitude range: {lats.min():.3f} to {lats.max():.3f}')
        print(f'   Longitude range: {lons.min():.3f} to {lons.max():.3f}')
        print(f'   Number of spatial pixels: {len(lats) * len(lons)}')
        
        # Temperature statistics
        temp_vals = temperature_data.temperature.values
        temp_vals_clean = temp_vals[~np.isnan(temp_vals)]
        
        if len(temp_vals_clean) == 0:
            print('❌ No valid temperature values found')
            return
            
        print(f'\\n🌡️ TEMPERATURE STATISTICS:')
        print(f'   Valid values: {len(temp_vals_clean):,} of {temp_vals.size:,} total')
        print(f'   Temperature range: {temp_vals_clean.min():.1f}°C to {temp_vals_clean.max():.1f}°C')
        print(f'   Mean temperature: {temp_vals_clean.mean():.1f}°C')
        
        # Check thresholds
        abs_thresh = absolute_threshold.value
        pct_thresh = percentile_threshold.value
        overall_percentile = np.percentile(temp_vals_clean, pct_thresh)
        
        print(f'\\n⚠️ THRESHOLD ANALYSIS:')
        print(f'   Absolute threshold: {abs_thresh}°C')
        print(f'   {pct_thresh}th percentile: {overall_percentile:.1f}°C')
        
        values_above_abs = (temp_vals_clean > abs_thresh).sum()
        values_above_pct = (temp_vals_clean > overall_percentile).sum()
        
        print(f'   Values above {abs_thresh}°C: {values_above_abs:,} ({values_above_abs/len(temp_vals_clean)*100:.1f}%)')
        print(f'   Values above {pct_thresh}th percentile: {values_above_pct:,} ({values_above_pct/len(temp_vals_clean)*100:.1f}%)')
        
        if values_above_abs < 10 and values_above_pct < 10:
            print('   🚨 WARNING: Very few values exceed your thresholds!')
            print('   🚨 Results may have many zero heat days (this may be scientifically correct)')
        
        # Create inline visualization with error handling
        print('\\n📊 CREATING EXPLORATION PLOTS...')
        
        try:
            fig, axes = plt.subplots(2, 2, figsize=(15, 10))
            
            # Temperature histogram
            axes[0,0].hist(temp_vals_clean, bins=min(50, len(temp_vals_clean)//10), alpha=0.7, color='skyblue', edgecolor='black')
            axes[0,0].axvline(abs_thresh, color='red', linestyle='--', linewidth=2, label=f'Absolute threshold ({abs_thresh}°C)')
            axes[0,0].axvline(overall_percentile, color='orange', linestyle='--', linewidth=2, label=f'{pct_thresh}th percentile ({overall_percentile:.1f}°C)')
            axes[0,0].set_title('Temperature Distribution')
            axes[0,0].set_xlabel('Temperature (°C)')
            axes[0,0].set_ylabel('Frequency')
            axes[0,0].legend()
            axes[0,0].grid(True, alpha=0.3)
            
            # Sample time series - handle single pixel case
            if len(lats) > 0 and len(lons) > 0:
                sample_lat = lats[len(lats)//2] if len(lats) > 1 else lats[0]
                sample_lon = lons[len(lons)//2] if len(lons) > 1 else lons[0]
                sample_pixel = temperature_data.sel(latitude=sample_lat, longitude=sample_lon, method='nearest')
                
                if len(sample_pixel.temperature.dropna('time')) > 0:
                    sample_pixel.temperature.plot(ax=axes[0,1], linewidth=1, alpha=0.7)
                    axes[0,1].axhline(abs_thresh, color='red', linestyle='--', alpha=0.7, label=f'Threshold ({abs_thresh}°C)')
                    axes[0,1].set_title(f'Sample Pixel Time Series\\n({sample_lat:.3f}°, {sample_lon:.3f}°)')
                    axes[0,1].set_ylabel('Temperature (°C)')
                    axes[0,1].legend()
                    axes[0,1].grid(True, alpha=0.3)
                else:
                    axes[0,1].text(0.5, 0.5, 'No valid data\\nfor sample pixel', ha='center', va='center', transform=axes[0,1].transAxes)
                    axes[0,1].set_title('Sample Pixel Time Series - No Data')
            
            # Spatial temperature mean - handle single pixel case
            temp_mean = temperature_data.temperature.mean(dim='time')
            if len(lats) > 1 and len(lons) > 1:
                temp_mean.plot(ax=axes[1,0], cmap='RdYlBu_r', add_colorbar=True, cbar_kwargs={'label': 'Mean Temperature (°C)'})
                axes[1,0].set_title('Spatial Mean Temperature')
                axes[1,0].set_xlabel('Longitude')
                axes[1,0].set_ylabel('Latitude')
            else:
                # For single pixel, show as text
                mean_temp = temp_mean.values.item() if temp_mean.size == 1 else np.nanmean(temp_mean.values)
                axes[1,0].text(0.5, 0.5, f'Single Pixel\\nMean: {mean_temp:.1f}°C', ha='center', va='center', transform=axes[1,0].transAxes, fontsize=14)
                axes[1,0].set_title('Spatial Mean Temperature')
            
            # Monthly temperature cycle
            try:
                monthly_temps = temperature_data.temperature.groupby('time.month').mean()
                monthly_avg = monthly_temps.mean(dim=['latitude', 'longitude'])
                monthly_avg.plot(ax=axes[1,1], marker='o', linewidth=2)
                axes[1,1].axhline(abs_thresh, color='red', linestyle='--', alpha=0.7, label=f'Threshold ({abs_thresh}°C)')
                axes[1,1].set_title('Monthly Temperature Cycle')
                axes[1,1].set_xlabel('Month')
                axes[1,1].set_ylabel('Temperature (°C)')
                axes[1,1].legend()
                axes[1,1].grid(True, alpha=0.3)
            except Exception as e:
                axes[1,1].text(0.5, 0.5, f'Error creating\\nmonthly plot:\\n{str(e)[:50]}', ha='center', va='center', transform=axes[1,1].transAxes)
                axes[1,1].set_title('Monthly Temperature Cycle - Error')
            
            plt.tight_layout()
            plt.show()
            
        except Exception as e:
            print(f'❌ Error creating plots: {e}')
            import traceback
            print(f'   Details: {traceback.format_exc()}')
        
        print('\\n✅ Data exploration complete!')
        
    except Exception as e:
        print(f'❌ Error in data exploration: {e}')
        import traceback
        print(f'   Details: {traceback.format_exc()}')

explore_button = widgets.Button(description='🔍 Explore Data', button_style='info')
explore_button.on_click(lambda b: explore_temperature_data())

display(explore_button)
print('🔍 Click to explore your extracted xarray dataset with inline plots')
print('🛡️ Enhanced with better error handling for small datasets')

Button(button_style='info', description='🔍 Explore Data', style=ButtonStyle())

🔍 Click to explore your extracted xarray dataset with inline plots
🛡️ Enhanced with better error handling for small datasets


## 📈 Step 4: Analysis with Inline Visualization

In [6]:
def calculate_and_visualize_metrics():
    '''Calculate climate metrics and create inline visualizations'''
    global temperature_data, analysis_results
    
    if temperature_data is None:
        print('❌ Please extract temperature data first!')
        return None
    
    try:
        print('📈 Calculating climate metrics with inline visualization...')
        
        year = analysis_year.value
        ref_start = reference_start.value
        ref_end = reference_end.value
        abs_threshold = absolute_threshold.value
        pct_threshold = percentile_threshold.value
        
        results = {}
        
        # Filter data
        analysis_data = temperature_data.sel(time=str(year))
        reference_data = temperature_data.sel(time=slice(f'{ref_start}-01-01', f'{ref_end}-12-31'))
        
        print(f'   📅 Analysis year: {len(analysis_data.time)} days')
        print(f'   📅 Reference period: {len(reference_data.time)} days')
        
        # DEBUG: Check analysis data
        print(f'\n🔍 DEBUG - Analysis data for {year}:')
        print(f'   Analysis data shape: {analysis_data.temperature.shape}')
        print(f'   Analysis temp range: {analysis_data.temperature.min().values:.1f} to {analysis_data.temperature.max().values:.1f}°C')
        print(f'   Analysis data sample dates: {analysis_data.time[:3].values}')
        
        # Calculate reference percentile
        print(f'   🧮 Calculating {pct_threshold}th percentile...')
        reference_percentile = reference_data.temperature.quantile(pct_threshold/100, dim='time')
        results['reference_percentile'] = reference_percentile
        
        # Heat Days calculation with DEBUG
        print('   🔥 Calculating heat days...')
        threshold = xr.where(
            reference_percentile > abs_threshold,
            reference_percentile,
            abs_threshold
        )
        
        # DEBUG: Check threshold calculation
        print(f'\n🔍 DEBUG - Threshold calculation:')
        print(f'   Reference percentile range: {reference_percentile.min().values:.1f} to {reference_percentile.max().values:.1f}°C')
        print(f'   Absolute threshold: {abs_threshold}°C')
        print(f'   Final threshold range: {threshold.min().values:.1f} to {threshold.max().values:.1f}°C')
        
        # DEBUG: Check the comparison
        print(f'\n🔍 DEBUG - Heat days calculation:')
        temp_above_threshold = analysis_data.temperature > threshold
        print(f'   Boolean comparison shape: {temp_above_threshold.shape}')
        print(f'   Days above threshold (before sum): {temp_above_threshold.sum().values} total instances')
        
        # Check a specific pixel manually
        if len(analysis_data.latitude) > 0 and len(analysis_data.longitude) > 0:
            test_lat = analysis_data.latitude.values[0]
            test_lon = analysis_data.longitude.values[0] 
            test_temps = analysis_data.temperature.sel(latitude=test_lat, longitude=test_lon)
            test_threshold = threshold.sel(latitude=test_lat, longitude=test_lon)
            test_above = (test_temps > test_threshold).sum()
            
            print(f'   Test pixel ({test_lat:.3f}, {test_lon:.3f}):')
            print(f'     Temperature range: {test_temps.min().values:.1f} to {test_temps.max().values:.1f}°C')
            print(f'     Threshold: {test_threshold.values:.1f}°C')
            print(f'     Days above threshold: {test_above.values}')
        
        heat_days = (analysis_data.temperature > threshold).sum(dim='time').fillna(0)
        print(f'   Final heat days range: {heat_days.min().values} to {heat_days.max().values}')
        
        results['heat_days'] = heat_days
        results['threshold_used'] = threshold
        
        # Temperature Trends
        print('   📈 Calculating temperature trends...')
        trends = reference_data.temperature.polyfit(dim='time', deg=1)
        trend_slope = trends.polyfit_coefficients.sel(degree=1)
        ns_per_year = 365.25 * 24 * 60 * 60 * 1e9
        trend_per_year = (trend_slope * ns_per_year).fillna(0)
        results['temperature_trend'] = trend_per_year
        
        # Seasonal Means
        print('   🌅 Calculating seasonal means...')
        seasonal_means = analysis_data.temperature.groupby('time.season').mean()
        results['seasonal_means'] = seasonal_means
        
        # Annual extremes
        print('   🌡️ Calculating annual extremes...')
        annual_max = analysis_data.temperature.max(dim='time').fillna(0)
        annual_min = analysis_data.temperature.min(dim='time').fillna(0)
        annual_range = (annual_max - annual_min).fillna(0)
        
        results['annual_max'] = annual_max
        results['annual_min'] = annual_min
        results['annual_range'] = annual_range
        
        print('\\n✅ Climate metrics calculated successfully!')
        
        # Print summary
        print(f'\\n📊 RESULTS SUMMARY:')
        print(f'   Mean heat days: {heat_days.mean().values:.1f}')
        print(f'   Max heat days: {heat_days.max().values:.0f}')
        print(f'   Pixels with >0 heat days: {(heat_days > 0).sum().values} of {heat_days.count().values}')
        print(f'   Mean temperature trend: {trend_per_year.mean().values:.3f} °C/year')
        
        # Create comprehensive inline visualization
        print('\\n📊 CREATING RESULTS VISUALIZATION...')
        
        try:
            fig, axes = plt.subplots(2, 3, figsize=(18, 12))
            
            # Heat Days Map
            heat_days.plot(ax=axes[0,0], cmap='Reds', add_colorbar=True, cbar_kwargs={'label': 'Heat Days'})
            axes[0,0].set_title('Heat Days per Pixel', fontweight='bold')
            
            # Temperature Trends
            trend_per_year.plot(ax=axes[0,1], cmap='RdBu_r', add_colorbar=True, cbar_kwargs={'label': '°C/year'})
            axes[0,1].set_title('Temperature Trends (°C/year)', fontweight='bold')
            
            # Annual Temperature Range
            annual_range.plot(ax=axes[0,2], cmap='viridis', add_colorbar=True, cbar_kwargs={'label': '°C'})
            axes[0,2].set_title('Annual Temperature Range (°C)', fontweight='bold')
            
            # Time Series
            daily_avg = analysis_data.temperature.mean(dim=['latitude', 'longitude'])
            daily_avg.plot(ax=axes[1,0], linewidth=1.5, color='blue')
            axes[1,0].axhline(abs_threshold, color='red', linestyle='--', alpha=0.7, label=f'Threshold ({abs_threshold}°C)')
            axes[1,0].set_title(f'Daily Average Temperature - {year}', fontweight='bold')
            axes[1,0].set_ylabel('Temperature (°C)')
            axes[1,0].legend()
            axes[1,0].grid(True, alpha=0.3)
            
            # Seasonal Means
            seasonal_avg = seasonal_means.mean(dim=['latitude', 'longitude'])
            seasonal_avg.plot.bar(ax=axes[1,1], color=['lightblue', 'lightgreen', 'orange', 'lightcoral'])
            axes[1,1].set_title('Seasonal Temperature Means', fontweight='bold')
            axes[1,1].set_ylabel('Temperature (°C)')
            axes[1,1].tick_params(axis='x', rotation=45)
            
            # Heat Days Distribution
            heat_days_flat = heat_days.values.flatten()
            heat_days_clean = heat_days_flat[~np.isnan(heat_days_flat)]
            axes[1,2].hist(heat_days_clean, bins=20, alpha=0.7, color='red', edgecolor='black')
            axes[1,2].set_title('Heat Days Distribution', fontweight='bold')
            axes[1,2].set_xlabel('Heat Days per Pixel')
            axes[1,2].set_ylabel('Frequency')
            axes[1,2].grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.show()
            
        except Exception as viz_error:
            print(f'⚠️ Visualization error (results still calculated): {viz_error}')
        
        print('\\n✅ Analysis and visualization complete!')
        
        return results
        
    except Exception as e:
        print(f'❌ Error calculating metrics: {e}')
        import traceback
        print(f'   Details: {traceback.format_exc()}')
        return None

# Fixed analyze button to properly set global variable
analyze_button = widgets.Button(description='📈 Calculate & Visualize', button_style='success')
analysis_results = None

def run_analysis_with_viz(button):
    global analysis_results  # Make sure this is global
    print('🔄 Starting analysis...')
    analysis_results = calculate_and_visualize_metrics()
    if analysis_results is not None:
        print('✅ Analysis results stored successfully!')
        print('📁 Ready for export')
    else:
        print('❌ Analysis failed - results not available for export')

analyze_button.on_click(run_analysis_with_viz)

display(analyze_button)
print('📈 Ready to calculate climate metrics with inline visualization')
print('🔧 Fixed to properly store results for export')
print('🐛 Added debug output to trace heat days calculation issue')

Button(button_style='success', description='📈 Calculate & Visualize', style=ButtonStyle())

📈 Ready to calculate climate metrics with inline visualization
🔧 Fixed to properly store results for export
🐛 Added debug output to trace heat days calculation issue


## 📁 Step 5: Export Results to ../outputs

In [7]:
def export_results():
    '''Export analysis results to ../outputs directory with proper CRS'''
    global analysis_results, temperature_data
    
    if analysis_results is None:
        print('❌ Please run analysis first!')
        return
    
    try:
        print('📁 Exporting results to ../outputs directory...')
        year = analysis_year.value
        
        # Create outputs directory
        os.makedirs('../outputs', exist_ok=True)
        
        # 1. Summary Table
        print('   📄 Creating summary statistics table...')
        summary_data = []
        
        for metric_name, metric_data in analysis_results.items():
            if hasattr(metric_data, 'mean') and len(metric_data.dims) <= 2:
                try:
                    if metric_name != 'seasonal_means':
                        valid_count = metric_data.count().values
                        if valid_count > 0:
                            summary_data.append({
                                'metric': metric_name,
                                'valid_pixels': int(valid_count),
                                'mean': float(metric_data.mean().values),
                                'min': float(metric_data.min().values),
                                'max': float(metric_data.max().values),
                                'std': float(metric_data.std().values),
                                'median': float(metric_data.median().values)
                            })
                except Exception as e:
                    print(f'     ⚠️ Skipping {metric_name}: {e}')
        
        if summary_data:
            summary_df = pd.DataFrame(summary_data)
            summary_filename = f'../outputs/climate_summary_{year}.csv'
            summary_df.to_csv(summary_filename, index=False)
            print(f'   ✅ Summary saved: {summary_filename}')
            
            # Display summary table inline
            print('\n📊 SUMMARY STATISTICS:')
            display(summary_df)
        
        # 2. FIXED: Full Pixel-wise Data Export using NetCDF coordinates
        print('\n   📊 Creating full pixel-wise data export...')
        
        # Use coordinates from analysis_results instead of temperature_data
        sample_metric = next(iter(analysis_results.values()))
        if hasattr(sample_metric, 'latitude') and hasattr(sample_metric, 'longitude'):
            lats = sample_metric.latitude.values
            lons = sample_metric.longitude.values
        else:
            print('❌ No spatial coordinates found in analysis results')
            return
        
        total_pixels = len(lats) * len(lons)
        print(f'      Processing {total_pixels} pixels...')
        
        pixel_data = []
        processed = 0
        
        for i, lat in enumerate(lats):
            for j, lon in enumerate(lons):
                row = {
                    'pixel_id': f'{i}_{j}',
                    'latitude': float(lat),
                    'longitude': float(lon)
                }
                
                # FIXED: Extract values for each metric with better error handling
                for metric_name, metric_data in analysis_results.items():
                    if hasattr(metric_data, 'sel') and hasattr(metric_data, 'dims'):
                        try:
                            if 'latitude' in metric_data.dims and 'longitude' in metric_data.dims:
                                if len(metric_data.dims) == 2:
                                    # Use exact coordinate matching instead of nearest
                                    value = metric_data.loc[lat, lon].values
                                    if np.isscalar(value) and not np.isnan(value):
                                        row[metric_name] = float(value)
                                    else:
                                        row[metric_name] = 0.0
                                else:
                                    row[metric_name] = 0.0
                        except (KeyError, IndexError) as e:
                            # Coordinate not found - this is expected for water pixels
                            row[metric_name] = 0.0
                        except Exception as e:
                            print(f'      ⚠️ Error extracting {metric_name} at ({lat}, {lon}): {e}')
                            row[metric_name] = 0.0
                    else:
                        row[metric_name] = 0.0
                
                pixel_data.append(row)
                processed += 1
                
                if processed % max(1, total_pixels // 10) == 0:
                    print(f'      Progress: {processed}/{total_pixels} ({processed/total_pixels*100:.0f}%)')
        
        if pixel_data:
            pixel_df = pd.DataFrame(pixel_data)
            pixel_filename = f'../outputs/climate_pixels_{year}.csv'
            pixel_df.to_csv(pixel_filename, index=False)
            
            print(f'\n   ✅ Full pixel data saved: {pixel_filename}')
            print(f'      Rows: {len(pixel_df):,}')
            print(f'      Columns: {len(pixel_df.columns)}')
            
            # Check for non-zero values
            non_zero_counts = {}
            for col in pixel_df.columns:
                if col not in ['pixel_id', 'latitude', 'longitude']:
                    non_zero_counts[col] = (pixel_df[col] != 0).sum()
            
            print(f'      Non-zero value counts:')
            for col, count in non_zero_counts.items():
                print(f'        {col}: {count}')
            
            # Show sample inline
            print('\n📊 SAMPLE PIXEL DATA:')
            display(pixel_df.head(10))
        
        # 3. NetCDF Export with proper CRS
        print('\n   📦 Creating NetCDF file with proper CRS...')
        
        spatial_results = {}
        for k, v in analysis_results.items():
            if hasattr(v, 'dims') and 'latitude' in v.dims and 'longitude' in v.dims:
                if len(v.dims) == 2:
                    spatial_results[k] = v.fillna(0)
        
        if spatial_results:
            results_ds = xr.Dataset(spatial_results)
            
            # Add proper CRS information (WGS84)
            results_ds.latitude.attrs['standard_name'] = 'latitude'
            results_ds.latitude.attrs['long_name'] = 'latitude'
            results_ds.latitude.attrs['units'] = 'degrees_north'
            results_ds.latitude.attrs['axis'] = 'Y'
            
            results_ds.longitude.attrs['standard_name'] = 'longitude'
            results_ds.longitude.attrs['long_name'] = 'longitude'
            results_ds.longitude.attrs['units'] = 'degrees_east'
            results_ds.longitude.attrs['axis'] = 'X'
            
            # Add CRS variable following CF conventions
            crs = xr.DataArray(
                data=np.int32(1),
                attrs={
                    'grid_mapping_name': 'latitude_longitude',
                    'longitude_of_prime_meridian': 0.0,
                    'semi_major_axis': 6378137.0,
                    'inverse_flattening': 298.257223563,
                    'spatial_ref': 'GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4326"]]',
                    'crs_wkt': 'GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4326"]]'
                }
            )
            results_ds['crs'] = crs
            
            # Add grid_mapping attribute to all data variables
            for var_name in results_ds.data_vars:
                if var_name != 'crs':
                    results_ds[var_name].attrs['grid_mapping'] = 'crs'
            
            # Add metadata
            results_ds.attrs['analysis_year'] = year
            results_ds.attrs['reference_period'] = f'{reference_start.value}-{reference_end.value}'
            results_ds.attrs['created'] = datetime.now().isoformat()
            results_ds.attrs['absolute_threshold'] = absolute_threshold.value
            results_ds.attrs['percentile_threshold'] = percentile_threshold.value
            results_ds.attrs['crs'] = 'EPSG:4326'
            
            netcdf_filename = f'../outputs/climate_analysis_{year}.nc'
            results_ds.to_netcdf(netcdf_filename)
            
            print(f'   ✅ NetCDF saved with proper CRS: {netcdf_filename}')
            print(f'      Variables: {list(results_ds.data_vars)}')
            print(f'      Dimensions: {dict(results_ds.dims)}')
            print(f'      CRS: EPSG:4326 (WGS84)')
            print(f'      File size: {os.path.getsize(netcdf_filename) / 1024**2:.1f} MB')
        
        print('\n✅ Export complete to ../outputs directory!')
        print('\n📊 Files created:')
        print(f'   • ../outputs/climate_summary_{year}.csv - Summary statistics')
        print(f'   • ../outputs/climate_pixels_{year}.csv - ALL pixel values with coordinates')
        print(f'   • ../outputs/climate_analysis_{year}.nc - NetCDF spatial dataset WITH CRS')
        
    except Exception as e:
        print(f'❌ Error exporting: {e}')
        import traceback
        print(f'   Details: {traceback.format_exc()}')

export_button = widgets.Button(description='📁 Export to ../outputs', button_style='warning')
export_button.on_click(lambda b: export_results())

display(export_button)
print('📁 Ready to export full pixel-level results to ../outputs directory')
print('🌍 Fixed: NetCDF files now include proper WGS84 CRS information')
print('🔧 Fixed: CSV export now properly extracts values from analysis results')



📁 Ready to export full pixel-level results to ../outputs directory
🌍 Fixed: NetCDF files now include proper WGS84 CRS information
🔧 Fixed: CSV export now properly extracts values from analysis results


## 🎯 Summary

This notebook provides the complete xarray climate analysis workflow with:

### ✅ **Enhanced Features:**
- **File browser button** - Easy raster file selection (no typing paths)
- **Inline plots** - All visualizations appear in notebook cells
- **Drawing tools** - Interactive ROI selection on map
- **Data exploration** - Examine dataset before analysis
- **Full pixel export** - Every individual pixel value preserved
- **Organized output** - All files saved to ../outputs

### 📊 **Complete Workflow:**
1. **ROI Selection** - Draw, coordinates, or browse for raster
2. **Configure Analysis** - Set thresholds and time periods
3. **Extract Data** - Get time series from Google Earth Engine
4. **Explore Dataset** - Examine structure with inline plots
5. **Run Analysis** - Calculate metrics with inline visualization
6. **Export Results** - Save to ../outputs with inline preview

### 🚀 **Performance Benefits:**
- **Extract once, analyze many times** - No repeated GEE API calls
- **Fast local analysis** - xarray operations are vectorized
- **Easy iteration** - Test different thresholds instantly
- **Complete spatial data** - Every pixel value exported

This version provides the ideal user experience with file browsers, inline visualization, and complete data export capabilities!

In [None]:
#notes to self - at 35 degrees or 90% percentile, only max 6 days per year extreme, relative to a reference period of 10 years and 15 day moving window
# the current implementation in this notebook calculates percentiles across the whole year - which will give false 'hot periods'
# as the colder days in the winter drag the averages down. we actually need a moving window so that the percentiles are calculated against
# the same period in the reference period. 5 days is quite short, perhaps we could allow the user to adjust, 10 or 15 days too?
# question is how useful is this kind of analysis. What if the user picks a year for analysis that was not particularly hot?
# is there a way of forcing the tool to select a few recent 'hot' year
# or is there a better metric that shows both the current / most recent number of extreme heat days per year, Plus the trend i.e., in crease
# in extreme heat days, and also an indication of the year within the last 10 years that had the most extreme heat days?