In [None]:
import os
import glob

import numpy as np
import pandas as pd
import xarray as xr

from PIL import Image

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
import cartopy.crs as ccrs
import cartopy.feature as cfeature

import math
import dask.dataframe as dd

import metpy.calc as mpcalc
from metpy.units import units

from scipy.spatial import Delaunay

import holoviews as hv
from holoviews import opts

import geoviews as gv
import geoviews.feature as gf

import datashader as ds
from holoviews.operation.datashader import rasterize as hds_rasterize
from holoviews.operation.datashader import datashade as hds_datashade

from uviz.utils.tools import find_TC_bbox

gv.extension("bokeh","matplotlib")
#hv.extension("bokeh","matplotlib")

opts.defaults(
    opts.Image(width=1200, height=600),
    opts.RGB(width=1200, height=600))

In [None]:
# Datashader allows for custom colormap, but not custom registered colormaps. Holoviews is the opposite.

# https://cimss.ssec.wisc.edu/satellite-blog/wp-content/uploads/sites/5/2017/09/GOES16_CleanWindow_Landfall-20170920_0957_1136anim.gif
def T_to_FLUT(T, unit='K'):
    if unit == 'C':
        T += 273.15
    sigma = 5.6693E-8
    olr = sigma*(T**4)
    
    return olr

# Normalized cimss scale
bw_colors = [(0, '#BCBCBC'), (1, '#000000')]  
bw_cmp = LinearSegmentedColormap.from_list('FLUT bw', bw_colors, N=435)

levels = np.array([T_to_FLUT(temp, 'C') for temp in [-110, -105, -87.5, -80, -70, -60, -50, -35, -27.5, -22.5]])
fracs = levels-T_to_FLUT(-110, 'C')

rainbow_colors = [(0, '#0febff'), # cyan
                  ((fracs[1]/fracs[-1]), '#7f007f'), # purple
                  ((fracs[2]/fracs[-1]), '#e5e4e5'), # white
                  ((fracs[3]/fracs[-1]), '#000000'), # black
                  ((fracs[4]/fracs[-1]), '#ff0000'), # red
                  ((fracs[5]/fracs[-1]), '#FFFF00'), # yellow
                  ((fracs[6]/fracs[-1]), '#00FF00'), # green
                  ((fracs[7]/fracs[-1]), '#000073'), # navy
                  (1, '#00ffff')] # cyan


rainbow_cmp = LinearSegmentedColormap.from_list('FLUT colors', rainbow_colors, N=184)

bws = plt.get_cmap(bw_cmp)
bws_colors = bws(np.linspace(0, 1, 435))
rainbow = plt.get_cmap(rainbow_cmp)
r_colors = rainbow(np.linspace(0, 1, 184))

all_colors = np.vstack((r_colors, bws_colors))
flut_cimss = LinearSegmentedColormap.from_list('FLUT CIMSS', all_colors)
flut_cimss

In [None]:
### This funtion splits a global mesh along longitude
#
# Examine the X coordinates of each triangle in 'tris'. Return an array of 'tris' where only those triangles
# with legs whose length is less than 't' are returned.
#
def unzipMesh(x,tris,t):
    return tris[(np.abs((x[tris[:,0]])-(x[tris[:,1]])) < t) & (np.abs((x[tris[:,0]])-(x[tris[:,2]])) < t)]

# Compute the signed area of a triangle
#
def triArea(x,y,tris):
    return ((x[tris[:,1]]-x[tris[:,0]]) * (y[tris[:,2]]-y[tris[:,0]])) - ((x[tris[:,2]]-x[tris[:,0]]) * (y[tris[:,1]]-y[tris[:,0]]))

# Reorder triangles as necessary so they all have counterclockwise winding order. CCW is what Datashader and MPL
# require.
#
def orderCCW(x,y,tris):
    tris[triArea(x,y,tris)<0.0,:] = tris[triArea(x,y,tris)<0.0,::-1]
    return(tris)

def createHVTriMesh(x,y,triangle_indices, var, var_name, n_workers=1):
    # Declare verts array
    # This is essentally an XYZ matrix (at the location of x and y, z=?) 3 X len(verts) array
    verts = np.column_stack([x, y, var])

    # Convert to pandas
    verts_df  = pd.DataFrame(verts,  columns=['Longitude', 'Latitude', var_name])
    tris_df   = pd.DataFrame(triangle_indices, columns=['v0', 'v1', 'v2'])

    # Convert to dask
    verts_ddf = dd.from_pandas(verts_df, npartitions=n_workers)
    tris_ddf = dd.from_pandas(tris_df, npartitions=n_workers)

    # Declare HoloViews element
    tri_nodes = hv.Nodes(verts_ddf, ['Longitude', 'Latitude', 'index'], [var_name])
    trimesh = hv.TriMesh((tris_ddf, tri_nodes))
    return(trimesh)

In [None]:
prototype_dir = r"/gpfs/group/cmz5202/default/cnd5285/synth_events"
storm_1279_dir = os.path.join(prototype_dir, "VR28.NATL.EXT.CAM5.4CLM5.0.dtime900_storm_1279")
parent_28km = os.path.join(storm_1279_dir, '28km')
child_3km = os.path.join(storm_1279_dir, '3km')

In [None]:
h1p_files = glob.glob(os.path.join(parent_28km, '*.h1.*.nc'))
h2p_files = glob.glob(os.path.join(parent_28km, '*.h2.*.nc'))
h3p_files = glob.glob(os.path.join(parent_28km, '*.h3.*.nc'))
h4p_files = glob.glob(os.path.join(parent_28km, '*.h4.*.nc'))

# CAM mesh is the "ext" one, based on original model run. Other options are "ref" or "wat"
p_mesh = r"/gpfs/group/cmz5202/default/cnd5285/maps_and_grids/ne0np4natlanticext.ne30x4.g_scrip.nc"

In [None]:
# Native grid files
parallel = True
h1pn_ds = xr.open_mfdataset([f for f in h1p_files if 'remap' not in f], parallel=parallel)
h2pn_ds = xr.open_mfdataset([f for f in h2p_files if 'remap' not in f], parallel=parallel)
h3pn_ds = xr.open_mfdataset([f for f in h3p_files if 'remap' not in f], parallel=parallel)
h4pn_ds = xr.open_mfdataset([f for f in h4p_files if 'remap' not in f], parallel=parallel)
p_mesh_ds = xr.open_dataset(p_mesh)

# Regridded files
h1pr_ds = xr.open_mfdataset([f for f in h1p_files if 'remap' in f], parallel=parallel)
h2pr_ds = xr.open_mfdataset([f for f in h2p_files if 'remap' in f], parallel=parallel)
h3pr_ds = xr.open_mfdataset([f for f in h3p_files if 'remap' in f], parallel=parallel)
h4pr_ds = xr.open_mfdataset([f for f in h4p_files if 'remap' in f], parallel=parallel)

In [None]:
h1c_files = glob.glob(os.path.join(child_3km, '*.h1.*.nc'))
h2c_files = glob.glob(os.path.join(child_3km, '*.h2.*.nc'))
h3c_files = glob.glob(os.path.join(child_3km, '*.h3.*.nc'))
h4c_files = glob.glob(os.path.join(child_3km, '*.h4.*.nc'))
h5c_files = os.listdir(os.path.join(child_3km, 'h5'))

# Dug in the attributes for the mesh, TODO: make function that finds it automatically
c_mesh = "/gpfs/group/cmz5202/default/cnd5285/MPAS_3km/x20.835586.florida.init.CAM.nc"

In [None]:
parallel=True
# Native grid files
h1cn_ds = xr.open_mfdataset([f for f in h1c_files if 'remap' not in f], parallel=parallel)
h2cn_ds = xr.open_mfdataset([f for f in h2c_files if 'remap' not in f], parallel=parallel)
h3cn_ds = xr.open_mfdataset([f for f in h3c_files if 'remap' not in f], parallel=parallel)
h4cn_ds = xr.open_mfdataset([f for f in h4c_files if 'remap' not in f], parallel=parallel)
#h5cn_ds = xr.open_mfdataset([f for f in h5c_files])
c_mesh_ds = xr.open_dataset(c_mesh, chunks={'nCells':10000, 'nVertices':100, 'nEdges':100})

# Regridded files
h1cr_ds = xr.open_mfdataset([f for f in h1c_files if 'remap' in f], parallel=parallel)
h2cr_ds = xr.open_mfdataset([f for f in h2c_files if 'remap' in f], parallel=parallel)
h3cr_ds = xr.open_mfdataset([f for f in h3c_files if 'remap' in f], parallel=parallel)
h4cr_ds = xr.open_mfdataset([f for f in h4c_files if 'remap' in f], parallel=parallel)

In [None]:
#lonCell = ((lonCell - 180.0) % 360.0) - 180.0
lonCell = np.mod(np.rad2deg(c_mesh_ds['lonCell'].values) - 180.0, 360.0) - 180.0
latCell = np.rad2deg(c_mesh_ds['latCell'].values)
lonVertex = np.mod(np.rad2deg(c_mesh_ds['lonVertex'].values) - 180.0, 360.0) - 180.0
latVertex = np.rad2deg(c_mesh_ds['latVertex'].values)
tris = c_mesh_ds.cellsOnVertex.values - 1
tris_ccw = orderCCW(lonCell,latCell,tris)
tris_ccw_flat = unzipMesh(lonCell,tris,90.0)


sel_time = 9
primalVar = h3cn_ds['FLUT'].isel(time=sel_time).values
lon_range, lat_range = find_TC_bbox(h3cn_ds, 'florida', 23, center_dist=1000)

proj = ccrs.Robinson(central_longitude=0.0)
xPCS, yPCS, _ = proj.transform_points(ccrs.PlateCarree(), lonCell, latCell).T
trimesh = createHVTriMesh(xPCS,yPCS,tris_ccw_flat, primalVar, 'FLUT', n_workers=1)
x_range, y_range, _ = proj.transform_points(ccrs.PlateCarree(), np.array(lon_range), np.array(lat_range)).T
lon_range = tuple(x_range)
lat_range = tuple(y_range)

rasterized = hds_rasterize(trimesh, aggregator='mean', precompute=True, vdim_prefix='', pixel_ratio=5, x_range=lon_range, y_range=lat_range)

final = rasterized.opts(tools=['hover'], colorbar=True, cmap=flut_cimss, title='Upwelling Longwave Flux MPAS (Native Grid)', 
                        clim=(T_to_FLUT(-110, 'C'), T_to_FLUT(55, 'C')), width=800, clabel='Upwelling Longwave Flux [W/m^2]', 
                        fontsize=dict(title='16pt', cticks='10pt')) * gf.coastline(projection=proj).options(scale='10m')

final

In [None]:
times = pd.date_range('2008-08-30', '2008-09-04', freq='3H')


lon_range, lat_range = find_TC_bbox(h3cn_ds, 'florida', 23, center_dist=1000)

proj = ccrs.Robinson(central_longitude=0.0)
xPCS, yPCS, _ = proj.transform_points(ccrs.PlateCarree(), lonCell, latCell).T

x_range, y_range, _ = proj.transform_points(ccrs.PlateCarree(), np.array(lon_range), np.array(lat_range)).T
lon_range = tuple(x_range)
lat_range = tuple(y_range)

for i in range(41):
    title = f'Upwelling Longwave Flux MPAS (Native Grid) {times[i]}'
    primalVar = h3cn_ds['FLUT'].isel(time=i).values
    trimesh = createHVTriMesh(xPCS,yPCS,tris_ccw_flat, primalVar, 'FLUT', n_workers=1)
    rasterized = hds_rasterize(trimesh, aggregator='mean', precompute=True, vdim_prefix='', pixel_ratio=5, x_range=lon_range, y_range=lat_range)
    final = rasterized.opts(tools=['hover'], colorbar=True, cmap=flut_cimss, title=title, 
                            clim=(T_to_FLUT(-110, 'C'), T_to_FLUT(55, 'C')), width=800, clabel='Upwelling Longwave Flux [W/m^2]', 
                            fontsize=dict(title='16pt', cticks='10pt')) * gf.coastline(projection=proj).options(scale='10m')
    hv.save(final, f'../figs/flut_gif/MPAS_880/{title}.png', backend='bokeh', dpi=300)

In [None]:
import holoviews as hv
hv.help(hv.save)

In [None]:
?Image.open

In [None]:
output_folder = r"../figs/flut_gif/MPAS_880"
import os, glob
from PIL import Image

# Create the frames

imgs = glob.glob(os.path.join(output_folder, "*00.png"))
imgs.sort()
#imgs = imgs[1:]

frames = []
for i in imgs:
    new_frame = Image.open(i)
    frames.append(new_frame)
    
Save into a GIF file that loops forever
duration = 200
frames[0].save(os.path.join(output_folder, f'mpas_FLUT_{duration}.gif'), format='GIF',
               append_images=frames[1:],
               save_all=True,
               duration=duration, loop=0)
new_frame.close()

In [None]:
frames[0]

In [None]:
?frames[0].save

In [None]:
def triangulate(vertices, x="Longitude", y="Latitude"):
    """
    Generate a triangular mesh for the given x,y,z vertices, using Delaunay triangulation.
    For large n, typically results in about double the number of triangles as vertices.
    """
    triang = Delaunay(vertices[[x,y]].values)
    print('Given', len(vertices), "vertices, created", len(triang.simplices), 'triangles.')
    
    tris_df = pd.DataFrame(triang.simplices, columns=['v0', 'v1', 'v2'])
    
    return tris_df

def createHVTriMesh(x, y, var, var_name, n_workers=1):
    # Declare verts array
    # This is essentally an XYZ matrix (at the location of x and y, z=?) 3 X len(verts) array
    verts = np.column_stack([x, y, var])

    # Convert to pandas
    verts_df  = pd.DataFrame(verts,  columns=['Longitude', 'Latitude', var_name])
    
    # Creates Delaunay triangular mesh
    tris_df = triangulate(verts_df)

    # Convert to dask
    verts_ddf = dd.from_pandas(verts_df, npartitions=n_workers)
    tris_ddf = dd.from_pandas(tris_df, npartitions=n_workers)

    # Declare HoloViews element
    tri_nodes = hv.Nodes(verts_ddf, ['Longitude', 'Latitude', 'index'], [var_name])
    trimesh = hv.TriMesh((tris_ddf, tri_nodes))
    return(trimesh)



In [None]:
times = pd.date_range('2008-08-30', '2008-09-04', freq='3H')

lonCell = p_mesh_ds.grid_center_lon.values
latCell = p_mesh_ds.grid_center_lat.values
lonCell = ((lonCell - 180.0) % 360.0) - 180.0

sel_time = 23

lon_range, lat_range = find_TC_bbox(h3cn_ds, 'florida', 23, center_dist=1000)
proj = ccrs.Robinson(central_longitude=0.0)
# Creates Delaunay triangular mesh
xPCS, yPCS, _ = proj.transform_points(ccrs.PlateCarree(), lonCell, latCell).T

x_range, y_range, _ = proj.transform_points(ccrs.PlateCarree(), np.array(lon_range), np.array(lat_range)).T
lon_range = tuple(x_range)
lat_range = tuple(y_range)

for i in range(41):
    title = f'Upwelling Longwave Flux CAM (Native Grid) {times[i]}'
    primalVar = h3pn_ds['FLUT'].isel(time=i).values
    trimesh = createHVTriMesh(xPCS, yPCS, primalVar, 'FLUT')
    rasterized = hds_rasterize(trimesh, aggregator='mean', precompute=True, x_range=lon_range, y_range=lat_range, vdim_prefix='')
    final = rasterized.opts(tools=['hover'], colorbar=True, cmap=flut_cimss, title=title, 
                            clim=(T_to_FLUT(-110, 'C'), T_to_FLUT(55, 'C')), width=800, clabel='Upwelling Longwave Flux [W/m^2]', 
                            fontsize=dict(title='16pt', cticks='10pt')) * gf.coastline(projection=proj).options(scale='10m')

    hv.save(final, f'../figs/flut_gif/CAM_880/{title}.png', backend='bokeh', dpi=300)

In [None]:
output_folder = r"../figs/new_tracks/storm_0236"
from PIL import Image

# Create the frames
frames = []
imgs = glob.glob(os.path.join(output_folder, "OLR*.png"))
imgs.sort()
imgs = imgs[1:]
for i in imgs:
    new_frame = Image.open(i)
    frames.append(new_frame)
# Save into a GIF file that loops forever
duration = 300
frames[0].save(os.path.join(output_folder, f'FLUT_{duration}.gif'), format='GIF',
               append_images=frames[1:],
               save_all=True,
               duration=duration, loop=0)
new_frame.close()