In [1]:
import numpy as np
from scipy.stats import rankdata
import xarray as xr
from glob import glob
from tqdm import tqdm
import pandas as pd
from multiprocess import Pool
from datetime import datetime, timezone
from zoneinfo import ZoneInfo # Python 3.9
# from tzwhere import tzwhere
import pytz
import warnings
warnings.filterwarnings('ignore')

In [2]:
def lat_lon_tz(lat,lon):
    tz = tzwhere.tzwhere()
    timezone_str = tz.tzNameAt(lat, lon,forceTZ=True) 
    LocalZone = ZoneInfo(timezone_str)
    return LocalZone

def local_to_utc(local_time, local_tz, dt_format = "%Y-%m-%d %H:%M:%S"):
    ## first convert strings to datetime
    Local = [datetime.strptime(s, dt_format).replace(tzinfo=local_tz) for s in local_time]
    UTC = [dt.astimezone(timezone.utc) for dt in Local]
    return UTC

def find_aus_tz(state):
    state = state.replace(" ", "")
    if state=="WA" or state=="Western Australia":
        tz_string = "Australia/West"
    elif state=="SA" or state=="South Australia":
        tz_string = "Australia/South"
    elif state == "VIC" or state=="Victoria":
        tz_string = "Australia/Victoria"
    elif state == "TAS" or state=="Tasmania":
        tz_string = "Australia/Tasmania" 
    elif state == "ANT" or state=="Antarctica":
        tz_string = "Antarctica/McMurdo"  
    elif state == "NSW" or state=="New South Wales" or state=="Australian Capital Territory" or state=="ACT":
        tz_string = "Australia/NSW"
    elif state == "QLD" or state=="Queensland":
        tz_string = "Australia/Queensland"
    elif state == "NT" or state=="Northern Territory":
        tz_string = "Australia/North"
    else:
        print("wrong state")
    LocalZone = ZoneInfo(tz_string)
    return LocalZone

In [3]:
files = sorted(glob("/g/data/w40/dl6968/BoM_daily_stations/netcdf/*.nc"))

In [4]:
timestamps = pd.date_range(start="1940-03-02", end="2024-06-30", freq="D")
full_time_range = timestamps + pd.to_timedelta(9, unit="h")

In [5]:
def percentile_nc(file):
    ds = xr.open_dataset(file)
    ds_clip = ds.sel(time=slice("1940-03-02", "2024-06-30"))
    state = ds.attrs["State"]
    local_tz = find_aus_tz(state)
    full_time_range = timestamps + pd.to_timedelta(9, unit="h")
    local_time = full_time_range.astype(str)
    utc_time = local_to_utc(local_time, local_tz)
    out_time = [np.datetime64(utc_ts) for utc_ts in utc_time]
    # Set the hour for each timestamp
    
    
    ds_sel = ds_clip["prcp"].reindex(time=out_time, fill_value=-1).to_dataset(name="prcp")
    rain = ds_sel["prcp"].fillna(-1)
    # Calculate percentile rank for each record
    ## all days (including no rain day)
    # percentile_ranks = rankdata(rain, method='average') / len(rain) * 100
    ## rainy days only
    # Mask rain rates ≤ 1 mm
    mask = rain > 1
    
    # Calculate percentile ranks only for valid days (rain > 1 mm)
    valid_rain = rain.where(mask, drop=True)  # Select days with rain > 1 mm
    valid_ranks = rankdata(valid_rain, method='average') / len(valid_rain) * 100
    
    # Reintroduce zeros for days with rain ≤ 1 mm
    percentile_ranks = xr.full_like(rain, 0, dtype=float)  # Create an array of zeros
    # percentile_ranks = percentile_ranks.where(mask, 0)  # Keep zeros for invalid days
    percentile_ranks = percentile_ranks.where(~mask, 0)  # Insert ranks for valid days
    
    # Convert to DataArray with time coordinate
    percentile_ranks = xr.DataArray(percentile_ranks, coords=rain.coords, dims=rain.dims)
    
    ## make ranks an xarray
    ranks_xr = xr.DataArray(valid_ranks, coords=valid_rain.coords, dims=valid_rain.dims)
    
    ## replace the values 
    percentile_ranks.loc[ranks_xr.time] = ranks_xr
    ds_pc = percentile_ranks.to_dataset(name="percentile")
    ds_pc.attrs["Description"] = "Percentile for rain rate >1 between 1940-03-01 and 2024-06-01"
    ds_pc["prcp"] = rain
    ds_pc.to_netcdf(file.replace("netcdf", "percentiles"))
    ds_pc.close()
    ds_sel.close()
    ds_clip.close()
    ds.close()

In [7]:
# use multiprocess  
# max_pool means maximum CPU to use
max_pool = 24

with Pool(max_pool) as p:
    pool_outputs = list(
        tqdm(
            p.imap(percentile_nc,
                   files[7440:]),
            total=len(files[7440:]),
            position=0, leave=True
        )
    )
p.join()

100%|██████████| 10388/10388 [04:05<00:00, 42.32it/s]
