# Calculate data for REZ and GCCSA regions

In [1]:
from dask.distributed import Client,LocalCluster
from dask_jobqueue import PBSCluster

In [2]:
# One node on Gadi has 48 cores - try and use up a full node before going to multiple nodes (jobs)

walltime = '00:15:00'
cores = 8
memory = str(4 * cores) + 'GB'

cluster = PBSCluster(walltime=str(walltime), cores=cores, memory=str(memory), processes=cores,
                     job_extra_directives=['-q normal',
                                           '-P w42',
                                           '-l ncpus='+str(cores),
                                           '-l mem='+str(memory),
                                           '-l storage=gdata/w42+gdata/rt52'],
                     local_directory='$TMPDIR',
                     job_directives_skip=["select"])
                     # python=os.environ["DASK_PYTHON"])

In [3]:
cluster.scale(jobs=1)
client = Client(cluster)

In [4]:
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.PBSCluster
Dashboard: http://10.6.23.26:8787/status,

0,1
Dashboard: http://10.6.23.26:8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.6.23.26:44151,Workers: 0
Dashboard: http://10.6.23.26:8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [5]:
# client.close()
# cluster.close()

In [6]:
# %load_ext autoreload
# %autoreload 2

In [7]:
import xarray as xr
import numpy as np

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

import cartopy.crs as ccrs
import cartopy
cartopy.config['pre_existing_data_dir'] = '/g/data/w42/dr6273/work/data/cartopy-data/'
cartopy.config['data_dir'] = '/g/data/w42/dr6273/work/data/cartopy-data/'

In [8]:
import functions as fn

plt_params = fn.get_plot_params()

# Load masks

In [9]:
rez_mask = fn.get_rez_mask()

In [11]:
# gccsa_mask = fn.get_gccsa_mask()

# Load gridded data

First look at daily (1400). Have already looked at hourly and daily (0000).

In [12]:
years = range(1959, 2022)

Mean solar shortwave radiation downwards

In [14]:
mssrd = xr.open_zarr(
    '/g/data/w42/dr6273/work/data/era5/msdwswrf/msdwswrf_era5_daily_1400UTC_sfc_'+str(years[0])+'-'+str(years[-1])+'_AUS_region.zarr',
    consolidated=True
)

100m wind speed

In [15]:
w100 = xr.open_zarr(
    '/g/data/w42/dr6273/work/data/era5/100w/100w_era5_daily_1400UTC_sfc_'+str(years[0])+'-'+str(years[-1])+'_AUS_region.zarr',
    consolidated=True
)

Runoff

In [17]:
ro = xr.open_zarr(
    '/g/data/w42/dr6273/work/data/era5/ro/ro_era5_monthly-averaged_sfc_'+str(years[0])+'-'+str(years[-1])+'_REZ_region.zarr',
    consolidated=True
)

2m temperature

In [18]:
# t = xr.open_zarr(
#     '/g/data/w42/dr6273/work/data/era5/2t/2t_era5_daily_1400UTC_sfc_'+str(years[0])+'-'+str(years[-1])+'_REZ_region.zarr',
#     consolidated=True
# )

# Calculate regional averages

In [19]:
def region_spatial_mean(da, region, mask):
    """
    Calculate spatial mean for a single region.
    """
    region_mask = mask.sel(region=region)
    return da.where(region_mask, drop=True).mean(['lat', 'lon'], skipna=True)

In [20]:
def calculate_spatial_means(da, mask):
    """
    Calculate spatial means over all regions in da.
    """
    da_list = []
    for region in mask['region'].values:
        region_mean = region_spatial_mean(da, region, mask)
        da_list.append(region_mean)
        
    da_concat = xr.concat(
        da_list,
        dim='region'
    )
    return da_concat

In [21]:
def to_single_chunk(da):
    """
    Rechunk to a single chunk.
    """
    return da.chunk({
        'region': -1,
        'time': -1
    })

In [22]:
def calculate_and_write(da, mask, mask_name, var_name, time_freq_name):
    """
    Calculate regional means, chunk and write.
    """
    da = calculate_spatial_means(da, mask=mask)
    da = to_single_chunk(da)
    ds = da.to_dataset(name=var_name)
    ds.to_zarr(
        '/g/data/w42/dr6273/work/projects/Aus_energy/data/'+var_name+'_era5_'+time_freq_name+'_sfc_'+str(years[0])+'-'+str(years[-1])+'_'+mask_name+'_region_mean.zarr',
        mode='w',
        consolidated=True
    )

In [23]:
calculate_and_write(
    da=w100['w100'],
    mask=rez_mask,
    mask_name='REZ',
    var_name='100w',
    time_freq_name='daily_1400UTC'
)

In [24]:
calculate_and_write(
    da=mssrd['msdwswrf'],
    mask=rez_mask,
    mask_name='REZ',
    var_name='mssrd',
    time_freq_name='daily_1400UTC'
)

In [25]:
# calculate_and_write(
#     da=t['t2m'],
#     mask=gccsa_mask,
#     mask_name='GCCSA',
#     var_name='t2m',
#     time_freq_name='daily_1400UTC'
# )

In [26]:
calculate_and_write(
    da=ro['ro'],
    mask=rez_mask,
    mask_name='REZ',
    var_name='ro',
    time_freq_name='monthly-averaged'
)

# Close cluster

In [27]:
client.close()
cluster.close()