In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import numpy as np
import pandas as pd
import glob
import os

import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

import holoviews as hv
from holoviews import opts

import geoviews.feature as gf # only needed for coastlines
from geoviews.operation import resample_geometry
import geoviews as gv
from datashader.mpl_ext import dsshow, alpha_colormap

#import geocat.datafiles as gdf  # Only for reading-in datasets
from xarray import open_mfdataset
import xarray as xr

import metpy.calc as mpcalc
from metpy.units import units

from mpas_tools.database.utils import choose_forecast
from mpas_tools.plotting.utils import nonlinear_colorbar
from mpas_tools.utils.tools import find_TC_bbox
from mpas_tools.datashader_tools.utils import datashader_wrapper

gv.extension("bokeh","matplotlib")

opts.defaults(
    opts.Image(width=1200, height=600),
    opts.RGB(width=1200, height=600))

hv.output(dpi=300, fig='png')

In [None]:
import dask.dataframe as dd
import cartopy.crs as ccrs
import xarray as xr
import math

import holoviews as hv
from holoviews.operation.datashader import rasterize as hds_rasterize

from numba import jit

# This funtion splits a global mesh along longitude
# Examine the X coordinates of each triangle in 'tris'. 
# Return an array of 'tris' where only those triangles
# with legs whose length is less than 't' are returned.
# Makes sure triangle legs are < 90 otherwise they'll cut across the prime meridian

def unzipMesh(x,tris,t):
    return tris[(np.abs((x[tris[:,0]])-(x[tris[:,1]])) < t) &\
                (np.abs((x[tris[:,0]])-(x[tris[:,2]])) < t)]

# Compute the signed area of a triangle - Aidan understands

def triArea(x,y,tris):
    return ((x[tris[:,1]]-x[tris[:,0]]) *\
            (y[tris[:,2]]-y[tris[:,0]])) - ((x[tris[:,2]]-x[tris[:,0]]) *\
                                            (y[tris[:,1]]-y[tris[:,0]]))

# Reorder triangles as necessary so they all have counter clockwise winding order. 
# CCW is what Datashader and MPL require.

def orderCCW(x,y,tris):
    tris[triArea(x,y,tris)<0.0,:] = tris[triArea(x,y,tris)<0.0,::-1]
    return(tris)

# Create a Holoviews Triangle Mesh suitable for rendering with Datashader
# This function returns a Holoviews TriMesh that is created from a list of coordinates, 
# 'x' and 'y', an array of triangle indices that addressess the coordinates in 'x' and 'y', 
# and a data variable 'var'. The data variable's values will annotate the triangle vertices

def createHVTriMesh(x,y,triangle_indices, var, n_workers=1):
    # Declare verts array
    verts = np.column_stack([x, y, var])

    # Convert to pandas
    verts_df  = pd.DataFrame(verts,  columns=['x', 'y', 'z'])
    tris_df   = pd.DataFrame(triangle_indices, columns=['v0', 'v1', 'v2'])

    # Convert to dask
    verts_ddf = dd.from_pandas(verts_df, npartitions=n_workers)
    tris_ddf = dd.from_pandas(tris_df, npartitions=n_workers)

    # Declare HoloViews element
    tri_nodes = hv.Nodes(verts_ddf, ['x', 'y', 'index'], ['z'])
    trimesh = hv.TriMesh((tris_ddf, tri_nodes))
    return(trimesh)

# Triangulate MPAS primary mesh:
# Triangulate each polygon in a heterogenous mesh of n-gons by connecting
# each internal polygon vertex to the first vertex. Uses the MPAS
# auxilliary variables verticesOnCell, and nEdgesOnCell.
# The function is decorated with Numba's just-in-time compiler so that it is translated into
# optimized machine code for better peformance

@jit(nopython=True)
def triangulatePoly(verticesOnCell, nEdgesOnCell):

    # Calculate the number of triangles. nEdgesOnCell gives the number of vertices for each 
    # cell (polygon)
    # The number of triangles per polygon is the number of vertices minus 2.
    
    nTriangles = np.sum(nEdgesOnCell - 2)

    triangles = np.ones((nTriangles, 3), dtype=np.int64)
    nCells = verticesOnCell.shape[0]
    triIndex = 0
    for j in range(nCells):
        for i in range(nEdgesOnCell[j]-2):
            triangles[triIndex][0] = verticesOnCell[j][0]
            triangles[triIndex][1] = verticesOnCell[j][i+1]
            triangles[triIndex][2] = verticesOnCell[j][i+2]
            triIndex += 1

    return triangles

def set_up_mesh(mesh_ds, n_workers=1):
    # Fetch lat and lon coordinates for the primal and dual mesh.
    lonCell = mesh_ds['lonCell'].values * 180.0 / math.pi
    latCell = mesh_ds['latCell'].values * 180.0 / math.pi
    lonCell = ((lonCell - 180.0) % 360.0) - 180.0

    lonVertex = mesh_ds['lonVertex'].values * 180.0 / math.pi
    latVertex = mesh_ds['latVertex'].values * 180.0 / math.pi
    lonVertex = ((lonVertex - 180.0) % 360.0) - 180.0

    # Get triangle indices for each vertex in the MPAS file. Note, indexing in MPAS starts from 1, not zero :-(
    tris = mesh_ds.cellsOnVertex.values - 1

    # Guarantees consistent clockwise winding order (required by Datashade and Matplotlib)
    tris = orderCCW(lonCell,latCell,tris)

    # Unzip the mesh along a constant line of longitude for PCS coordinates (central_longitude=0.0)
    central_longitude = 0.0
    projection = ccrs.Robinson(central_longitude=central_longitude)
    tris = unzipMesh(lonCell,tris,90.0)

    # Project verts from geographic to PCS coordinates
    xPCS, yPCS, _ = projection.transform_points(ccrs.PlateCarree(), lonCell, latCell).T
    
    return xPCS, yPCS, tris, n_workers, projection

In [None]:
def datashader_wrapper(mesh_ds, unstructured_ds, primalVarName, time, level=None, 
                       pixel_height=400, pixel_width=400, pixel_ratio=1, x_sampling=None, 
                       y_sampling=None, lon_range=None, lat_range=None):
    
    # Selects target variable from dataset based on timestep (suppresses error if no time dimension)
    try:
        primalVar = unstructured_ds[primalVarName].isel(time=time).values
    except ValueError:
        primalVar = unstructured_ds[primalVarName].values
    
    if np.ndim(primalVar) > 1 and level==None:
        raise ValueError('Select a level to knock this down to a 1D array.')
    elif np.ndim(primalVar) > 1 and level != None:
        primalVar = unstructured_ds[primalVarName].sel(lev=level, method='nearest').isel(time=time).values
    
    xPCS, yPCS, tris, n_workers, projection = set_up_mesh(mesh_ds)
    # Possibly use xPCS, yPCS and tris to calc vorticity and other derived functions
    
    trimesh = createHVTriMesh(xPCS,yPCS,tris, primalVar,n_workers=n_workers)
    
    if lon_range != None and lat_range != None:
        x_range, y_range, _ = projection.transform_points(ccrs.PlateCarree(), np.array(lon_range), np.array(lat_range)).T
        x_range = tuple(x_range)
        y_range = tuple(y_range)
    else:
        x_range = None
        y_range = None
    
    # Use precompute so it caches the data internally
    rasterized = hds_rasterize(trimesh, aggregator='mean', precompute=True, height=pixel_height, 
                               width=pixel_width, pixel_ratio=pixel_ratio, x_sampling=x_sampling, 
                               y_sampling=y_sampling, x_range=x_range, y_range=y_range)
    
    return rasterized

# Florence MPAS Data

In [None]:
florence_folders = r"/storage/home/cmz5202/group/cnd5285/MPAS_betacast_sample/"
lores_mesh_file = os.path.join(florence_folders, 'FHIST-mpasa120-betacast-ERA5-x001_INIC.nc')
lores_files = choose_forecast(florence_folders, regridded=False, init_time='0Z')
h1_ds = lores_files[0]
h2_ds = lores_files[1]
h4_ds = lores_files[2]

lores_mesh = xr.open_dataset(lores_mesh_file, decode_times=False)

In [None]:
h1_ds

In [None]:
lores_mesh

In [None]:
# Attempt 1: calculaute horizontal vorticity using Davies-Jones 1992


In [None]:
h1_ds['maxU10'] = h1_ds['U10'].max(dim='time')
h1_ds

In [None]:
h1_ds['U10'].max(dim='time').values

In [None]:
h1_ds['maxU10'].values

In [None]:
h2_ds

In [None]:
lores_mesh

In [None]:
print(h2_ds['FLUT'].values.min())
print(h2_ds['FLUT'].values.max())

In [None]:
# Block to create wind swath colormap

# < 18 m/s is discarded (< TS strength)
# Maybe do TS strength in 10's in shades of blue/purple?

# 33-? (>70) m/s Allen was strongest ever at 190 mph (85 m/s)
# Puts range at 0-85, maybe have 0-18 = white?

colors = [(0, '#FFFFFF'), # white
          ((17/85), '#FFFFFF'),
          ((18/85), '#FFA500'), # tropical storm strength winds
          ((32/85), '#FFA500'),
          ((33/85), '#880808'), # hurricane strength winds
          (1, '#880808')
         ]

# omit white, just start at TS strength
#colors2 = [(0, '#FFA500'), ((14/67), '#FFA500'), ((15/67), '#880808'), (1, '#880808')]

ws_cmp = LinearSegmentedColormap.from_list('wind swaths', colors, N=85)
#ws_cmp = LinearSegmentedColormap.from_list('wind swaths', colors2, N=68)

ws_cmp

In [None]:
central_longitude = 0.0
projection = ccrs.Robinson(central_longitude=central_longitude)

target_var = 'maxU10'
target_time = 2
pixel_ratio = 5
lon_range, lat_range = find_TC_bbox(h1_ds, 'north atlantic', target_time, center_dist=1000)
clipped_h1ds = h1_ds
rasterized_lowres = datashader_wrapper(lores_mesh, h1_ds, target_var, target_time, 
                                       lon_range=lon_range, lat_range=lat_range, pixel_ratio=pixel_ratio)
rasterized_lowres.opts(tools=['hover'], colorbar=True, cmap=ws_cmp, clim=(0, 85)) * gf.coastline(projection=projection)

In [None]:
#hv.output(width=400, height=400)

central_longitude = 0.0
projection = ccrs.Robinson(central_longitude=central_longitude)

target_var = 'U10'
target_time = 2
pixel_ratio = 5
lon_range, lat_range = find_TC_bbox(h1_ds, 'north atlantic', target_time, center_dist=1000)
rl1 = datashader_wrapper(lores_mesh, h1_ds, target_var, 1, 
                                       lon_range=lon_range, lat_range=lat_range, pixel_ratio=pixel_ratio)
rl2 = datashader_wrapper(lores_mesh, h1_ds, target_var, 2, 
                                       lon_range=lon_range, lat_range=lat_range, pixel_ratio=pixel_ratio)
rl3 = datashader_wrapper(lores_mesh, h1_ds, target_var, 3, 
                                       lon_range=lon_range, lat_range=lat_range, pixel_ratio=pixel_ratio)
rl4 = datashader_wrapper(lores_mesh, h1_ds, target_var, 4, 
                                       lon_range=lon_range, lat_range=lat_range, pixel_ratio=pixel_ratio)
rl5 = datashader_wrapper(lores_mesh, h1_ds, target_var, 5, 
                                       lon_range=lon_range, lat_range=lat_range, pixel_ratio=pixel_ratio)

layout = rl1.opts(colorbar=True, cmap=ws_cmp, clim=(0, 20)) * rl2.opts(colorbar=True, cmap=ws_cmp, clim=(0, 20)) *\
        rl3.opts(colorbar=True, cmap=ws_cmp, clim=(0, 20)) * rl4.opts(colorbar=True, cmap=ws_cmp, clim=(0, 20)) *\
        rl5.opts(colorbar=True, cmap=ws_cmp, clim=(0, 20)) 
layout.opts(tools=['hover'], framewise=True) * gf.coastline(projection=projection)

In [None]:
#hv.output(width=400, height=400)

central_longitude = 0.0
projection = ccrs.Robinson(central_longitude=central_longitude)

target_var = 'PSL'
target_time = 2
pixel_ratio = 5
lon_range, lat_range = find_TC_bbox(h1_ds, 'north atlantic', target_time, center_dist=1000)
rasterized_lowres1 = datashader_wrapper(lores_mesh, h1_ds, target_var, target_time, 
                                       lon_range=lon_range, lat_range=lat_range, pixel_ratio=pixel_ratio)
rasterized_lowres2 = datashader_wrapper(lores_mesh, h1_ds, target_var, 3, 
                                       lon_range=lon_range, lat_range=lat_range, pixel_ratio=pixel_ratio)
layout = rasterized_lowres1.opts(colorbar=True, cmap='RdBu') * rasterized_lowres2.opts(colorbar=True, cmap='RdBu')
layout.opts(tools=['hover'], framewise=False) * gf.coastline(projection=projection)

In [None]:
hds_rasterize(hv.Image(h1_ds.U10.isel(time=[2]))).opts(cmap='jet', colorbar=True, width=800)