# Calculate velocity potential

In [None]:
from dask_jobqueue import PBSCluster
from dask.distributed import Client

In [None]:
# One node on Gadi has 48 cores - try and use up a full core before going to multiple nodes (jobs)

walltime = '03:00:00'
cores = 48
memory = str(cores * 4) + 'GB'

cluster = PBSCluster(walltime=str(walltime), cores=cores, memory=str(memory),
                     job_extra=['-l ncpus='+str(cores),
                                '-l mem='+str(memory),
                                '-P xv83',
                                '-l storage=gdata/xv83+gdata/rt52+scratch/xv83'],
                     header_skip=["select"])

In [None]:
cluster.scale(jobs=3)
client = Client(cluster)

In [None]:
client

In [None]:
import xarray as xr
import os
import numpy as np

In [None]:
years = range(1979, 2021)
levels = [150, 850]

# Daily u and v data - use standard dask and xarray tools

- To process hourly data to daily, for two isobaric levels, takes around 4 hours (using 3 full nodes; ~576 GB)

In [None]:
def get_files(file_path, var, years):
    """
    Get list of files
    """
    fp_list = []
    for year in years:
        fp_dir = file_path+var+'/'+str(year)+'/'
        for fp in sorted(os.listdir(fp_dir)):
            fp_list.append(fp_dir+fp)
    return fp_list

In [None]:
load = True

In [None]:
%%time
if load:
    u = xr.open_zarr('/g/data/xv83/dr6273/work/data/era5/u/u_era5_daily_'+str(years[0])+'-'+str(years[-1])+'.zarr', consolidated=True)
    v = xr.open_zarr('/g/data/xv83/dr6273/work/data/era5/v/v_era5_daily_'+str(years[0])+'-'+str(years[-1])+'.zarr', consolidated=True)
else:
    u_files = get_files('/g/data/rt52/era5/pressure-levels/reanalysis/', 'u', years)
    v_files = get_files('/g/data/rt52/era5/pressure-levels/reanalysis/', 'v', years)
    
    # Using preprocess in open_mfdataset to select desired levels improves performance
    #  versus doing a .sel() afterwards
    def preprocess(ds):
        return ds.sel(level=levels)
    
    u = xr.open_mfdataset(u_files,
                          chunks={'time': 24, 'level': 1}, #, 'longitude': 500, 'latitude': 70},
                          preprocess=preprocess,
                          compat='override',
                          coords='minimal',
                          engine='netcdf4')

    v = xr.open_mfdataset(v_files,
                          chunks={'time': 24, 'level': 1}, #, 'longitude': 500, 'latitude': 70},
                          compat='override',
                          preprocess=preprocess,
                          coords='minimal',
                          engine='netcdf4')
    
    u = u.resample(time='1D').mean()
    v = v.resample(time='1D').mean()
    
#     u = u.chunk({'time': 125, 'longitude': 300, 'level': 1})
#     v = v.chunk({'time': 125, 'longitude': 300, 'level': 1})
    
    u_encoding = {'u': {'dtype': 'float32'}}
    v_encoding = {'v': {'dtype': 'float32'}}
    
    u.to_zarr('/g/data/xv83/dr6273/work/data/era5/u/u_era5_daily_'+str(years[0])+'-'+str(years[-1])+'.zarr',
              mode='w',
              consolidated=True,
              encoding=u_encoding)

    v.to_zarr('/g/data/xv83/dr6273/work/data/era5/v/v_era5_daily_'+str(years[0])+'-'+str(years[-1])+'.zarr',
                mode='w',
                consolidated=True,
                encoding=v_encoding)

# Calculate velocity potential using `windspharm`

- Non-lazy, so we do this separately for each year and isobaric level
- Used 10 cores at 40GB

In [None]:
from windspharm.standard import VectorWind
from windspharm.tools import prep_data, recover_data, order_latdim

### For each level and year

~ Takes around 3 minutes per level and year
~ 2 levels, 42 years takes around 4.5 hours

In [None]:
def write_vpot(u, v, level, year):
    
    lons = u.longitude.values
    lats = u.latitude.values
    year = str(year)

    # Subsample u and v
    u_ = u.u.sel(time=year, level=level)
    v_ = v.v.sel(time=year, level=level)

    # Transpose to ensure time is out front
    u_ = u_.transpose('time', 'latitude', 'longitude')
    v_ = v_.transpose('time', 'latitude', 'longitude')

    # Load values
    uwnd = u_.values
    vwnd = v_.values

    # Ensure data is in correct shape for windspharm
    uwnd, uwnd_info = prep_data(uwnd, 'tyx') # 'tyx' because data is in format time, lat, lon
    vwnd, vwnd_info = prep_data(vwnd, 'tyx')
    lats, uwnd, vwnd = order_latdim(lats, uwnd, vwnd)

    # Create a VectorWind instance to handle computation of streamfunction and velocity potential
    w = VectorWind(uwnd, vwnd)

    # Calculate velocity potential
    _, vp = w.sfvp()

    # Re-shape to original format
    vp = recover_data(vp, uwnd_info)

    # Put into DataArray and format for writing
    vp = xr.DataArray(vp,
                     dims=['time', 'latitude', 'longitude'],
                     coords={'time': u_['time'].values,
                             'latitude': u_['latitude'].values,
                             'longitude': u_['longitude'].values})
    vp = vp.assign_attrs({'short_name': 'vpot',
                          'long name': 'velocity potential',
                          'units': 'm^2 / s^-1'})
    vp = vp.expand_dims({'level': [level]})
    vp = vp.to_dataset(name='vpot')
    vp_encoding = {'vpot': {'dtype': 'float32'}}

    vp.to_netcdf('/g/data/xv83/dr6273/work/data/era5/vpot/nc/vpot_'+str(level)+'_era5_daily_'+str(year)+'.nc',
                mode='w',
                encoding=vp_encoding)

In [None]:
%%time
for level in levels:
    for year in years:
        write_vpot(u, v, level, year)

# Close cluster

In [None]:
client.close()
cluster.close()

# Regrid VPOT to 2x2

- Takes around 2 minutes but need quite a few resources. I used two nodes here.

In [None]:
from dask_jobqueue import PBSCluster
from dask.distributed import Client

In [None]:
# One node on Gadi has 48 cores - try and use up a full core before going to multiple nodes (jobs)

walltime = '00:05:00'
cores = 48
memory = str(cores * 4) + 'GB'

cluster = PBSCluster(walltime=str(walltime), cores=cores, memory=str(memory),
                     job_extra=['-l ncpus='+str(cores),
                                '-l mem='+str(memory),
                                '-P xv83',
                                '-l storage=gdata/xv83+gdata/rt52+scratch/xv83'],
                     header_skip=["select"])

In [None]:
cluster.scale(jobs=2)
client = Client(cluster)

In [None]:
client

In [None]:
import xarray as xr

# Load daily vpot data

In [None]:
vpot = xr.open_mfdataset('/g/data/xv83/dr6273/work/data/era5/vpot/nc/*.nc',
                         coords='minimal',
                         compat='override')

In [None]:
vpot = vpot.chunk({'latitude': -1, 'longitude': -1, 'level': 1, 'time': 25})

### Set up 2x2 array

In [None]:
import xesmf as xe

In [None]:
target_grid = xr.Dataset({'latitude': (['latitude'], np.arange(90, -91, -2)),
                          'longitude': (['longitude'], np.arange(-180, 180, 2))})

In [None]:
regridder = xe.Regridder(vpot, target_grid, 'bilinear')
regridder

In [None]:
vpot_2 = regridder(vpot)

In [None]:
vpot_2.to_zarr('/g/data/xv83/dr6273/work/data/era5/vpot/vpot_era5_daily_2x2.zarr',
               mode='w', consolidated=True)

# Close cluster

In [None]:
client.close()
cluster.close()