In [None]:
import os
import xarray as xr
import shutil

def classify_cluster(elevation, temperature, humidity):
    """Classify the file into one of 8 clusters based on threshold conditions."""
    cluster = 0
    if elevation > threshold_elev:
        cluster |= 1  # Set bit 1
    if temperature > threshold_temp:
        cluster |= 2  # Set bit 2
    if humidity > threshold_hum:
        cluster |= 4  # Set bit 3
    return cluster

def process_netcdf_files(input_dir, output_dir, thresholds):
    global threshold_elev, threshold_temp, threshold_hum
    threshold_elev, threshold_temp, threshold_hum = thresholds
    
    # Ensure cluster directories exist
    for i in range(8):
        os.makedirs(os.path.join(output_dir, f'cluster_{i}'), exist_ok=True)
    
    # List all NetCDF files
    files = [f for f in os.listdir(input_dir) if f.endswith(".nz")]

    
    
    for file in files:
        file_path = os.path.join(input_dir, file)

        # # Extract elevation from the DEM file
        dem_path = file_path.split()
        # elevation = xr.open_dataset(dem_path)
        
        # Open the NetCDF file and extract the variables
        with xr.open_dataset(file_path) as ds:
            temperature = ds['T_2M'].values.mean()
            humidity = ds['RELHUM_2M'].values.mean()
            elevation = ds['HSURF'].values.mean()
        # Classify file into a cluster
        cluster = classify_cluster(elevation, temperature, humidity)
        
        # Move file to the corresponding cluster folder
        dest_path = os.path.join(output_dir, f'cluster_{cluster}', file)
        shutil.move(file_path, dest_path)
        print(f"Moved {file} to cluster_{cluster}")

# Example usage
input_directory = "~/data/1h_2D_sel_cropped_blurred_x8_gridded"
output_directory = "~/data/1h_2D_sel_cropped_blurred_x8_clustered"
thresh_values = (800, 290, 50)  # Example thresholds for elevation, temperature, humidity
process_netcdf_files(input_directory, output_directory, thresh_values)
