In [None]:
import asyncio
import os
import time
import traceback

import IPython.display as display
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

import dask
from dask.distributed import LocalCluster, Client, progress

from wrf_eke_example import get_data, crop_lat_lon, calc_averages, plot_coords, plot_eke_avg

In [None]:
scratch_path=os.path.expandvars("$SCRATCH")
input_path=os.path.join(scratch_path, "scira/wrf_in/")
scenario_type="Historical"
dataset = "wrfout_d01_2008-07-01_00_00_00"
lat_lon_path = os.path.join(input_path, scenario_type, dataset)

years = [y for y in range(2001,2011)]
chunks = {"time": -1, "lev": 1}

file_locations = [os.path.join(input_path, scenario_type, "{}-{}".format(year, year+1), "Variables") for year in years]
file_suffixes = [f'{scenario_type}_{year}.nc' for year in years]
yearly_chunks = [chunks for year in years]

## Preview the data extent and region of interest

In [None]:
# get cropped lat and lon
lat, lon, lat_index_north, lat_index_south, lon_index_west, lon_index_east = crop_lat_lon(lat_lon_path)

# preview the extent and region of interest for the eke_avg calculation
cfig, cax = plot_coords(lat, lon, bbox=(-20,0,20,20))

## Connect to the dask cluster and get a link to the dashboard for status

In [None]:
#dask.config.config["distributed"]["dashboard"]["link"] = "{JUPYTERHUB_SERVICE_PREFIX}proxy/{host}:{port}/status" 

cluster = LocalCluster(n_workers=3, threads_per_worker=1, memory_limit='4G', local_directory='/Volumes/T7/tmp/')
dask_client = Client(cluster)
display.display(dask_client)

#file_path = os.path.abspath('../dask_client')

#if os.path.exists(file_path):
#    with open(file_path, 'r') as location:
#        info = location.read().strip()

#dask_client = Client(scheduler_file=info)
#display.display(dask_client)

## Scatter data to the dask workers

In [None]:
# get file handles and scatter data to dask, data will be loaded lazily during computation
data_refs = [x.result() for x in dask_client.map(get_data, file_locations, file_suffixes, yearly_chunks)]
data_futures = dask_client.scatter(data_refs)
data = data_futures

import wrf_eke_example
lat, lon, lat_index_north, lat_index_south, lon_index_west, lon_index_east = wrf_eke_example.eke.crop_lat_lon(lat_lon_path)

# uncomment this if you have a cluster with ample memory to store data
# persist will keep this data in memory on the workers after being read from disk the first time, this will make multiple runs or additional computation faster
#data_scattered_refs = [(x[0].persist(),x[1].persist()) for x in dask_client.gather(data_futures)]
#data = data_scattered_refs

## Calculate the results using dask and xarray

In [None]:
try:
    print("Calculating eke_avg and total_eke_avg for {}".format(years))
    eke_futures = []
    results = []
    # get a yearly average
    for i in range(len(years)):
        eke_futures.append(
            dask_client.submit(
                calc_averages,
                [data[i]], 
                lat_index_north, 
                lat_index_south, 
                lon_index_west, 
                lon_index_east,
                priority=10-i))
except Exception as e:
    print("Exception for years: {}, chunks: {}".format(years, chunks))
    print(e)
    traceback.print_tb(e.__traceback__)

## Wait for tasks to complete, display the data for each year as it arrives

In [None]:
display.display(progress(eke_futures, notebook=True))

eke_averages = []
total_eke_averages = []

waiting = [i for i in range(len(eke_futures))]
while len(waiting) > 0:
    completed = []

    for i in waiting:
        if eke_futures[i].done():
            # collect the data and show it
            eke_avg_year, total_avg_year = eke_futures[i].result()
            eke_averages.append((eke_avg_year, i))
            total_eke_averages.append((total_avg_year, i))
            print("Total EKE {} - {}".format(years[i], total_avg_year))
            fig, ax = plot_eke_avg(
                eke_avg_year, 
                lat, 
                title='WRF_TCM_M-O_{}-{}_avg_{}_EKE'.format(years[i], years[i]+1, scenario_type), 
                size=(10,8))
            # bypass plot.show() so we can make sure plots show as data comes in
            display.display(fig)
            # close the figure for memory cleanup but also to prevent matplotlib from displaying twice
            plt.close(fig)
            # add some vertical space
            print("\n\n")
            # book keeping
            completed.append(i)

    # stop checking futures that we have collected results for
    for i in completed:
        waiting.remove(i)

    # sleep and give time back to the kernel without blocking
    if len(waiting) > 0:
        await asyncio.sleep(10)

# sort the results by year
sorted(eke_averages, key=lambda x: x[1])
sorted(total_eke_averages, key=lambda x: x[1])

# remove the index values
eke_averages = [x[0] for x in eke_averages]
total_eke_averages = [x[0] for x in total_eke_averages]

## Show yearly averages in order, combined

In [None]:
figsize=(20,10)

# plot total_eke_avg per year
tfig, ax = plt.subplots(figsize=figsize)
ax.set_title('Total EKE Average by year')
plt.xlabel("Year")
plt.ylabel("Total EKE Average")
plt.minorticks_on()
plt.xlim((2000,2011))
plt.scatter(years, total_eke_averages)

efig, axs = plt.subplots(2, 5, figsize=figsize)
efig.tight_layout()

eke_avg_figs = []
# plot eke_avg per year
for i in range(len(eke_averages)):
    plot_eke_avg(
        eke_averages[i], lat, title='WRF_TCM_M-O_{}_avg_{}_EKE'.format(years[i], scenario_type), size=figsize, fig=efig, ax=axs.flat[i])

## Show the 10 year average

In [None]:
eke_avg = np.stack(eke_averages).mean(axis=0)
total_eke_avg = np.mean(total_eke_averages)
print("Total EKE - {}".format(total_eke_avg))

fig, ax = plot_eke_avg(eke_avg, lat, title='WRF_TCM_M-O_{}-{}_avg_{}_EKE'.format(years[0],years[-1]+1, scenario_type), size=(20,15))

## Save the computation result to a file

In [None]:
import datetime
results_filename = 'yearly_eke_averages_{}-{}.nc'.format(years[0],years[-1]+1)

u_filenames = []
v_filenames = []
for i in range(len(file_locations)):
    u_filenames.append(os.path.join(file_locations[i], "ua_" + file_suffixes[i]))
    v_filenames.append(os.path.join(file_locations[i], "va_" + file_suffixes[i]))

years_array = np.asarray(years)
eke_array = np.stack(eke_averages)
total_eke_array = np.asarray(total_eke_averages)

results = xr.Dataset(
    data_vars={
        "years": xr.DataArray(data=years_array, dims=("year")),
        "eke_avg": xr.DataArray(data=eke_array, dims=("year", "lev", "lat")), 
        "total_eke_avg": xr.DataArray(data=total_eke_array, dims=("year")),
        "u_source_files": xr.DataArray(data=u_filenames, dims=("year")),
        "v_source_files": xr.DataArray(data=v_filenames, dims=("year"))
    },
    attrs={
        "created": datetime.datetime.now().astimezone().isoformat()
    })
results.to_netcdf(results_filename)

In [None]:
with xr.open_dataset(results_filename) as results:
    display.display(results)

## Save the final plot

In [None]:
fig.savefig('WRF_TCM_M-O_{}-{}_avg_{}_EKE.pdf'.format(years[0], years[-1], scenario_type))

In [None]:
dask_client.shutdown()
dask_client.close()