In [None]:
import os
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from odc.geo.xr import assign_crs

import warnings
warnings.simplefilter(action='ignore')

In [None]:
import sys
sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')
from _utils import start_local_dask, round_coords

client = start_local_dask(mem_safety_margin='2Gb')
client

In [None]:
#list of years to run
years = [str(i) for i in range(2003, 2023+1)]

# Grab a common grid to reproject too and a create a land mask
p = '/g/data/os22/chad_tmp/NEE_modelling/data/1km/kNDVI_1km_monthly_2002_2022.nc'
gbox = xr.open_dataset(p).odc.geobox

#create a mask of aus extent
mask = xr.open_dataset(p)['kNDVI'].sel(time=slice('2002','2005'))
mask = mask.mean('time')
mask = xr.where(mask>-99, 1, 0)

In [None]:
arrs=[]
for year in years:
    print(year)    
    ozwald_vars = {
        'GPP' :'/g/data/ub8/au/OzWALD/8day/GPP/'
     }
    
    for k,i in ozwald_vars.items():
        j = i+f'OzWALD.GPP.{year}.nc'
        ds = xr.open_dataset(j)
        ds = ds.chunk(dict(latitude=1000, longitude=1000, time=-1))
        
        ds = ds.transpose('time', 'latitude', 'longitude')

        #tidy up
        ds = assign_crs(ds, crs='epsg:4326')
        ds = ds.to_array()
        ds = ds.squeeze().drop_vars('variable')
        ds.attrs['nodata'] = np.nan
        
        #resample time
        ds = ds.resample(time='MS', loffset=pd.Timedelta(14, 'd')).mean().persist()
        ds = ds * ds.time.dt.days_in_month #convert to /month
        
        # resample spatial
        ds = ds.odc.reproject(gbox, resampling='average').compute()
        
        #tidy up
        ds = round_coords(ds)
        ds.attrs['nodata'] = np.nan
        ds = ds.rename(k)
        
        #mask to aus land extent
        ds = ds.where(mask)
        arrs.append(ds)


In [None]:
dss = xr.concat(arrs, dim='time').sortby('time')

In [None]:
dss

In [None]:
dss.to_netcdf('/g/data/os22/chad_tmp/AusEFlux/data/OzWALD_GPP_1km_2003_2023.nc')

In [None]:
dss.max('time').plot.imshow(size=6, robust=True)