# xhycom for HYCOM OPeNDAP and plotting demonstration
**Author: Jun Sasaki, Coded on July 23, 2020, Updated on August 15, 2020**<br>
- See xhycom/utils.py
- When encountering `urlopen error [Errno -2] Name or service not known`, reduce zoom.
- To avoid decoding problem in time coordinate, var "tau" is deleted. If "tau" is necessary, manual decoding is required.

In [None]:
from xhycom import utils as xh
import numpy as np
import pandas as pd
from datetime import datetime, timedelta, timezone
import xarray as xr
from netCDF4 import num2date, date2num
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.io.img_tiles import Stamen
from cartopy.io.img_tiles import OSM
from cartopy.mpl.ticker import LatitudeFormatter,LongitudeFormatter
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import hvplot.xarray
import warnings
import os
import sys
import datetime
from matplotlib.animation import FuncAnimation
# This is needed to display graphics calculated outside of jupyter notebook
from IPython.display import HTML, display
warnings.filterwarnings('ignore')
%matplotlib inline
#%matplotlib widget  ### fig.savefig() does not work

# HYCOM data download
## Downloading one time and creating netcdf
- Enforce `opendap=False` if already downloaded
- In general, use run_opendap()

In [None]:
opendap=True  ### True: Download HYCOM data using OPeNDAP, False: Use existing netcdf
extent = (139.2, 140.2, 34.8, 35.8)  ### (lon_min, lon_max, lat_min, lat_max)
time = '2012-10-28 18:00:00'
ncfile="hycom_tokyo_bay.nc"

if opendap:
    if os.path.exists(ncfile):
        print('ERROR: Netcdf file alreay exists. Delete it before downloading.')
    else:
        ds = xh.run_hycom_gofs3_1_region_ymdh(extent=extent, time=time)
        ds.to_netcdf(ncfile, mode="w")
if os.path.exists(ncfile):
    ds = xr.open_dataset(ncfile)
    ds
else:
    print('ERROR: Netcdf file does not exist.')

In [None]:
ds['time'].attrs

## Same as above but using xh.run_opendap()
- Enforce `opendap=False` if already downloaded.
- `xh.run_opendap(extent, time_start, time_end=None, dtime=3, tz='utc')`
- Netcdf file name is automatically generated as "hycom_" + datetime + ".nc" 
- This method should be used in general.

In [None]:
extent = (139.2, 140.2, 34.8, 35.8)  ### (lon_min, lon_max, lat_min, lat_max)
time = '2012-10-29 00:00:00'

xh.run_opendap(extent=extent, time_start=time)

## Downloading and creating multi-time netcdf
- `xh.run_opendap(extent, time_start, time_end=None, dtime=3, tz='utc')`
- Netcdf file name is automatically generated as "hycom_" + datetime + ".nc" 
- This method should be used in general.

In [None]:
extent = (139.2, 140.2, 34.8, 35.8)
time_start, time_end, dtime = ('2012-10-28 12:00:00', '2012-10-29 00:00:00', 3)  ### dtime (int): time interval in hours

xh.run_opendap(extent=extent, time_start=time_start, time_end=time_end, dtime=dtime)

# Plotting HYCOM netcdf with [matplotlib](https://matplotlib.org/#)
- NetCDF is loaded as xarray.Dataset.
- matplotlib is used for static plotting, which is suitable for publication.

## Loading netcdf into xarray.Dataset

In [None]:
ncfile = "hycom_2012-10-29_00.nc"
with xr.open_dataset(ncfile) as ds:
    print(ds)
#ds.salinity
#ds["salinity"]
#ds[["salinity"]]

In [None]:
### Examples of indexing and slicing
#ds["salinity"]
#ds.salinity
#ds.salinity.isel(depth=0, time=0)
#ds.salinity.isel(depth=slice(0,3), time=[0])
ds.salinity.sel(depth=[0.1, 15.0], method="nearest")

## Simple plotting
Very simple but the coastline is ugly, x- and y- axes labels are missing, and the legend is too long by default.

In [None]:
### Start setting by users
depth = 0 # depth index in int
extent = (139.2, 140.2, 34.8, 35.8)  ### (lon_min, lon_max, lat_min, lat_max)
figsize = (6,6)
png = 'hycom_tokyo_bay1.png'
## Default font size
plt.rcParams['font.size'] = 12
### End setting by users

central_longitude = np.mean(extent[0:2])
#proj = ccrs.LambertConformal(central_longitude=central_longitude, central_latitude=central_latitude)
proj = ccrs.PlateCarree(central_longitude=central_longitude)
fig = plt.figure(figsize=figsize)
ax = plt.axes(projection=proj)
ds.salinity.isel(depth=depth).plot(ax=ax, x='lon', y='lat', transform=ccrs.PlateCarree())
ax.coastlines()
fig.savefig(png, dpi=600, bbox_inches='tight')

## Simple plotting with minimum axes adjustment

In [None]:
### Start setting by users
depth = 0 # depth index in int
extent = (139.2, 140.2, 34.8, 35.8)
figsize = (6,5)
png = 'hycom_tokyo_bay2.png'
## Default font size
plt.rcParams['font.size'] = 12
## Adjust surorunding margins
lon_min, lon_max = extent[0]-0.1, extent[1]+0.15
lat_min, lat_max = extent[2]-0.1, extent[3]+0.1
### End setting by users

central_longitude = np.mean(extent[0:2])
#proj = ccrs.LambertConformal(central_longitude=central_longitude, central_latitude=central_latitude)
proj = ccrs.PlateCarree(central_longitude=central_longitude)
fig = plt.figure(figsize=figsize)
ax = plt.axes(projection=proj)
ds.salinity.isel(depth=depth).plot(ax=ax, x='lon', y='lat', transform=ccrs.PlateCarree())
ax.set_extent([lon_min,lon_max,lat_min,lat_max], crs=ccrs.PlateCarree())
ax.coastlines()
gl=ax.gridlines(draw_labels=True, linewidth=0.5, color='k', alpha=0.8)
gl.right_labels=False
gl.top_labels=False
fig.savefig(png, dpi=600, bbox_inches='tight')

## Simple plotting with manual axes adjustment

In [None]:
### Start settings by users
extent = (139.2, 140.2, 34.8, 35.8)
depth = 0 # depth index in int
figsize = (6,5)
png = 'hycom_tokyo_bay3.png'
## Default font size
plt.rcParams['font.size'] = 12
## Adjust surrounding margins
lon_min, lon_max = extent[0]-0.1, extent[1]+0.15
lat_min, lat_max = extent[2]-0.1, extent[3]+0.1
## Ticks intervals for lon and lat axes
dlon, dlat = (0.4, 0.2)
### End settings by users

central_longitude = np.mean(extent[0:2])
xticks = np.arange(lon_min, lon_max, dlon)
yticks = np.arange(lat_min, lat_max, dlat)
proj = ccrs.PlateCarree(central_longitude=central_longitude)
fig = plt.figure(figsize=figsize)
ax = plt.axes(projection=proj)
ds.salinity.isel(depth=depth).plot(ax=ax, x='lon', y='lat', transform=ccrs.PlateCarree())
ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
ax.coastlines()
gl=ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False, linewidth=1, color='k', alpha=0.8)
gl.xlocator = mticker.FixedLocator(xticks)
gl.ylocator = mticker.FixedLocator(yticks)
ax.set_xticks(xticks,crs=ccrs.PlateCarree())
ax.set_yticks(yticks,crs=ccrs.PlateCarree())
latfmt=LatitudeFormatter()
lonfmt=LongitudeFormatter(zero_direction_label=True)
ax.xaxis.set_major_formatter(lonfmt)
ax.yaxis.set_major_formatter(latfmt)
ax.axes.tick_params(labelsize=12)
fig.savefig(png, dpi=600, bbox_inches='tight')

## Simple plotting with manual axes adjustment and background map (tiler)

In [None]:
### Start settings by users
# Background map
tiler = Stamen('terrain-background')
#tiler = OSM()
zoom=9
extent = (139.2, 140.2, 34.8, 35.8)
figsize = (6,5)
depth = 0 # depth index in int
png = 'hycom_tokyo_bay4.png'
## Default font size
plt.rcParams['font.size'] = 14
## Override axes label size
axes_label_size = 12
## Adjust surronding margins
lon_min, lon_max = extent[0]-0.1, extent[1]+0.15
lat_min, lat_max = extent[2]-0.1, extent[3]+0.1
## Ticks intervals
dlon, dlat = (0.4, 0.2)
### End settings by users

xticks = np.arange(lon_min, lon_max, dlon)
yticks = np.arange(lat_min, lat_max, dlat)
central_longitude = np.mean(extent[0:2])
proj = ccrs.PlateCarree(central_longitude=central_longitude)
fig = plt.figure(figsize=figsize)
ax = plt.axes(projection=proj)
ds.salinity.isel(depth=0).plot(ax=ax, x='lon', y='lat', transform=ccrs.PlateCarree())
ax.set_extent([lon_min,lon_max,lat_min,lat_max], crs=ccrs.PlateCarree())
#gl=ax.gridlines(draw_labels=True, xlocs=xlocs, ylocs=ylocs, linewidth=0.5, color='k', alpha=0.8)
gl=ax.gridlines(draw_labels=True, xlocs=xticks, ylocs=yticks, linestyle=':', linewidth=1, color='k', alpha=0.8)
#fig.canvas.draw()
gl.right_labels=False
gl.top_labels=False
gl.xlabel_style={'size':axes_label_size}
gl.ylabel_style={'size':axes_label_size}
gl.xlocator = mticker.FixedLocator(xticks)
ax.add_image(tiler, zoom)
fig.savefig(png, dpi=600, bbox_inches='tight')

## Function ax_lonlat_axes for adjusting axes
`ax_lonlat_axes(ax, extent, grid_linestyle=':', grid_linewidth=0.5, grid_color='k', \
                grid_alpha=0.8, xticks=None, yticks=None, label_size=12, tiler=None, zoom=8)`<br>
### Args:
-    ax(Axis)                 : gets current Axis
-    extent (tuple)           : (lon_min, lon_max, lat_min, lat_max)
-    grid_linestyle(str)      : linestyle (default: ':')
-    grid_linewidth(float/int): linewidth (default: 0.5)
-    grid_color(str)          : color (default: 'k')
-    grid_alpha(float)        : opacity (default: 0.8)
-    xticks(list)             : list of xticks (default: None)
-    yticks(list)             : list of yticks (default: None)
-    label_size(int)          : label size in pt (default: 12)
-    tiler(cartopy.io.img_tiles):
-    zoom(int)                : zoom in tiler
### Returns
-    ax(Axis)

## Plotting using ax_lonlat_axes and background map

In [None]:
### Start settings by users
extent = (139.2, 140.2, 34.8, 35.8)
tiler = Stamen('terrain-background')  ### tiler = Stamen('terrain-background')/OSM()/None 
zoom = 9
## Default font size
plt.rcParams['font.size'] = 14
## Override axes label size
axes_label_size = 12
grid_linewidth=1
figsize = (6,5)
depth = 0 # depth index in int
png = 'hycom_tokyo_bay5.png'
## Adjust surronding margins
lon_min, lon_max = extent[0]-0.1, extent[1]+0.15
lat_min, lat_max = extent[2]-0.1, extent[3]+0.1
## Tick intervals
dlon, dlat = (0.4, 0.2)
### End settings by users

xticks = np.arange(lon_min, lon_max, dlon)
yticks = np.arange(lat_min, lat_max, dlat)
central_longitude, central_latitude = np.mean(extent[0:2]), np.mean(extent[2:4])

proj = ccrs.PlateCarree(central_longitude=central_longitude)
fig = plt.figure(figsize=figsize)
ax = plt.axes(projection=proj)

ds.salinity.isel(depth=0).plot(ax=ax, x='lon', y='lat', transform=ccrs.PlateCarree())

ax = xh.ax_lonlat_axes(ax, extent=(lon_min, lon_max, lat_min, lat_max), \
               grid_linewidth=grid_linewidth, grid_color='k', \
               xticks=xticks, yticks=yticks, label_size=axes_label_size, tiler=tiler, zoom=zoom)
fig.savefig(png, dpi=600, bbox_inches='tight')

# Interactive plotting with [hvPlot](https://hvplot.holoviz.org/#)
- Useful for checking data values of `(x, y, z)` interactively.
- Slower than matplotlib; may not be suitable for publication.
- [Geoviews](https://geoviews.org/#) may also be useful.

[hvPlot geographic data](https://hvplot.holoviz.org/user_guide/Geographic_Data.html) and [hvPlot options](https://hvplot.holoviz.org/user_guide/Customization.html)<br>

In [None]:
### Geoviews
'''
import warnings
warnings.filterwarnings('ignore')
import geoviews as gv
import geoviews.feature as gf
from geoviews import opts
gv.extension('bokeh')
gv.output(size=100)

gvds=gv.Dataset(ds.salinity.isel(time=0))

qmesh = gvds.to(gv.QuadMesh, ['lon', 'lat'], dynamic=True).redim.range(salinity=(34.0, 34.3))
qmesh.opts(colorbar=True, cmap='jet', projection=ccrs.PlateCarree(), width=400, height=300 ,tools=['hover']) \
* gf.coastline.options(scale='10m')
'''

## Plan view
`xarray.Dataset.hvplot.quadmesh(x, y, z, project, coastline, frame_height, cmap)`<br>
- [cmap](https://matplotlib.org/gallery/color/colormap_reference.html?highlight=cmap): matplotlib colormap (for details see [here](https://matplotlib.org/tutorials/colors/colormaps.html))
- `coastline = ['10m'/'50m'/'110m']`
- When `tiles` is specified (e.g., `tiles="OSM"`), axes labels are ignored -> maybe a bug.


In [None]:
### When tiles is not None (e.g., tiles="OSM"), axes labels are to be ignored -> maybe bug.
z='salinity'
cmap='magma_r' ## _r : reversed 
frame_height=300
project = ccrs.PlateCarree()
ds.hvplot.quadmesh(x='lon', y='lat', z=z, project=project, tiles=None, coastline='10m', \
                   frame_height=frame_height, cmap=cmap)

## Vertical sectional views
`xarray.Dataset.hvplot.quadmesh(x, y, z, flip_yaxis, frame_width, frame_height, cmap)`<br>
- `flip_yaxis=True`: reverse yaxis

In [None]:
x = 'lon'
z = 'salinity'
cmap = 'magma_r'
frame_width = 300
frame_height = 200

ds.hvplot.quadmesh(x=x, y='depth', z=z, flip_yaxis=True, frame_width=frame_width, frame_height=frame_height, \
                   cmap=cmap)

In [None]:
x = 'lat'  ### x = 'lat'|'lon'

ds.hvplot.quadmesh(x=x, y='depth', z=z, flip_yaxis=True, frame_width=frame_width, frame_height=frame_height, \
                   cmap=cmap)

# Interactive time series plotting with hvPlot
- Using dask for large data handling
- See [Readng and writing files - xarray](http://xarray.pydata.org/en/stable/io.html)

## Loading netcdf files
- Dataset time series

In [None]:
ncfiles = 'hycom_2012*.nc'
#dsts = xr.open_mfdataset(ncfiles, parallel=True, concat_dim="time", data_vars='minimal', \
#                         coords='minimal', compat='override')
#dsts
with xr.open_mfdataset(ncfiles, parallel=True, concat_dim="time", data_vars='minimal', \
                       coords='minimal', compat='override') as dsts:
    print(dsts)


## Plan view
`xarray.Dataset.hvplot.quadmesh(x, y, z, project, coastline, frame_height, cmap)`

In [None]:
z = 'salinity'
project = ccrs.PlateCarree()
frame_height = 300
cmap = 'magma_r'

dsts.hvplot.quadmesh(x='lon', y='lat', z=z, project=project, coastline='10m', \
                     frame_height=frame_height, cmap=cmap)

## Vertical sectional views

In [None]:
x = 'lon'
z = 'salinity'
frame_width  = 300
frame_height = 200
cmap = 'magma_r'

dsts.hvplot.quadmesh(x=x, y='depth', z=z, flip_yaxis=True, frame_height=frame_height, frame_width=frame_width, \
                     cmap=cmap)

In [None]:
x = 'lat'
dsts.hvplot.quadmesh(x=x, y='depth', z=z, flip_yaxis=True, frame_height=frame_height, frame_width=frame_width, \
                     cmap=cmap)

# Create GIF and MP4 animation with matplotlib
- There is a kind of bug for plotting an array including nan with set_array() method which is required when creating animation. See [this page](https://stackoverflow.com/questions/58117358/matplotlib-image-plot-nan-values-shown-as-lowest-color-of-colormaps-instead-of) for details.
- To resolve this problem, masked_array should be used instead of ndarray. See [this page](http://xarray.pydata.org/en/stable/generated/xarray.DataArray.to_masked_array.html).

In [None]:
dsts.time

In [None]:
### Start settings by users
## Set params for animation
frames = 3  ### Number of frames
interval = 200
filename = "anim.gif"
dpi = 200
## Select indexes for depth and time, and specification of map
depth=0
time=0
extent = (139.2, 140.2, 34.8, 35.8)
tiler = Stamen('terrain-background')  ### tiler = Stamen('terrain-background')/OSM()/None 
zoom = 9
figsize = (6,5)

## Default font size
plt.rcParams['font.size'] = 11
## Override axes label size
axes_label_size = 11
grid_linewidth=1
#png = 'hycom_tokyo_bay5.png'
## Adjust surronding margins
lon_min, lon_max = extent[0]-0.1, extent[1]+0.15
lat_min, lat_max = extent[2]-0.1, extent[3]+0.1
## Tick intervals
dlon, dlat = (0.4, 0.2)
### End settings by users

xticks = np.arange(lon_min, lon_max, dlon)
yticks = np.arange(lat_min, lat_max, dlat)
#central_longitude, central_latitude = np.mean(extent[0:2]), np.mean(extent[2:4])

#proj = ccrs.PlateCarree(central_longitude=central_longitude)
proj = ccrs.PlateCarree()

fig = plt.figure(figsize=figsize)
ax = plt.axes(projection=proj)

fig.subplots_adjust(left=0.1, bottom=0.05, right=0.95, top=0.95, wspace=None, hspace=None)

ax = xh.ax_lonlat_axes(ax, extent=(lon_min, lon_max, lat_min, lat_max), \
           grid_linewidth=grid_linewidth, grid_color='k', \
           xticks=xticks, yticks=yticks, label_size=axes_label_size, tiler=tiler, zoom=zoom)

### Create the first panel
cax = dsts.salinity.isel(depth=depth, time=0).plot(ax=ax, x='lon', y='lat', transform=proj,\
                                                   vmin=34.0, vmax=34.2, levels=20, extend='both', \
                                                   cbar_kwargs={'shrink':0.85})

def update(i):
    '''
    Update panel at each time step i by replacing data and title.
    '''
    masked_array = dsts.salinity.isel(depth=depth, time=i).to_masked_array()
    cax.set_array(masked_array.flatten())
    ax.set_title("Depth = " + str(dsts['depth'].values[depth]) + "  Time = " \
                 + str(dsts.coords['time'].values[i])[0:16])

ani = FuncAnimation(fig, update, frames=frames, interval=interval, repeat=True)
ani.save(filename, writer="pillow", dpi=dpi)

In [None]:
### Display GIF animation on this Jupyter Lab notebook
HTML(ani.to_jshtml())

In [None]:
### Create MP4 animation and display on this Jupyter Lab notebook
ani.save('./Python_Animation_04.mp4', dpi=100)
display(HTML("<video controls><source src='./Python_Animation_04.mp4' type='video/mp4'></video>"))

# Classes for plot configuration, data management, and plotting
Preparing classes, which are useful for managing plot configuration `PlotConfig` and data `Data`. Also class `Plotter` is useful for plotting many figures in a uniform configuration. However, making complicated and sophisticated figures, manual customization is required without using the class, still `PlotConfig` and `Data` may be useful.  

In [None]:
class PlotConfig:
    '''
    Plotting configuration for 2D map
    '''

    def __init__(self, figsize=(6,5), tiler = Stamen('terrain-background'), zoom=9, cmap='magma_r', \
                 title_size=12, label_size=11, plot_coast='10m',\
                 proj='plate_carree', margins=(-0.1, 0.15, -0.1, 0.1), \
                 subplots_adjust=(0.1, 0.05, 0.95, 0.95, None, None), \
                 grid_linewidth=1, grid_linestyle=':', grid_color = 'k', grid_alpha=0.8, \
                 extend='both', cbar_kwargs={'shrink':0.9},title=None, \
                 vmin=None, vmax=None, levels=None):
        '''
        Instance method for class PlotConfig
        
        Args:
        figsize(tuple): (x_inches, y_inches)
        tiler(): tile or None
        zoom(int): Zoom for tile between 1 and 9 when tiler is specified.
        cmap(str): matplotlib colormap
        title_size(int): Font size for title or None for default
        label_size(int): Font size for label or None for default
        plot_coast(str): [ '10m' | '50m' | '110m' | None ] Plotting coastline
        proj(str): [ 'plate_carree' | None ] Set in case of lon-lat coordinates)
        margins(tuple): Margins in float for (L, B, R, T) in normalized coords between 0 and 1
        subplots_adjust(tuple): (L, B, R, T, wspace=None, hspace=None) for Figure.subplots_adjust()
        grid_linewidth(int|float): Grid line width
        grid_linestyle(str): Grid line style
        grid_color(str): Grid color
        grid_alpha(float): Grid alpha (opacity)
        cbar_kwargs(dict): kwargs for color bar 'shrink' adjusts the length of color bar
        extend(str): [ 'neither' | 'both' | 'min' | 'max' ] If not 'neither', 
                     make pointed end(s) for out-of- range values. 
                     These are set for a given colormap using the colormap set_under and set_over methods.
        vmin(float): vmin for cbar or None for default
        vmax(float): vmax for cbar or None for default
        levels(int | array-like): levels for cbar (int: num of levels, array-like: level values) or None for default
        '''

        self.figsize = figsize
        self.tiler = tiler
        self.zoom = zoom
        self.cmap = cmap
        self.title_size = title_size
        self.label_size = label_size
        ## self.ticks_intervals = ticks_intervals  ## ticks depends on plotting -> Plotter
        self.plot_coast = plot_coast
        self.proj = proj
        self.margins = margins
        self.subplots_adjust = subplots_adjust
        self.grid_linewidth = grid_linewidth
        self.grid_linestyle = grid_linestyle
        self.grid_color = grid_color
        self.grid_alpha = grid_alpha
        self.extend = extend
        self.cbar_kwargs=cbar_kwargs
        self.vmin=vmin
        self.vmax=vmax
        self.levels=levels
        self.title=title

class Data:
    '''
    Managing Xarray.DataArray in 4D(time, z, y, x) or 3D(time, y, x) prepared by reading netcdf.
    '''

    def __init__(self, da, vname=None, unit=None, xrange=None, yrange=None, zrange=None, trange=None, \
                 xlabel=None, ylabel=None, zlabel=None, tlabel=None):
        '''
        Instance method for class Data

        Args:
        da(Xarray.DataArray): Data to be managed
        vname(str): Name of variable. If None, gets from netcdf.
        unit(str) : Name of unit. If None, gets from netcdf.
        xrange(tuple): Range of x (lon)   dimension. If None, gets from netcdf.
        yrange(tuple): Range of y (lat)   dimension. If None, gets from netcdf.
        zrange(tuple): Range of z (depth) dimension. If None, gets from netcdf. Only in 4D data
        trange(tuple): Range of t (time)  dimension. If None, gets from netcdf. 
        xlabel(str)  : Label of x. If None, gets from netcdf.
        ylabel(str)  : Label of y. If None, gets from netcdf.
        zlabel(str)  : Label of z. If None, gets from netcdf. Only in 4D data
        tlabel(str)  : Label of t. If None, gets from netcdf.
        
        '''
        
        self.da = da                      ## xarray.DataArray
        self.vmax = self.da.max().values
        self.vmin = self.da.min().values
        if vname is None:
            self.vname = self.da.long_name
        else:
            self.vname = vname
        if unit is None:
            self.unit = self.da.units
        else:
            self.unit = unit

        ### 3D and 4D
        if xrange is None:
            self.xrange = (self.da[self.da.dims[-1]].values.min(), self.da[self.da.dims[-1]].values.max())
        else:
            self.xrange = xrange
        self.xmin, self.xmax = self.xrange
        if yrange is None:
            self.yrange = (self.da[self.da.dims[-2]].values.min(), self.da[self.da.dims[-2]].values.max())
        else:
            self.yrange = yrange
        self.ymin, self.ymax = self.yrange
        if xlabel is None:
            self.xlabel = self.da.dims[-1]
        else:
            self.xlabel = xlable
        if ylabel is None:
            self.ylabel = self.da.dims[-2]
        else:
            self.ylabel = ylable

        ### 4D (t, z, y, x)
        if len(self.da.dims) == 4:
            if zrange is None:
                self.zrange = (self.da[self.da.dims[-3]].values.min(), self.da[self.da.dims[-3]].values.max())
            else:
                self.zrange = zrange
            self.zmin, self.zmax = self.zrange
            if zlabel is None:
                self.zlabel = self.da.dims[-3]
            else:
                self.zlabel = zlable
        ### Time
        if trange is None:
            self.trange = (self.da[self.da.dims[0]].values.min(), self.da[self.da.dims[0]].values.max())
        else:
            self.trange = trange
        self.tmin, self.tmax = self.trange
        if tlabel is None:
            self.tlabel = self.da.dims[0]
        else:
            self.tlabel = tlabel
                
class Plotter:
    '''
    Plotting 2D maps (top view or vertical view) invoking the instances of PlotConfig and Data
    '''
    def __init__(self, plot_config, data, x='lon', y='lat', z='depth', t='time'):
        '''
        Instance method for Plotter. Type of plot is determined by args of x, y, z, and t.
        For example, if x='lon', y='lat', z=0, and t=0, a horizontal x-y plot is selected
        with indexing of z=0 and t=0.
        When t=slice(None), slicing all t values, used for creating animation.

        Args:
        plot_config(PlotConfig): Gets information about configuration
        data(Data): Gets data and its specification.
        x(str | int): Name or index of of x dimension
        y(str | int): Name or index of of y dimension
        z(str | int): Name or index of of z dimension
        t(str | int): Name or index of of t dimension
        '''
 
        self.cfg = plot_config
        self.data = data
        ### Prepare projection for (lon, lat) coordinates
        if self.cfg.proj == 'plate_carree':
            self.proj = ccrs.PlateCarree()
        else:
            self.proj = ccrs.PlateCarree()  ## Default
        self.x = x
        self.y = y
        self.z = z
        self.t = t
        ### Determine the type of plot and projection
        if self.x   == 'lon'  and self.y == 'lat':
            self.plot_type = 'xy_view'
            self.indexing = dict(time=self.t, depth=self.z)
            self.x_axis = self.x
            self.y_axis = self.y
        else:
            self.proj = None
        if self.x == 'lon'  and self.z == 'depth':
            self.plot_type = 'xz_view'
            self.indexing = dict(time=self.t, lat = self.y)
            self.x_axis=self.x
            self.y_axis=self.z
        elif self.y == 'lat'  and self.z == 'depth':
            self.plot_type = 'yz_view'
            self.indexing = dict(time=self.t, lon = self.x)
            self.x_axis = self.y
            self.y_axis = self.z
        elif self.t == 'time' and self.x == 'lon':
            self.plot_type = 'tx_view'
            self.indexing = dict(depth = self.z, lat = self.y)
            self.x_axis = self.t
            self.y_axis = self.x
        elif self.t == 'time' and self.y == 'lat':
            self.plot_type = 'ty_view'
            self.indexing = dict(depth = self.z, lon = self.x)
            self.x_axis = self.t
            self.y_axis = self.y
        elif self.t == 'time' and self.z == 'depth':
            self.plot_type = 'tz_view'
            self.indexing = dict(lat = self.y, lon = self.x)
            self.x_axis = self.t
            self.y_axis = self.z
        
        self.fig = plt.figure(figsize=self.cfg.figsize)
        
        
    def get_ax(self):
        '''
        Create and return Axes.
        '''

        print("projection = ", self.proj)
        if self.proj is None:
            return self.fig.add_subplot(1,1,1)
        else:
            return self.fig.add_subplot(1,1,1, projection=self.proj)

    def update_ax(self, ax):
        '''
        Update Axes.
        '''

        plt.rcParams['font.size'] = self.cfg.title_size

        if self.plot_type == 'xy_view':
            ax.set_extent(self.extent, crs=self.proj)
            if self.cfg.tiler is None:
                ax.coastlines()
                gl=ax.gridlines(draw_labels=True, xlocs=self.xticks, ylocs=self.yticks, \
                                linestyle=self.cfg.grid_linestyle, linewidth=self.cfg.grid_linewidth, \
                                color=self.cfg.grid_color, alpha=self.cfg.grid_alpha)
                gl.right_labels=False
                gl.top_labels=False
                gl.xlabel_style={'size':self.cfg.label_size}
                gl.ylabel_style={'size':self.cfg.label_size}
                if self.xticks is not None:
                    gl.xlocator = mticker.FixedLocator(self.xticks)
                if self.yticks is not None:
                    gl.ylocator = mticker.FixedLocator(self.yticks)
            else:
                gl=ax.gridlines(crs=self.proj, draw_labels=False, linestyle=self.cfg.grid_linestyle, \
                                linewidth=self.cfg.grid_linewidth, color=self.cfg.grid_color,\
                                alpha=self.cfg.grid_alpha)
                if self.xticks is not None:
                    gl.xlocator = mticker.FixedLocator(self.xticks)
                    ax.set_xticks(self.xticks,crs=self.proj)
                if self.yticks is not None:
                    gl.ylocator = mticker.FixedLocator(self.yticks)
                    ax.set_yticks(self.yticks,crs=self.proj)
                if self.proj is not None:
                    latfmt=LatitudeFormatter()
                    lonfmt=LongitudeFormatter(zero_direction_label=True)
                    ax.xaxis.set_major_formatter(lonfmt)
                    ax.yaxis.set_major_formatter(latfmt)
                ax.axes.tick_params(labelsize=self.cfg.label_size)
                ax.add_image(self.cfg.tiler, self.cfg.zoom)
        else:
            if self.plot_type == 'tx_view' or self.plot_type == 'ty_view' or \
               self.plot_type == 'tz_view':
                ### set_xlim() does not work for time series axis
                #ax.set_xlim(pd.to_datetime(self.extent[0]).to_pydatetime(), \
                #            pd.to_datetime(self.extent[1]).to_pydatetime())
                print('Notice: For time series axis, set_xlim() does not work.')
            else:
                ax.set_xlim(self.extent[0], self.extent[1])
            ax.set_ylim(self.extent[2], self.extent[3])
            ax.grid(linestyle=self.cfg.grid_linestyle, linewidth=self.cfg.grid_linewidth, \
                    color=self.cfg.grid_color, alpha=self.cfg.grid_alpha)
            if self.xticks is not None:
                ax.set_xticks(self.xticks)
            if self.yticks is not None:
                ax.set_yticks(self.yticks)
            ax.tick_params(labelsize=self.cfg.label_size)
            if self.plot_type == 'xz_view' or self.plot_type == 'yz_view' or \
               self.plot_type == 'tz_view':
                ax.invert_yaxis()  ## reverse y-axis when y axis is depth.
        
        return ax

    def update_title(self, ax):
        '''
        Updating title using ax.set_title()
        '''

        title = self.cfg.title
        if title == None:
            if self.plot_type == 'xy_view':
                title = str(self.data.da['time'].values[self.t])[0:16] + \
                                 "  Depth = " + str(self.data.da['depth'].values[self.z])[0:5]
            elif self.plot_type == 'xz_view':
                title = str(self.data.da['time'].values[self.t])[0:16] + \
                                 "  Lat = " + str(self.data.da['lat'].values[self.y])[0:5]
            elif self.plot_type == 'tx_view':
                title = "Lat = " + str(self.data.da['lat'].values[self.y])[0:5] + \
                             "  Depth = " + str(self.data.da['depth'].values[self.z])[0:5]
            elif self.plot_type == 'ty_view':           
                title = "Lon = " + str(self.data.da['lon'].values[self.x])[0:5] + \
                        "  Depth = " + str(self.data.da['depth'].values[self.z])[0:5]
            elif self.plot_type == 'tz_view':
                title = "Lon = " + str(self.data.da['lon'].values[self.x])[0:5] + \
                        "  Lat = " + str(self.data.da['lat'].values[self.y])[0:5]
        ax.set_title(title)

        return ax

    def make_2d_plot(self, ax, **kwargs):
        '''
        Update Axes by plotting 2D horizontal, vertical, or time series panel.
        Default da.plot() does not work for time series probably because of unsupported datetime64;
        thus it is replaced with da.plot.contourf(). In addition, ax.set_xlim() cannot be applied.
        '''

        if self.plot_type == 'xy_view':  ### Projection required
            self.data.da[self.indexing].plot(ax=ax, x=self.x_axis, y=self.y_axis, \
                    cmap=self.cfg.cmap, transform=self.proj, extend=self.cfg.extend, \
                    vmin=self.cfg.vmin, vmax=self.cfg.vmax, \
                    levels=self.cfg.levels, cbar_kwargs=self.cfg.cbar_kwargs, **kwargs)
        else:  ### No projection
            if self.plot_type == 'xz_view' or self.plot_type == 'yz_view':
                self.data.da[self.indexing].plot(ax=ax, x=self.x_axis, y=self.y_axis, \
                extend=self.cfg.extend, vmin=self.cfg.vmin, vmax=self.cfg.vmax, cmap=self.cfg.cmap, \
                levels=self.cfg.levels, cbar_kwargs=self.cfg.cbar_kwargs, **kwargs)
            elif self.plot_type == 'tx_view' or self.plot_type == 'ty_view' or\
                 self.plot_type == 'tz_view':  ### contourf() required
                self.data.da[self.indexing].plot.contourf(ax=ax, x=self.x_axis, y=self.y_axis, \
                extend=self.cfg.extend, vmin=self.cfg.vmin, vmax=self.cfg.vmax, cmap=self.cfg.cmap, \
                levels=self.cfg.levels, cbar_kwargs=self.cfg.cbar_kwargs, **kwargs)
            else:
                print('ERROR: No such plot type of ', self.plot_type)

        return ax

    def _set_plot_extent(self, extent, ticks_intervals, vmin, vmax, levels):
        '''
        Private method invoked by plot, save, and anim methods for setting plot extent, ticks intervals,
        variable range of vmin and vmax and contour levels.
        '''

        self.extent = extent
        self.ticks_intervals = ticks_intervals
        if self.extent is None:
            if self.plot_type == 'xy_view':
                self.extent = (self.data.xmin, self.data.xmax, self.data.ymin, self.data.ymax)
            elif self.plot_type == 'xz_view':
                self.extent = (self.data.xmin, self.data.xmax, self.data.zmin, self.data.zmax)
            elif self.plot_type == 'yz_view':
                self.extent = (self.data.ymin, self.data.ymax, self.data.zmin, self.data.zmax)
            elif self.plot_type == 'tx_view':
                self.extent = (self.data.tmin, self.data.tmax, self.data.xmin, self.data.xmax)
            elif self.plot_type == 'ty_view':
                self.extent = (self.data.tmin, self.data.tmax, self.data.ymin, self.data.ymax)
            elif self.plot_type == 'tz_view':
                self.extent = (self.data.tmin, self.data.tmax, self.data.zmin, self.data.zmax)
            else:
                print('Error: such plot_type is not defined in set_plot_extent')
        self.xmin, self.xmax = self.extent[0:2]
        self.ymin, self.ymax = self.extent[2:4]   

        if self.ticks_intervals is None:
            self.xticks = None
            self.yticks = None
        else:
            self.xticks = np.arange(self.xmin, self.xmax, self.ticks_intervals[0])
            self.yticks = np.arange(self.ymin, self.ymax, self.ticks_intervals[1])

        if vmin is not None:
            self.cfg.vmin = vmin  ### Override
        if vmax is not None:
            self.cfg.vmax = vmax  ### Override
        if levels is not None:
            self.cfg.levels = levels  ### Override

    def _plot(self, **kwargs):
        '''
        Private instance invoked by save method.
        '''
        ax = self.get_ax()
        ax = self.make_2d_plot(ax, **kwargs)
        ax = self.update_title(ax)
        ax = self.update_ax(ax)  ## Needs to invoke at the last
        return ax

    def plot(self, extent=None, ticks_intervals=None, vmin=None, vmax=None, levels=None, **kwargs):
        '''
        Public method for plotting on a screen. Arguments are to override parameter
        values set in PlotConfig.
        
        Args:
        extent(tuple)           : Extent of plot. Default is None which gets from PlotConfig.
        ticks_intervals(tuple)  : Default is None which gets from PlotConfig.
        vmin(float)             : Min variable value. Default is None which gets from PlotConfig.
        vmax(float)             : Max variable value. Default is None which gets from PlotConfig.
        levels(int | array-like): Number of levels in int or list of levels.
                                  Default is None which gets from PlotConfig.
        '''

        self._set_plot_extent(extent, ticks_intervals, vmin, vmax, levels)
        self._plot(**kwargs)
        return self

    def save(self, filename, **kwargs):
        '''
        Public method of plot method (Plotter().plot().save()) for creating graphic file.

        Args:
        filename(str): File name with an extension of graphic format, e.g., png.
        **kwargs     : kwargs for Figure.savefig()
        '''
        
        self.fig.savefig(filename, **kwargs)
        print("Saved {}.".format(filename))

    def frame(self, extent=None, ticks_intervals=None, vmin=None, vmax=None, levels=None, \
              subplot_adjust=(0.15, 0.05, 0.9, 0.95), **kwargs):
        '''
        Creating an initial frame of animation.
        
        Args:
        extent(tuple)           : Extent of plot. Default is None which gets from PlotConfig.
        ticks_intervals(tuple)  : Default is None which gets from PlotConfig.
        vmin(float)             : Min variable value. Default is None which gets from PlotConfig.
        vmax(float)             : Max variable value. Default is None which gets from PlotConfig.
        levels(int | array-like): Number of levels in int or list of levels.
                                  Default is None which gets from PlotConfig.
        subplot_adjust(tuple)   : Tuple of (left, bottom, right, top) in normalized coords.
        **kwargs                : DataArray.plot(**kwargs)
        '''

        left = subplot_adjust[0]
        bottom = subplot_adjust[1]
        right = subplot_adjust[2]
        top = subplot_adjust[3]

        self.fig.subplots_adjust(left=left, bottom=bottom, right=right, top=top, wspace=None, hspace=None)
        self._set_plot_extent(extent, ticks_intervals, vmin, vmax, levels)
        ax = self.get_ax()
        ax = self.update_ax(ax)
        self.cax = self.data.da.isel(depth=self.z, time=0).plot(ax=ax, x='lon', y='lat', transform=self.proj,\
                                                          cmap=self.cfg.cmap, extend=self.cfg.extend, \
                                                          vmin=vmin, vmax=vmax, levels=levels, \
                                                          **kwargs)
        #ani = FuncAnimation(self.fig, anim_update, frames=3, interval=200, blit=False, repeat=True)
        #ani.save(filename, **kwargs)
        self.ax_anim = ax
        return self


    def anim(self, filename, frames=None, interval=200, blit=False, repeat=True, **kwargs):
        '''
        Method of medhod frame for creating GIF animation. There is a kind of bug for plotting an array
        including nan with set_array() method which is required when creating animation;
        thus masked_array needs to be used instead of ndarray.
        https://stackoverflow.com/questions/58117358/matplotlib-image-plot-nan-values-shown-as-lowest-color-of-colormaps-instead-of
        http://xarray.pydata.org/en/stable/generated/xarray.DataArray.to_masked_array.html
        
        Args:
        filename(str): GIF or MP4 file name.
        frames(int)  : Number of frames
        interval(int): Interval in ms
        blit
        repeat(bool) : True: repeat
        **kwargs                : kwargs for Animation.save
        '''

        def anim_update(i):
            masked_array = self.data.da.isel(depth=self.z, time=i).to_masked_array()
            self.cax.set_array(masked_array.flatten())
            self.ax_anim.set_title("Depth = " + str(self.data.da['depth'].values[self.z]) + \
                         "  Time = " + str(self.data.da.coords['time'].values[i])[0:16])

        print(self.data.da['time'][self.t])
        if frames is None:
            frames = len(self.data.da['time'][self.t])
        ani = FuncAnimation(self.fig, anim_update, frames=frames, interval=interval, blit=blit, repeat=repeat)
        ani.save(filename, **kwargs)

## 2D Plotting for plan view, vertical views, and timeseries
Save method may be overridden by the following. You may put some text by self.fig.text(x, y, str) where x and y are normilized axes (0 and 1) and str is a plotted string. 

In [None]:
### Override "save" method of class Plotter
### Comment out for ''', when overriding save method. 
'''
def save(self, filename, extent=None, ticks_intervals=None, 
         vmin=None, vmax=None, levels=None, **kwargs):

    self._set_plot_extent(extent, ticks_intervals, vmin, vmax, levels)
    self._save(ax)
    ### Start edit
    self.fig.text(0.5, 0.5, 'Test')
    ### End edit
    self.fig.savefig(filename, **kwargs)
    print("Saved {}.".format(filename))
Plotter.save = save
'''

### Specify plot type and variable
plot_type = 'plot_xy'
var = 'salinity' #'water_temp'

### Speciry axes range or None
#extent_xy = (139.2, 140.2, 34.8, 35.8)
extent_xy = (139.0, 140.4, 34.6, 36.0)
extent_xz = (139.0, 140.4, 0, 3000) # None
extent_yz = None #(139.0, 140.4, 0, 3000)
extent_tx = None
extent_ty = None
extent_tz = None

### Specify axes' ticks intervals or None
ticks_intervals_xy = (0.4, 0.2)
ticks_intervals_xz = None
ticks_intervals_yz = None
ticks_intervals_tx = None
ticks_intervals_ty = None
ticks_intervals_tz = None

### Create Data instance and PlotConfig instance
da_var = Data(da=dsts[var])
cfg_sal = PlotConfig()

### Create Plotter instance and invoke its method of save or plot

vmin = 34.0
vmax = 34.2
levels = 10

if   plot_type == 'plot_xy':
    ### xy plot (lon, lat)
    plot_xy=Plotter(plot_config=cfg_sal, data=da_var, x='lon', y='lat', z=0, t=0)
    plot_xy.plot(extent=extent_xy, ticks_intervals=ticks_intervals_xy, \
                 vmin=vmin, vmax=vmax, levels=10).save("./class_text_xy.png", dpi=300, bbox_inches='tight')
    #plot_xy.save("./class_test_xy.png", extent=extent_xy, ticks_intervals=ticks_intervals_xy, \
    #             vmin=vmin, vmax=vmax, levels=10, dpi=300, bbox_inches='tight')
elif plot_type == 'plot_xz':
    ### xz plot (lon, depth)
    plot_xz=Plotter(plot_config=cfg_sal, data=da_var, x='lon', y=0, z='depth', t=0)
    plot_xz.save("./class_test_xz.png", extent=extent_xz, ticks_intervals=ticks_intervals_xz, dpi=300, bbox_inches='tight')
elif plot_type == 'plot_yz':
    ### yz plot (lat, depth)
    plot_yz=Plotter(plot_config=cfg_sal, data=da_var, x=2, y='lat', z='depth', t=0)
    plot_yz.save("./class_test_yz.png", extent=extent_yz, ticks_intervals=ticks_intervals_yz, dpi=300, bbox_inches='tight')
elif plot_type == 'plot_tx':
    ### tx plot (time, lon)
    plot_tx=Plotter(plot_config=cfg_sal, data=da_var, x='lon', y=0, z=0, t='time')
    plot_tx.save("./class_test_tx.png", extent=extent_tx, ticks_intervals=ticks_intervals_tx, dpi=300, bbox_inches='tight')
elif plot_type == 'plot_ty':
    ### ty plot (time, lat)
    plot_ty=Plotter(plot_config=cfg_sal, data=da_var, x=0, y='lat', z=0, t='time')
    plot_ty.save("./class_test_ty.png", extent=extent_ty, ticks_intervals=ticks_intervals_ty, dpi=300, bbox_inches='tight')
elif plot_type == 'plot_tz':
    ### tz plot (time, depth)
    plot_tz=Plotter(plot_config=cfg_sal, data=da_var, x=0, y=0, z='depth', t='time')
    plot_tz.save("./class_test_tz.png", extent=extent_tz, ticks_intervals=ticks_intervals_tz, dpi=300, bbox_inches='tight')
else:
    'ERROR: No such plot_type'

## Creating GIF or MP4 animation
Only supported plot_type of `plot_xy`. It is recommended to prepare a specific manual code for animation without using class because cutomizing panels is often required.

In [None]:
plot_type = 'plot_xy'
var = 'salinity' #'water_temp'
extent_xy = (139.0, 140.4, 34.6, 36.0)
ticks_intervals_xy = (0.4, 0.2)
### Create Data instance and PlotConfig instance
da_var = Data(da=dsts[var])
cfg_sal = PlotConfig()

if plot_type == 'plot_xy':
    ### xy plot (lon, lat)
    plot_xy=Plotter(plot_config=cfg_sal, data=da_var, x='lon', y='lat', z=0, t=slice(None))

    ### GIF
    plot_xy.frame(extent=extent_xy, ticks_intervals=ticks_intervals_xy, vmin=34.0, vmax=34.2, levels=20, \
                  cbar_kwargs={'shrink':0.8}).anim("./class_test_anim.gif", frames=None, \
                                                   interval=500, writer='pillow', dpi=200)
    ### MP4
    #plot_xy.frame(extent=extent_xy, ticks_intervals=ticks_intervals_xy, vmin=34.0, vmax=34.2, levels=20, \
    #              cbar_kwargs={'shrink':0.8}).anim("./class_test_anim.mp4", frames=None, \
    #                                               interval=500, writer='ffmpeg', dpi=200)

# Test for time series 2D plot
pcolormesh() and ax.set_xlim() does not work for time series. Thus replace it with contourf() and deactivate ax.set_xlim(). The following is a trial to resolve this problem but not yet resolvd.

In [None]:
### Test for timse series axis
### pcolormesh and ax.set_xlim() does not work for time series.
### Thus replace it with contourf() and deactivate ax.set_xlim().
### Start setting by users
depth = 0 # depth index in int
lat = 0
figsize = (6,6)
png = 'hycom_tokyo_bay1.png'
## Default font size
plt.rcParams['font.size'] = 12
### End setting by users

fig = plt.figure(figsize=figsize)
ax = plt.axes()
#dsts.salinity.isel(depth=0, lat=0).plot.contourf(ax=ax, x='time', y='lon')
dsts.salinity.isel(depth=0, lat=0).plot.contourf(ax=ax, x='time', y='lon', extend='both')
#x.set_xlim(pd.to_datetime('2012-10-28 12:00:00'), pd.to_datetime('2012-10-31 18:00:00'))
#ax.set_xlim([datetime.datetime(2012, 10, 29, 18), datetime.datetime(2012, 10, 31, 18)])
#ax.set_xbound([datetime.datetime(2012, 10, 28, 12), datetime.datetime(2012, 10, 31, 18)])
fig.savefig(png, dpi=600, bbox_inches='tight')

In [None]:
def test_pcolormesh_datetime_axis():
    fig = plt.figure()
    base = datetime.datetime(2013, 1, 1)
    x = np.array([base + datetime.timedelta(days=d) for d in range(21)])
    ##x = pd.to_datetime(x).to_datetime64()  ### not work
    y = np.arange(21)
    z1, z2 = np.meshgrid(np.arange(20), np.arange(20))
    z = z1 * z2
    plt.subplot(111)
    plt.pcolormesh(x, y, z)
    ##plt.xlim('2013-1-1 00:00:00', '2013-2-1 00:00:00')  ### not work

In [None]:
plt.close()
test_pcolormesh_datetime_axis()