In [None]:
import xarray as xr
import netCDF4 as nc
import plotly
import chart_studio.plotly as py
import plotly.offline as py_off
import numpy as np
from scipy.io import netcdf
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
pio.renderers.default = 'iframe'
import cartopy.feature as cf
import dash
import dash_core_components
import matplotlib
from __future__ import print_function
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import matplotlib.animation as anim
import cartopy
from IPython.display import HTML
from PIL import Image

In [None]:
def get_plotly_data(file,field,level=None):
    ds = xr.open_dataset(xr.backends.NetCDF4DataStore(nc.Dataset(file)))
    df = ds[['lon','lat',f'{field}']].to_dataframe()

    if level is not None:
        levels = df.index.get_level_values('pfull')
        level = levels[np.argmin(np.abs(levels-level))]
        df = df[df.index.get_level_values('pfull') == level]
    
    step = 1.0
    to_bin = lambda x: np.floor(x / step) * step
    df["lat"] = df.index.get_level_values('lat').map(to_bin)
    df["lon"] = df.index.get_level_values('lon').map(to_bin)
    
    df.reset_index(drop=True, inplace=True)
    
    groups = df.groupby(["lat", "lon"])
    df_flat = df.drop_duplicates(subset=['lat', 'lon'])
    df = df_flat[np.isfinite(df_flat[f'{field}'])]
    df = df[(df.lat <= 90.0) & 
            (df.lat >= -90.0) & 
            (df.lon <= 360.0) & 
            (df.lon >= 0)]
    
    df['lon'] = df['lon'].apply(lambda x: x if x <= 180.0 else x-360.0)
    return df,ds
    

In [None]:
def get_fig_data(file,field,level=None):

    df,ds = get_plotly_data(file,field,level=level)
    
    x_coords = []
    y_coords = []
    traces = []
    for coord_seq in cf.COASTLINE.geometries():
        x_coords.extend([k[0] for k in coord_seq.coords] + [np.nan])
        y_coords.extend([k[1] for k in coord_seq.coords] + [np.nan])  
    
    ## in your app callback for dash
    trace = go.Scatter(x = x_coords,
                       y = y_coords,
                       mode = 'lines',
                       line=go.Line(color="black"))
    traces.append(trace)
    
    contours = go.Contour(z=df[f'{field}'],
                          x=df['lon'],
                          y=df['lat'],
                          colorscale="RdBu",
                          zauto=False,  # custom contour levels
                          zmin=min(df[f'{field}']),      # first contour level
                          zmax=max(df[f'{field}'])        # last contour level  => colorscale is centered about 0
                          )
    data = go.Data([contours]+traces)
    
    title = u"{0} field ({1}) <br>".format(ds[f'{field}'].long_name,ds[f'{field}'].units)

    anno_text = "Data from ISCA simulation"
    #of \
    #<a href='http://www.esrl.noaa.gov/psd/data/composites/day/'>\
    #NOAA Earth System Research Laboratory</a>"

    axis_style = dict(zeroline=False,
                      showline=False,
                      showgrid=False,
                      ticks='',
                      showticklabels=False)

    layout = go.Layout(title=title,
                       showlegend=False,
                       hovermode="closest",        # highlight closest point on hover
                       xaxis=go.XAxis(axis_style,
                                      range=[min(df['lon']),
                                             max(df['lon'])]),  # restrict y-axis to range of lon
    
                       yaxis=go.YAxis(axis_style,
                                      range=[min(df['lat']),
                                             max(df['lat'])]),

                       annotations=go.Annotations([go.Annotation(text=anno_text,
                                                                 xref='paper',
                                                                 yref='paper',
                                                                 x=0,y=1,
                                                                 yanchor='bottom',
                                                                 showarrow=False)]),
                       autosize=False,
                       width=1000,
                       height=600)
    
    #layout = Layout(
    #paper_bgcolor='rgba(0,0,0,0)',
    #plot_bgcolor='rgba(0,0,0,0)'
    #)
    return data,layout
    

In [None]:
def plot_field(file,field,level=None):  
    data,layout = get_fig_data(file,field,level=level)
    fig = go.Figure(data=data,layout=layout)
    fig.show()    

In [1]:
def get_animation(files,field,level=None):

    frames = []
    axis_style = dict(zeroline=False,
                      showline=False,
                      showgrid=False,
                      ticks='',
                      showticklabels=False)

    layout = go.Layout(title='',
                       showlegend=False,
                       hovermode="closest",        # highlight closest point on hover
                       xaxis=go.XAxis(axis_style,
                                      range=[min(get_data(files[0],field,level)[0]['lon']),
                                             max(get_data(files[0],field,level)[0]['lon'])]),  # restrict y-axis to range of lon
    
                       yaxis=go.YAxis(axis_style,
                                      range=[min(get_data(files[0],field,level)[0]['lat']),
                                             max(get_data(files[0],field,level)[0]['lat'])]),

                       annotations=go.Annotations([go.Annotation(text='',
                                                                 xref='paper',
                                                                 yref='paper',
                                                                 x=0,y=1,
                                                                 yanchor='bottom',
                                                                 showarrow=False)]),
                       autosize=False,
                       width=1000,
                       height=600,
                       updatemenus=[dict(
                       type="buttons",
                       buttons=[dict(label="Play",
                          method="animate",
                          args = [None, {"frame": {"duration": 1000, 
                                                              "redraw": True},
                                                              "fromcurrent": True, 
                                                              "transition": {"duration": 0,"easing": "quadratic-in-out"}}])])])

    for f in files:
        frames.append(go.Frame(data=get_fig_data(f,field,level)[0]))

    fig = go.Figure(
    data=get_fig_data(files[0],field,level)[0],
    layout=layout,
    frames=frames
    )

    fig.show()