In [417]:
import xarray as xr
import netCDF4 as nc
import plotly
import chart_studio.plotly as py
import plotly.offline as py_off
#from plotly.graph_objs import *
import numpy as np
from scipy.io import netcdf
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
pio.renderers.default = 'iframe'
import cartopy.feature as cf
#from mpl_toolkits.basemap import Basemap

In [401]:
file = '/data/bnb32/gfdl_data/held_suarez_default/run0001/atmos_monthly.nc'

In [442]:
def plot_field(file,field,level=None):

    ds = xr.open_dataset(xr.backends.NetCDF4DataStore(nc.Dataset(file)))
    print(ds)
    df = ds[['lon','lat',f'{field}']].to_dataframe()
    
    if level is not None:
        levels = df.index.get_level_values('pfull')
        level = levels[np.argmin(np.abs(levels-level))]
        df = df[df.index.get_level_values('pfull') == level]
    
    step = 1.0
    to_bin = lambda x: np.floor(x / step) * step
    df["lat"] = df.index.get_level_values('lat').map(to_bin)
    df["lon"] = df.index.get_level_values('lon').map(to_bin)
    
    df.reset_index(drop=True, inplace=True)
    
    groups = df.groupby(["lat", "lon"])
    df_flat = df.drop_duplicates(subset=['lat', 'lon'])
    df = df_flat[np.isfinite(df_flat[f'{field}'])]
    df = df[(df.lat <= 90.0) & 
            (df.lat >= -90.0) & 
            (df.lon <= 360.0) & 
            (df.lon >= 0)]
    
    df['lon'] = df['lon'].apply(lambda x: x if x <= 180.0 else x-360.0)
    
    x_coords = []
    y_coords = []
    traces = []
    for coord_seq in cf.COASTLINE.geometries():
        x_coords.extend([k[0] for k in coord_seq.coords] + [np.nan])
        y_coords.extend([k[1] for k in coord_seq.coords] + [np.nan])  
    
    ## in your app callback for dash
    trace = go.Scatter(x = x_coords,
                       y = y_coords,
                       mode = 'lines',
                       line=go.Line(color="black"))
    traces.append(trace)
    
    contours = go.Contour(z=df[f'{field}'],
                          x=df['lon'],
                          y=df['lat'],
                          colorscale="RdBu",
                          zauto=False,  # custom contour levels
                          zmin=min(df[f'{field}']),      # first contour level
                          zmax=max(df[f'{field}'])        # last contour level  => colorscale is centered about 0
                          )
    data = go.Data([contours]+traces)
    
    layout = Layout(
    paper_bgcolor='rgba(0,0,0,0)',
    #plot_bgcolor='rgba(0,0,0,0)'
    )

    fig = Figure(data=data,layout=layout)
    fig.show()
    

In [444]:
plot_field(file,'div',250)

<xarray.Dataset>
Dimensions:      (lon: 128, lonb: 129, lat: 64, latb: 65, time: 1, nv: 2, phalf: 26, pfull: 25)
Coordinates:
  * lon          (lon) float64 0.0 2.812 5.625 8.438 ... 348.8 351.6 354.4 357.2
  * lonb         (lonb) float64 -1.406 1.406 4.219 7.031 ... 353.0 355.8 358.6
  * lat          (lat) float64 -87.86 -85.1 -82.31 -79.53 ... 82.31 85.1 87.86
  * latb         (latb) float64 -90.0 -86.58 -83.76 -80.96 ... 83.76 86.58 90.0
  * time         (time) object 2000-01-16 00:00:00
  * nv           (nv) float64 1.0 2.0
  * phalf        (phalf) float64 0.0 6.165 12.71 22.59 ... 786.6 886.9 1e+03
  * pfull        (pfull) float64 2.268 9.244 17.42 28.92 ... 741.7 836.3 942.9
Data variables:
    ps           (time, lat, lon) float32 ...
    bk           (phalf) float32 ...
    pk           (phalf) float32 ...
    ucomp        (time, pfull, lat, lon) float32 ...
    vcomp        (time, pfull, lat, lon) float32 ...
    temp         (time, pfull, lat, lon) float32 ...
    vor        