In [None]:
import os
import numpy as np
import xarray as xr
from scipy.ndimage import distance_transform_edt
from multiprocessing import Pool
import os
import numpy as np
from scipy.ndimage import distance_transform_edt
from datetime import datetime, timedelta

In [None]:


DATA_DIR = "/mnt/team/rapidresponse/pub/tropical-storms/data"
RAW_DATA_DIR = f"{DATA_DIR}/raw"
PROCESSED_DATA_DIR = f"{DATA_DIR}/processed"
DATA_SOURCE = "cmip6"

model_names = [f.name for f in os.scandir(f"{RAW_DATA_DIR}/{DATA_SOURCE}") if f.is_dir()]

In [None]:
import os
import re
import numpy as np
import xarray as xr

def int_to_date(date_int, monthly=False):
    if monthly:
        return datetime.strptime(str(date_int), "%Y%m")
    else:
        return datetime.strptime(str(date_int), "%Y%m%d")

from datetime import datetime, timedelta


def check_model_completeness(model_name):
    base_path = os.path.join(RAW_DATA_DIR, DATA_SOURCE, model_name)
    errors = []

    # 1. Check for exactly one variant folder
    variant_folders = [f.name for f in os.scandir(base_path) if f.is_dir()]
    if len(variant_folders) != 1:
        errors.append(f"Expected 1 variant folder, found {len(variant_folders)}: {variant_folders}")
        return errors

    variant = variant_folders[0]
    variant_path = os.path.join(base_path, variant)

    # 2. Check for 4 scenario folders
    required_scenarios = ["historical", "ssp126", "ssp245", "ssp585"]
    scenario_folders = [f.name for f in os.scandir(variant_path) if f.is_dir()]
    missing_scenarios = [s for s in required_scenarios if s not in scenario_folders]
    if missing_scenarios:
        errors.append(f"Missing scenario folders: {missing_scenarios}")

    for scenario in required_scenarios:
        scenario_path = os.path.join(variant_path, scenario)
        if not os.path.isdir(scenario_path):
            continue

        # 3. Check for 6 variable folders
        required_vars = ["ua", "va", "tos", "psl", "hus", "ta"]
        var_folders = [f.name for f in os.scandir(scenario_path) if f.is_dir()]
        missing_vars = [v for v in required_vars if v not in var_folders]
        if missing_vars:
            errors.append(f"Scenario '{scenario}' missing variable folders: {missing_vars}")

        for var in required_vars:
            var_path = os.path.join(scenario_path, var)
            if not os.path.isdir(var_path):
                continue

            # 4. Check for 1 grid folder
            grid_folders = [f.name for f in os.scandir(var_path) if f.is_dir()]
            if len(grid_folders) != 1:
                errors.append(f"Scenario '{scenario}', variable '{var}' expected 1 grid folder, found {len(grid_folders)}: {grid_folders}")
                continue

            grid = grid_folders[0]
            grid_path = os.path.join(var_path, grid)

            # 5. Check for 1 time folder
            time_folders = [f.name for f in os.scandir(grid_path) if f.is_dir()]
            if len(time_folders) != 1:
                errors.append(f"Scenario '{scenario}', variable '{var}', grid '{grid}' expected 1 time folder, found {len(time_folders)}: {time_folders}")
                continue

            time_folder = time_folders[0]
            time_path = os.path.join(grid_path, time_folder)

            # 6. Check NetCDF files and date coverage
            nc_files = sorted([f for f in os.listdir(time_path) if f.endswith('.nc')])
            if not nc_files:
                errors.append(f"Scenario '{scenario}', variable '{var}', grid '{grid}', time '{time_folder}' has no NetCDF files")
                continue

            monthly = 'mon' in time_folder.lower()
            date_ranges = []
            if monthly:
                regex = rf"{var}_{time_folder}_{model_name}_{scenario}_{variant}_{grid}_(\d{{6}})-(\d{{6}})\.nc"
            else:
                regex = rf"{var}_{time_folder}_{model_name}_{scenario}_{variant}_{grid}_(\d{{8}})-(\d{{8}})\.nc"
            for fname in nc_files:
                m = re.match(regex, fname)
                if not m:
                    errors.append(f"File '{fname}' does not match expected naming convention")
                    continue
                start, end = m.groups()
                date_ranges.append((int(start), int(end)))

            date_ranges.sort(key=lambda x: x[0])
            if monthly:
                expected_start, expected_end = (197001, 201412) if scenario == "historical" else (201501, 210012)
            else:
                expected_start, expected_end = (19700101, 20141231) if scenario == "historical" else (20150101, 21001231)

            date_issue = False
            if date_ranges:
                actual_start, actual_end = date_ranges[0][0], date_ranges[-1][1]
                if actual_start > expected_start or actual_end < expected_end:
                    date_issue = True

                # Check for gaps
                for i in range(1, len(date_ranges)):
                    prev_end = int_to_date(date_ranges[i-1][1], monthly)
                    curr_start = int_to_date(date_ranges[i][0], monthly)
                    if monthly:
                        year = prev_end.year + (prev_end.month // 12)
                        month = prev_end.month % 12 + 1
                        try:
                            next_month = prev_end.replace(year=year, month=month)
                        except ValueError:
                            next_month = prev_end.replace(year=prev_end.year + 1, month=1)
                        if next_month != curr_start:
                            # date_issue = True
                            errors.append(f"Gap detected for scenario {scenario}, variable {var} between {prev_end.strftime('%Y%m')} and {curr_start.strftime('%Y%m')}")
                            break
                    else:
                        if prev_end + timedelta(days=1) != curr_start:
                            # date_issue = True
                            errors.append(f"Gap detected for scenario {scenario}, variable {var} between {prev_end.strftime('%Y%m%d')} and {curr_start.strftime('%Y%m%d')}")
                            break
            else:
                date_issue = True

            if date_issue:
                errors.append(f"Scenario '{scenario}', variable '{var}' has issues with dates")

    return errors

def check_all_models_completeness(model_names, verbose=True):
    """
    Check completeness of all climate models.
    
    Parameters:
    -----------
    model_names : list
        List of model names to check
    verbose : bool, default True
        If True, prints progress and detailed error information
        
    Returns:
    --------
    dict
        Dictionary containing:
        - 'complete_models': list of complete model names
        - 'incomplete_models': list of incomplete model names  
        - 'all_errors': dict mapping model names to their error lists
        - 'summary': dict with counts and statistics
    """
    complete_models = []
    incomplete_models = []
    all_errors = {}
    
    if verbose:
        print("Checking model completeness...")
        print("=" * 60)
    
    for i, model_name in enumerate(model_names, 1):
        if verbose:
            print(f"[{i}/{len(model_names)}] Checking model: {model_name}")
            
        errors = check_model_completeness(model_name)
        all_errors[model_name] = errors
        
        if not errors:
            if verbose:
                print(f"  ✓ Model '{model_name}' is COMPLETE.")
            complete_models.append(model_name)
        else:
            if verbose:
                print(f"  ✗ Model '{model_name}' is INCOMPLETE:")
                for err in errors:
                    print(f"    - {err}")
            incomplete_models.append(model_name)
    
    # Create summary
    summary = {
        'total_models': len(model_names),
        'complete_count': len(complete_models),
        'incomplete_count': len(incomplete_models),
        'completion_rate': len(complete_models) / len(model_names) if model_names else 0
    }
    
    if verbose:
        print("\n" + "=" * 60)
        print("COMPLETENESS SUMMARY")
        print("=" * 60)
        print(f"Complete models ({len(complete_models)}):")
        for model in complete_models:
            print(f"  - {model}")

        print(f"\nIncomplete models ({len(incomplete_models)}):")
        for model in incomplete_models:
            print(f"  - {model}")

        print(f"\nTotal: {len(model_names)} models")
        print(f"Complete: {len(complete_models)} ({summary['completion_rate']:.1%})")
        print(f"Incomplete: {len(incomplete_models)}")
        print("=" * 60)
    
    return {
        'complete_models': complete_models,
        'incomplete_models': incomplete_models,
        'all_errors': all_errors,
        'summary': summary
    }

# Usage examples:
# Verbose mode (default)
results = check_all_models_completeness(model_names, verbose=True)

# Silent mode
results = check_all_models_completeness(model_names, verbose=False)

# Access the results
complete_models = results['complete_models']
incomplete_models = results['incomplete_models']
all_errors = results['all_errors']
summary = results['summary']

print(f"Found {len(complete_models)} complete models")

In [15]:
models_to_run = ['ACCESS-CM2', 'MIROC6']

In [23]:
def fill_nans_nearest(data):
    if not np.any(np.isnan(data)):
        return data
    
    data_filled = data.copy()
    
    if data.ndim > 2:
        for t in range(data.shape[0]):
            slice_data = data[t, ...]
            mask = np.isnan(slice_data)
            if np.any(mask):
                ind = distance_transform_edt(mask, return_distances=False, return_indices=True)
                data_filled[t, ...][mask] = slice_data[tuple(ind[:, mask])]
    else:
        mask = np.isnan(data)
        ind = distance_transform_edt(mask, return_distances=False, return_indices=True)
        data_filled[mask] = data[tuple(ind[:, mask])]
    
    return data_filled


def write_yearly_files(ds, src_file, dest_dir):
    import re
    import os

    # Get the time folder name from the path
    time_folder = os.path.basename(os.path.dirname(src_file))
    is_monthly = 'mon' in time_folder.lower()
    is_daily = time_folder.lower() == 'day'

    years = np.unique(ds["time.year"].values)
    if is_monthly:
        print(f"Monthly file detected (time folder: '{time_folder}'): uses YYYYMM date format.")
    elif is_daily:
        print(f"Daily file detected (time folder: '{time_folder}'): uses YYYYMMDD date format.")
    else:
        print(f"Unknown time frequency for file: {src_file}")

    for year in years:
        ds_year = ds.sel(time=str(year))
        if is_monthly:
            out_fname = re.sub(r'_(\d{6})-(\d{6})\.nc$', f'_{year}01-{year}12.nc', os.path.basename(src_file))
        else:
            out_fname = re.sub(r'_(\d{8})-(\d{8})\.nc$', f'_{year}0101-{year}1231.nc', os.path.basename(src_file))
        out_path = os.path.join(dest_dir, out_fname)
        ds_year.to_netcdf(out_path)


def process_and_report(args):
    import xarray as xr
    import re
    from datetime import datetime

    src_file, dest_dir, file_idx, total_files = args
    try:
        start_time = datetime.now()
        ds = xr.open_dataset(src_file, engine='netcdf4')
        print("Filling NaNs if any...")
        has_nans = any(np.any(np.isnan(ds[var].values)) for var in ds.data_vars)
        if has_nans:
            ds_filled = ds.copy()
            for var in ds.data_vars:
                arr = ds[var].values
                if np.any(np.isnan(arr)):
                    ds_filled[var].values = fill_nans_nearest(arr)
        else:
            ds_filled = ds
        print("Writing yearly files...")
        write_yearly_files(ds_filled, src_file, dest_dir)
        ds.close()
        elapsed = (datetime.now() - start_time).total_seconds()
        status = "filled NaNs" if has_nans else "no NaNs"
        years = np.unique(ds_filled["time.year"].values)
        return (True, f"[{file_idx}/{total_files}] ✓ {os.path.basename(src_file)} ({status}, {len(years)} years, {elapsed:.1f}s)")
    except Exception as e:
        return (False, f"[{file_idx}/{total_files}] ✗ {os.path.basename(src_file)}: {str(e)}")

In [None]:
# Collect all files
print("=" * 80)
print("SCANNING DIRECTORIES")
print("=" * 80)

file_tasks = []
for model_idx, model_name in enumerate(models_to_run, 1):
    print(f"[{model_idx}/{len(model_names)}] Scanning model: {model_name}")
    
    src_root = os.path.join(RAW_DATA_DIR, DATA_SOURCE, model_name)
    dest_root = os.path.join(PROCESSED_DATA_DIR, DATA_SOURCE, model_name)
    
    model_file_count = 0
    for dirpath, dirnames, filenames in os.walk(src_root):
        rel_path = os.path.relpath(dirpath, src_root)
        dest_dir = os.path.join(dest_root, rel_path)
        os.makedirs(dest_dir, exist_ok=True)
        
        for fname in filenames:
            if fname.endswith('.nc'):
                src_file = os.path.join(dirpath, fname)
                file_tasks.append((src_file, dest_dir))
                model_file_count += 1
    
    print(f"    Found {model_file_count} NetCDF files")

print("\n" + "=" * 80)
print(f"PROCESSING {len(file_tasks)} FILES")
print("=" * 80)

# Add file index to each task
file_tasks_with_index = [
    (src_file, dest_dir, idx+1, len(file_tasks)) 
    for idx, (src_file, dest_dir) in enumerate(file_tasks)
]

results = []
for args in file_tasks_with_index:
    print(f"Processing file {args[2]} of {args[3]}: {os.path.basename(args[0])}")
    result = process_and_report(args)
    results.append(result)


SCANNING DIRECTORIES
[1/15] Scanning model: ACCESS-CM2
    Found 142 NetCDF files
[2/15] Scanning model: MIROC6
    Found 676 NetCDF files

PROCESSING 818 FILES
Processing file 1 of 818: ta_Amon_ACCESS-CM2_ssp585_r1i1p1f1_gn_201501-210012.nc
Filling NaNs if any...
Writing yearly files...
Monthly file detected (time folder: 'Amon'): uses YYYYMM date format.
Processing file 2 of 818: ua_day_ACCESS-CM2_ssp585_r1i1p1f1_gn_20250101-20291231.nc
Filling NaNs if any...
Writing yearly files...
Daily file detected (time folder: 'day'): uses YYYYMMDD date format.
Processing file 3 of 818: ua_day_ACCESS-CM2_ssp585_r1i1p1f1_gn_20850101-20891231.nc
Filling NaNs if any...
Writing yearly files...
Daily file detected (time folder: 'day'): uses YYYYMMDD date format.
Processing file 4 of 818: ua_day_ACCESS-CM2_ssp585_r1i1p1f1_gn_20600101-20641231.nc
Filling NaNs if any...
Writing yearly files...
Daily file detected (time folder: 'day'): uses YYYYMMDD date format.
Processing file 5 of 818: ua_day_ACCESS-C

In [None]:





start_time = datetime.now()

# Process in parallel
with Pool(processes=8) as pool:
    results = pool.map(process_single_file, file_tasks_with_index)

elapsed_total = (datetime.now() - start_time).total_seconds()

print("\n" + "=" * 80)
print("RESULTS")
print("=" * 80)

# Count successes and failures
successes = sum(1 for success, _ in results if success)
failures = sum(1 for success, _ in results if not success)

# Print all results
for success, message in results:
    print(message)

print("\n" + "=" * 80)
print("SUMMARY")
print("=" * 80)
print(f"Total files: {len(file_tasks)}")
print(f"Successful: {successes}")
print(f"Failed: {failures}")
print(f"Total time: {elapsed_total/60:.1f} minutes ({elapsed_total:.1f} seconds)")
print(f"Average time per file: {elapsed_total/len(file_tasks):.1f} seconds")
print("=" * 80)

SCANNING DIRECTORIES
[1/15] Scanning model: ACCESS-CM2
    Found 138 NetCDF files
[2/15] Scanning model: EC-Earth3
    Found 656 NetCDF files
[3/15] Scanning model: ACCESS-ESM1-5
    Found 3 NetCDF files
[4/15] Scanning model: CESM2-WACCM
    Found 77 NetCDF files
[5/15] Scanning model: INM-CM5-0
    Found 62 NetCDF files
[6/15] Scanning model: MIROC6
    Found 676 NetCDF files
[7/15] Scanning model: FGOALS-g3
    Found 664 NetCDF files
[8/15] Scanning model: IPSL-CM6A-LR
    Found 9 NetCDF files
[9/15] Scanning model: CanESM5
    Found 3 NetCDF files
[10/15] Scanning model: INM-CM4-8
    Found 68 NetCDF files
[11/15] Scanning model: NorESM2-LM
    Found 64 NetCDF files
[12/15] Scanning model: KACE-1-0-G
    Found 22 NetCDF files
[13/15] Scanning model: MRI-ESM2-0
    Found 17 NetCDF files
[14/15] Scanning model: TaiESM1
    Found 303 NetCDF files
[15/15] Scanning model: MPI-ESM1-2-HR
    Found 189 NetCDF files

PROCESSING 2951 FILES WITH 4 PARALLEL WORKERS
NOTE: Running with 4 workers

KeyboardInterrupt: 