In [None]:
import os
import re
import numpy as np
import xarray as xr
from datetime import datetime, timedelta
import argparse

import idd_climate_models.constants as rfc
from idd_climate_models.validate_model_functions import int_to_date, is_monthly

PROCESSED_DATA_PATH = rfc.PROCESSED_DATA_PATH

# Year filtering constants
MIN_YEAR = 1950
MAX_YEAR = 2100


def define_dest_dir(model, variant, scenario, variable, grid, time_period):
    """Define the destination directory based on model parameters."""
    dest_dir = os.path.join(PROCESSED_DATA_PATH, model, variant, scenario, variable, grid, time_period)
    os.makedirs(dest_dir, exist_ok=True)
    return dest_dir


def fill_nans_xarray(ds):
    """Fill NaN values using xarray's built-in methods - much more efficient!"""
    ds_filled = ds.copy()
    
    # Method 1: Forward fill then backward fill (good for time series)
    # ds_filled = ds.ffill(dim='time').bfill(dim='time')
    
    # Method 2: Interpolate using nearest neighbor (better for spatial data)
    for var in ds.data_vars:
        if ds[var].isnull().any():
            print(f"    Filling NaNs in variable '{var}'...")
            # Use xarray's interpolate_na with nearest neighbor
            ds_filled[var] = ds[var].interpolate_na(
                dim='time', 
                method='nearest', 
                fill_value='extrapolate'
            )
            
            # If there are still NaNs after time interpolation, try spatial interpolation
            if ds_filled[var].isnull().any():
                # Get spatial dimensions (usually lat/lon)
                spatial_dims = [dim for dim in ds[var].dims if dim not in ['time']]
                for dim in spatial_dims:
                    if ds_filled[var].isnull().any():
                        ds_filled[var] = ds_filled[var].interpolate_na(
                            dim=dim, 
                            method='nearest',
                            fill_value='extrapolate'
                        )
    
    return ds_filled


def write_yearly_files_optimized(ds, src_file, dest_dir):
    """Write dataset split into yearly files with better performance and error handling."""
    # Get the time folder name from the path
    time_folder = os.path.basename(os.path.dirname(src_file))
    is_monthly = 'mon' in time_folder.lower()
    is_daily = time_folder.lower() == 'day'
    
    # Get all years and filter to specified range
    all_years = np.unique(ds["time.year"].values)
    years = [year for year in all_years if MIN_YEAR <= year <= MAX_YEAR]
    
    print(f"  Time frequency: {'Monthly' if is_monthly else 'Daily' if is_daily else 'Unknown'}")
    
    # Report filtering results
    if len(years) < len(all_years):
        excluded_years = [year for year in all_years if year < MIN_YEAR or year > MAX_YEAR]
        print(f"  Filtering years {MIN_YEAR}-{MAX_YEAR}: keeping {len(years)}/{len(all_years)} years")
        if len(excluded_years) <= 10:  # Only show if not too many
            print(f"  Excluded years: {excluded_years}")
        else:
            print(f"  Excluded {len(excluded_years)} years outside range")
    else:
        print(f"  All {len(years)} years are within {MIN_YEAR}-{MAX_YEAR} range")
    
    if len(years) > 0:
        print(f"  Years to process: {len(years)} ({years[0]}-{years[-1]})")
    else:
        print(f"  No years to process (all outside {MIN_YEAR}-{MAX_YEAR} range)")
        return
    
    # Process filtered years
    for year in years:
        try:
            # More efficient year selection
            ds_year = ds.sel(time=ds.time.dt.year == year)
            
            # Skip empty years
            if len(ds_year.time) == 0:
                print(f"    Skipping {year} (no data)")
                continue
            
            # Generate output filename
            base_name = os.path.basename(src_file)
            if is_monthly:
                out_fname = re.sub(r'_(\d{6})-(\d{6})\.nc$', f'_{year}01-{year}12.nc', base_name)
            elif is_daily:
                out_fname = re.sub(r'_(\d{8})-(\d{8})\.nc$', f'_{year}0101-{year}1231.nc', base_name)
            else:
                # Fallback for unknown formats
                out_fname = re.sub(r'_(\d{6,8})-(\d{6,8})\.nc$', f'_{year}.nc', base_name)
            
            out_path = os.path.join(dest_dir, out_fname)
            
            # Write with optimized settings
            encoding = {}
            for var in ds_year.data_vars:
                encoding[var] = {
                    'zlib': True,           # Enable compression
                    'complevel': 4,         # Compression level (1-9)
                    'shuffle': True,        # Improve compression
                    'chunksizes': None      # Let xarray choose optimal chunks
                }
            
            ds_year.to_netcdf(
                out_path, 
                encoding=encoding,
                engine='netcdf4'
            )
            
            print(f"    Wrote: {out_fname} ({len(ds_year.time)} time steps)")
            
        except Exception as e:
            print(f"    ✗ Failed to write year {year}: {str(e)}")
            raise


def process_file(file_path, dest_dir):
    """Process a single NetCDF file: fill NaNs and split into yearly files."""
    try:
        start_time = datetime.now()
        print(f"Processing: {os.path.basename(file_path)}")
        
        # Open dataset with chunking for better memory management
        ds = xr.open_dataset(file_path, engine='netcdf4', chunks={'time': 100})
        
        print(f"  Dataset info: {len(ds.time)} time steps, {list(ds.data_vars.keys())} variables")
        
        # Check for NaNs more efficiently
        print("  Checking for NaNs...")
        nan_vars = []
        for var in ds.data_vars:
            if ds[var].isnull().any():
                nan_count = ds[var].isnull().sum().values
                nan_vars.append((var, nan_count))
        
        has_nans = len(nan_vars) > 0
        
        if has_nans:
            print(f"  Found NaNs in {len(nan_vars)} variables:")
            for var, count in nan_vars:
                print(f"    {var}: {count} NaN values")
            print("  Filling NaNs using interpolation...")
            ds_filled = fill_nans_xarray(ds)
        else:
            print("  No NaNs found.")
            ds_filled = ds
        
        # Write yearly files
        print("  Writing yearly files...")
        write_yearly_files_optimized(ds_filled, file_path, dest_dir)
        
        # Cleanup
        ds.close()
        if has_nans:
            ds_filled.close()
        
        # Report results
        elapsed = (datetime.now() - start_time).total_seconds()
        status = f"filled NaNs in {len(nan_vars)} vars" if has_nans else "no NaNs"
        
        # Filter years for final report
        all_years = np.unique(ds.time.dt.year.values) if hasattr(ds, 'time') else []
        filtered_years = [year for year in all_years if MIN_YEAR <= year <= MAX_YEAR]
        
        print(f"  ✓ Completed in {elapsed:.1f}s ({status}, {len(filtered_years)} years processed)")
        return True
        
    except Exception as e:
        print(f"  ✗ Error processing {os.path.basename(file_path)}: {str(e)}")
        return False


def main():
    """Main function to process climate model data."""
    # Set up argument parser
    parser = argparse.ArgumentParser(description="Fill and yearly split climate model data")
    parser.add_argument("--model", type=str, required=True, help="Climate model name")
    parser.add_argument("--variant", type=str, required=True, help="Model variant")
    parser.add_argument("--scenario", type=str, required=True, help="Climate scenario")
    parser.add_argument("--variable", type=str, required=True, help="Climate variable")
    parser.add_argument("--grid", type=str, required=True, help="Grid type")
    parser.add_argument("--time_period", type=str, required=True, help="Time period of the data")
    parser.add_argument("--file_path", type=str, required=True, help="Path to the input file")
    
    # Parse arguments
    args = parser.parse_args()
    
    # Validate input file exists
    if not os.path.exists(args.file_path):
        print(f"Error: Input file does not exist: {args.file_path}")
        return 1
    
    # Create destination directory
    dest_dir = define_dest_dir(
        args.model, args.variant, args.scenario, 
        args.variable, args.grid, args.time_period
    )
    
    print(f"Processing climate model data:")
    print(f"  Model: {args.model}")
    print(f"  Variant: {args.variant}")
    print(f"  Scenario: {args.scenario}")
    print(f"  Variable: {args.variable}")
    print(f"  Grid: {args.grid}")
    print(f"  Time period: {args.time_period}")
    print(f"  Input file: {args.file_path}")
    print(f"  Output directory: {dest_dir}")
    print(f"  Year range: {MIN_YEAR}-{MAX_YEAR}")
    print()
    
    # Process the file
    success = process_file(args.file_path, dest_dir)
    
    if success:
        print("\n✓ Processing completed successfully!")
        return 0
    else:
        print("\n✗ Processing failed!")
        return 1


if __name__ == "__main__":
    exit(main())