In [None]:
import xarray as xr
import rasterio as rio
import geopandas as gpd
import rasterstats as rstats
import rioxarray
from rasterio.enums import Resampling
import numpy as np
import pandas as pd
import os
import glob
import gc
import warnings
warnings.filterwarnings('ignore')

print("Imported packages")

# Load shapefile
shp_fo = r'/Users/collinsmatiza/Documents/Analysis/protected_sites_with_srtm_elevation.shp'
print(f"Loading shapefile: {shp_fo}")

try:
    shp_df = gpd.read_file(shp_fo)
    print("Shapefile loaded")
except Exception as e:
    print(f"Error loading shapefile: {e}")
    exit(1)

# Filter for specific protected areas
target_areas = ['Garden Route National Park', 'Mphaphuli Protected Environment']
print(f"Filtering for target areas: {target_areas}")

# Filter the shapefile using correct column name
filtered_shp = shp_df[shp_df['CUR_NME'].isin(target_areas)].copy()

if len(filtered_shp) == 0:
    print("No matching protected areas found. Please check the area names and column name.")
    print("Available area names (first 10):", shp_df['CUR_NME'].head(10).tolist())
    exit(1)
else:
    print(f"Found {len(filtered_shp)} matching protected areas")

# Convert shapefile to WGS84 if it's not already
print(f"Original shapefile CRS: {shp_df.crs}")
if shp_df.crs != 'EPSG:4326':
    print("Converting shapefile to WGS84 (EPSG:4326)...")
    filtered_shp = filtered_shp.to_crs('EPSG:4326')
    print("Conversion complete")

# Check if protected areas are in the data extent
print("\nChecking protected area locations (WGS84 coordinates):")
for _, row in filtered_shp.iterrows():
    bounds = row.geometry.bounds
    print(f"{row['CUR_NME']}: lon {bounds[0]:.4f} to {bounds[2]:.4f}, lat {bounds[1]:.4f} to {bounds[3]:.4f}")

# Data extent from diagnostic: lon 16.25 to 32.75, lat -35.75 to -22.25
data_lon_range = (16.25, 32.75)
data_lat_range = (-35.75, -22.25)
print(f"Data extent: lon {data_lon_range[0]} to {data_lon_range[1]}, lat {data_lat_range[0]} to {data_lat_range[1]}")

# Define target resolution
TARGET_RESOLUTION = 0.025  # degrees
print(f"Target pixel resolution: {TARGET_RESOLUTION} degrees")

# Base path for NetCDF files
base_path = '/Users/collinsmatiza/Downloads/isimip3a/counterclim/20crv3/mon/'

# Find all precipitation files in the directory
pr_files = glob.glob(os.path.join(base_path, 'pr_mon_*.nc'))
if not pr_files:
    pr_files = glob.glob(os.path.join(base_path, 'pr_*mon*.nc'))
    if not pr_files:
        print(f"No precipitation files found in: {base_path}")
        exit(1)

# Find all temperature files in the directory
tas_files = glob.glob(os.path.join(base_path, 'tas_mon_*.nc'))
if not tas_files:
    tas_files = glob.glob(os.path.join(base_path, 'tas_*mon*.nc'))
    if not tas_files:
        print(f"No temperature files found in: {base_path}")
        exit(1)

print(f"Found {len(pr_files)} precipitation files")
print(f"Found {len(tas_files)} temperature files")
pr_files.sort()
tas_files.sort()

# Initialize list to store all pixel data
# Create dictionaries to match precipitation and temperature files by year
pr_files_by_year = {}
tas_files_by_year = {}

def extract_year_from_filename(filename):
    """Extract year from filename"""
    basename = os.path.basename(filename)
    if '_' in basename:
        parts = basename.split('_')
        for part in parts:
            clean_part = part.replace('.nc', '')
            if clean_part.isdigit() and len(clean_part) == 4:
                return int(clean_part)
    return None

# Group files by year
for pr_file in pr_files:
    year = extract_year_from_filename(pr_file)
    if year and 1982 <= year <= 2015:
        pr_files_by_year[year] = pr_file

for tas_file in tas_files:
    year = extract_year_from_filename(tas_file)
    if year and 1982 <= year <= 2015:
        tas_files_by_year[year] = tas_file

# Find common years
common_years = set(pr_files_by_year.keys()) & set(tas_files_by_year.keys())
print(f"Common years with both precipitation and temperature data: {len(common_years)}")
print(f"Years: {sorted(common_years)}")

if not common_years:
    print("No matching years found between precipitation and temperature files!")
    exit(1)

# Conversion factors dictionary - your data is in kg m-2 s-1, convert to mm/month
conversion_factors = {
    'kg m-2 s-1': 86400 * 30.44,  # Convert to mm/month (avg days per month)
    'kg/m²/s': 86400 * 30.44,
    'kg/m2/s': 86400 * 30.44,
    'm s-1': 86400000 * 30.44,
    'm/s': 86400000 * 30.44,
    'mm s-1': 86400 * 30.44,
    'mm/s': 86400 * 30.44,
    'mm/day': 30.44,  # Convert daily to monthly
    'mm day-1': 30.44,
    'mm/month': 1
}

def resample_to_target_resolution(data_array, original_transform, target_resolution):
    """
    Resample data array to target resolution using bilinear interpolation
    """
    from rasterio.warp import reproject, Resampling as RioResampling
    from rasterio.transform import from_bounds
    
    # Get original bounds
    height, width = data_array.shape
    west, north = original_transform * (0, 0)
    east, south = original_transform * (width, height)
    
    # Calculate new dimensions
    new_width = int(np.ceil((east - west) / target_resolution))
    new_height = int(np.ceil((north - south) / target_resolution))
    
    # Create new transform
    new_transform = from_bounds(west, south, east, north, new_width, new_height)
    
    # Create destination array
    dst_array = np.empty((new_height, new_width), dtype=data_array.dtype)
    dst_array[:] = np.nan
    
    # Reproject
    reproject(
        data_array,
        dst_array,
        src_transform=original_transform,
        dst_transform=new_transform,
        src_crs='EPSG:4326',
        dst_crs='EPSG:4326',
        resampling=RioResampling.bilinear
    )
    
    return dst_array, new_transform

def process_climate_data(file_path, var_name, conversion_factor=1, target_units=''):
    """Process climate data file and return resampled data for each time step"""
    ds = None
    try:
        ds = xr.open_dataset(file_path)
        data_var = ds[var_name]
        
        # Get units and apply conversion
        units = data_var.attrs.get('units', '')
        print(f"    Original {var_name} units: {units}")
        
        # Set up spatial coordinates
        if 'longitude' in data_var.dims:
            data_var = data_var.rename({'longitude': 'x', 'latitude': 'y'})
        
        # Set CRS explicitly
        data_var = data_var.rio.write_crs("EPSG:4326")
        
        # Create original transform
        lon_res = float(data_var.x[1] - data_var.x[0])
        lat_res = float(data_var.y[1] - data_var.y[0])
        
        from rasterio.transform import from_bounds
        west, south, east, north = float(data_var.x.min()), float(data_var.y.min()), float(data_var.x.max()), float(data_var.y.max())
        
        # Adjust bounds to pixel edges
        west -= lon_res / 2
        east += lon_res / 2
        south -= abs(lat_res) / 2
        north += abs(lat_res) / 2
        
        original_affine = from_bounds(west, south, east, north, len(data_var.x), len(data_var.y))
        
        # Process each time step
        time_data_dict = {}
        for time_idx in range(len(data_var.time)):
            time_val = data_var.time[time_idx].values
            date_obj = pd.to_datetime(time_val)
            date_str = date_obj.strftime('%Y-%m-%d')
            
            # Select data for this time step
            time_data = data_var.isel(time=time_idx)
            
            # Apply conversion
            if var_name == 'pr':
                time_data = time_data * conversion_factor
            elif var_name == 'tas':
                time_data = time_data + conversion_factor
            
            # Extract values and flip if necessary
            data_array = time_data.values
            
            # Check if y-coordinates are decreasing
            if len(data_var.y) > 1 and data_var.y[0] > data_var.y[1]:
                data_array = np.flipud(data_array)
            
            # Skip if all NaN
            if np.all(np.isnan(data_array)):
                continue
            
            # Resample to target resolution
            resampled_array, resampled_affine = resample_to_target_resolution(
                data_array, original_affine, TARGET_RESOLUTION
            )
            
            time_data_dict[date_str] = {
                'data': resampled_array,
                'affine': resampled_affine
            }
        
        return time_data_dict, target_units
        
    except Exception as e:
        print(f"    Error processing {var_name} data: {e}")
        return {}, ''
    finally:
        if ds is not None:
            ds.close()

all_pixel_data = []
pixel_counter = 0  # Global pixel counter for sequential IDs
pixel_id_map = {}  # Dictionary to map (lon, lat) to unique pixel IDs

# Process each year
for year in sorted(common_years):
    pr_file = pr_files_by_year[year]
    tas_file = tas_files_by_year[year]
    
    print(f"\nProcessing year {year}:")
    print(f"  Precipitation file: {os.path.basename(pr_file)}")
    print(f"  Temperature file: {os.path.basename(tas_file)}")
    
    # Process precipitation data
    print("  Processing precipitation...")
    pr_data_dict, pr_units = process_climate_data(pr_file, 'pr', 
                                                  conversion_factors.get('kg m-2 s-1', 86400 * 30.44), 
                                                  'mm/month')
    
    # Process temperature data  
    print("  Processing temperature...")
    # Temperature is typically in Kelvin, convert to Celsius by subtracting 273.15
    tas_data_dict, tas_units = process_climate_data(tas_file, 'tas', -273.15, 'Celsius')
    
    if not pr_data_dict or not tas_data_dict:
        print(f"  Skipping year {year} - missing data")
        continue
    
    # Find common dates
    common_dates = set(pr_data_dict.keys()) & set(tas_data_dict.keys())
    print(f"  Processing {len(common_dates)} common time steps")
    
    # Check areas in extent for this year's data
    areas_in_extent = []
    for _, row in filtered_shp.iterrows():
        bounds = row.geometry.bounds
        if (bounds[0] <= data_lon_range[1] and bounds[2] >= data_lon_range[0] and
            bounds[1] <= data_lat_range[1] and bounds[3] >= data_lat_range[0]):
            areas_in_extent.append(row['CUR_NME'])
    
    if not areas_in_extent:
        print(f"  WARNING: No protected areas overlap with data extent for year {year}")
        continue
    
    # Process each common date
    for date_str in sorted(common_dates):
        if len(common_dates) > 6 and list(sorted(common_dates)).index(date_str) % 3 == 0:
            print(f"    Processing date: {date_str}")
        
        pr_info = pr_data_dict[date_str]
        tas_info = tas_data_dict[date_str]
        
        # Process each protected area
        for _, row in filtered_shp.iterrows():
            if row['CUR_NME'] not in areas_in_extent:
                continue
            
            try:
                # Extract precipitation data
                pr_stats = rstats.zonal_stats(
                    row.geometry,
                    pr_info['data'],
                    affine=pr_info['affine'],
                    raster_out=True,
                    nodata=np.nan,
                    all_touched=True
                )
                
                # Extract temperature data
                tas_stats = rstats.zonal_stats(
                    row.geometry,
                    tas_info['data'],
                    affine=tas_info['affine'],
                    raster_out=True,
                    nodata=np.nan,
                    all_touched=True
                )
                
                if (pr_stats and len(pr_stats) > 0 and pr_stats[0] is not None and
                    tas_stats and len(tas_stats) > 0 and tas_stats[0] is not None):
                    
                    pr_stat = pr_stats[0]
                    tas_stat = tas_stats[0]
                    
                    if ('mini_raster_array' in pr_stat and pr_stat['mini_raster_array'] is not None and
                        'mini_raster_array' in tas_stat and tas_stat['mini_raster_array'] is not None):
                        
                        pr_arr = pr_stat['mini_raster_array']
                        tas_arr = tas_stat['mini_raster_array']
                        pr_affine = pr_stat['mini_raster_affine']
                        
                        # Get valid pixels (where both variables have data)
                        pr_rows, pr_cols = np.where(~np.isnan(pr_arr))
                        tas_rows, tas_cols = np.where(~np.isnan(tas_arr))
                        
                        # Find intersection of valid pixels
                        pr_pixels = set(zip(pr_rows, pr_cols))
                        tas_pixels = set(zip(tas_rows, tas_cols))
                        valid_pixels = pr_pixels & tas_pixels
                        
                        if len(valid_pixels) == 0:
                            continue
                        
                        for r, c in valid_pixels:
                            try:
                                lon, lat = rio.transform.xy(pr_affine, r, c)
                                pr_val = pr_arr[r, c]
                                tas_val = tas_arr[r, c]
                                
                                # Handle coordinate format
                                if isinstance(lon, (list, tuple)):
                                    lon = lon[0]
                                if isinstance(lat, (list, tuple)):
                                    lat = lat[0]

                                # Create a unique key for the pixel based on its coordinates
                                pixel_key = (round(lon, 4), round(lat, 4))
                                if pixel_key not in pixel_id_map:
                                    pixel_id_map[pixel_key] = pixel_counter
                                    pixel_counter += 1

                                pixel_id = pixel_id_map[pixel_key]
                                
                                # Create record with both climate variables
                                record = {
                                    'pixel_id': pixel_id,
                                    'NDVI': np.nan,  # Placeholder
                                    'Human_Modification': np.nan,  # Placeholder
                                    'date': date_str,
                                    'area_name': row['CUR_NME'],
                                    'elevation': row.get('elevation', np.nan),
                                    'T_BIOME': row.get('T_BIOME', 'NA'),
                                    'Temperature_C': float(tas_val),
                                    'Precipitation_mm': float(pr_val)
                                }
                                
                                all_pixel_data.append(record)
                                
                            except Exception as e:
                                print(f"      Error processing pixel: {e}")
                                continue
                        
            except Exception as e:
                print(f"    Error in zonal stats for {row['CUR_NME']}: {e}")
                continue
    
    print(f"  Completed year {year}. Current total records: {len(all_pixel_data)}")
    gc.collect()

print(f"\nExtraction complete. Total records: {len(all_pixel_data)}")

# Save results to CSV
if all_pixel_data:
    print("Converting to DataFrame...")
    pixel_df = pd.DataFrame(all_pixel_data)
    
    # Save full dataset
    output_csv = f'/Users/collinsmatiza/Downloads/isimip3a/pixel_timeseries_monthly_1982_2015_mm_month_{TARGET_RESOLUTION}deg.csv'
    print(f"Saving data to: {output_csv}")
    pixel_df.to_csv(output_csv, index=False)
    print(f"Data saved successfully")
    
    # Print summary statistics
    print("\nSummary of extracted data:")
    print(f"Total records: {len(pixel_df)}")
    print(f"Total unique pixels: {len(pixel_df['pixel_id'].unique())}")
    print(f"Date range: {pixel_df['date'].min()} to {pixel_df['date'].max()}")
    print(f"Units: mm/month")
    print(f"Pixel resolution: {TARGET_RESOLUTION} degrees")
    print(f"Columns: {list(pixel_df.columns)}")
    
    for area in target_areas:
        area_data = pixel_df[pixel_df['area_name'] == area]
        if len(area_data) > 0:
            print(f"\n{area}:")
            print(f"  Total records: {len(area_data)}")
            print(f"  Unique pixels: {len(area_data['pixel_id'].unique())}")
            print(f"  Pixel ID range: {area_data['pixel_id'].min()} to {area_data['pixel_id'].max()}")
            print(f"  Years covered: {sorted(area_data['date'].str[:4].unique())}")
            print(f"  Precipitation range: {area_data['Precipitation_mm'].min():.4f} to {area_data['Precipitation_mm'].max():.4f} mm/month")
        else:
            print(f"\n{area}: No data found (likely outside data extent)")
    
    # Save summary statistics
    if len(pixel_df) > 0 and 'date' in pixel_df.columns:
        print("Creating yearly summary...")
        pixel_df['year'] = pixel_df['date'].str[:4]
        yearly_summary = pixel_df.groupby(['area_name', 'year'])['Precipitation_mm'].agg([
            'count', 'mean', 'min', 'max', 'std'
        ]).reset_index()
        yearly_summary['units'] = 'mm/month'
        yearly_summary['pixel_resolution_deg'] = TARGET_RESOLUTION
        
        summary_csv = f'/Users/collinsmatiza/Downloads/isimip3a/yearly_summary_1982_2015_mm_month_{TARGET_RESOLUTION}deg.csv'
        yearly_summary.to_csv(summary_csv, index=False)
        print(f"Yearly summary saved to: {summary_csv}")
else:
    print("No data extracted - likely no overlap between protected areas and data extent")

print("Processing complete")