# Prototype for combined model-observation bias and model-intercomparison notebook

This notebook plots the bottom stratification bias for a series of simulations vs. climataology derived from the World Ocean Atlas (WOA) 2018.

It also plots the differences between the simulations.

In [None]:
import os
import datetime

In [None]:
config = {
    "startyr": "1980",
    "endyr": "2017",
    "dora_id": "odiv-384, odiv-391, odiv-395, odiv-393, odiv-396", # try with just "odiv-319" as well.
    "pathPP": None,
}

MAR will pass for environment variables to the script when running via the web engine:

* `MAR_STARTYR`: A `str` of the beginning year of analysis from model
* `MAR_ENDYR`: A `str` of the ending year of analysis from model
* `MAR_DORA_ID`: A `str` of the experiment ID in the database for a single experiment, e.g. `"odiv-1"`, or comma-separated experiments, e.g. `"odiv-1, odiv-2"`
* `MAR_PATHPP`: A `str` of the top-level path to the post-processing experiment directory of the experiment

The block below will use values passed in by Dora but default to the values defined above in `config`. This is useful for interactive usage and debugging.

If executed from Dora, there will also be a `DORA_EXECUTE` variable that is set.

In [None]:
for k, v in config.items():
    if f"MAR_{k.upper()}" in os.environ.keys():
        c = os.environ[f"MAR_{k.upper()}"]
    else:
        c = v
    if k=="dora_id" and (type(v) is not list): 
        c = [e.strip(" ") for e in c.split(",")]
    config[k] = c

### Import Python Modules

In [None]:
import doralite
import glob
import momlevel
import subprocess 

import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs

from matplotlib.colors import ListedColormap, BoundaryNorm

In [None]:
# momgrid will use a directory of pre-computed weights that is used for calculating basic area-weighted statistics later
import momgrid
os.environ["MOMGRID_WEIGHTS_DIR"] = "/nbhome/John.Krasting/grid_weights"

### Define Local Parameters

In [None]:
# Define some local variables. These are taken from the doralite object
# or they can be defined locally

def get_local_params(config):
    experiments = {dora_id: doralite.dora_metadata(dora_id) for dora_id in config["dora_id"]}
    params = {
        dora_id: {
            "experiment": experiment,
            "pathPP": experiment["pathPP"],
            'expName': experiment['expName']
        }
        for dora_id, experiment in experiments.items()
    }
    return params

start = int(config["startyr"])
end = int(config["endyr"])
params = get_local_params(config)

### Determine What Files to Load

In [None]:
# Determine what files are needed (leave this up to the developer for flexibility)
# This is an example of what someone might do:

component = "ocean_annual"
static = f"{component}/{component}.static.nc"
varname = "N2_b"

chunk = "5yr"

filelist_dict = {odiv:
    glob.glob(f"{p['pathPP']}{component}/ts/**/{chunk}/{component}.*.{varname}.nc", recursive=True)
    for odiv, p in params.items() 
}

In [None]:
def is_in_range(file,start,end):
    start = int(start)
    end = int(end)
    target = set(list(range(start,end+1)))
    fname = os.path.basename(file)
    times = fname.split(".")[1]
    times = times.split("-")
    times = [int(x[0:4]) for x in times]
    candidate = set(list(range(times[0],times[1]+1)))
    return len(candidate.intersection(target)) > 0

filelist_dict = {odiv: [x for x in filelist if is_in_range(x,start,end)] for odiv, filelist in filelist_dict.items()}

import itertools
list_of_filelists = [filelist for filelist in filelist_dict.values() if len(filelist) > 1]
filelist = sorted(list(itertools.chain.from_iterable(list_of_filelists)))

staticfiles = [f"{p['pathPP']}/{static}" for p in params.values()]

_ = [print(x) for x in filelist]

### DMgetting Files

Dora cannot issue calls to dmget

In [None]:
if not "DORA_EXECUTE" in os.environ.keys():
    print("Calling dmget on files ...")
    cmd = ["dmget"]+filelist+staticfiles
    _ = subprocess.check_output(cmd)

### Load model data and grid

In [None]:
models = {}
for (odiv, filelist), staticfile in zip(filelist_dict.items(), staticfiles):
    if len(filelist) == 0:
        print(f"The required diagnostics are not available for {odiv}. Skipping.")
        continue
        
    gds = momgrid.Gridset(filelist, force_symmetric=True, return_corners=True)
    gds.data = gds.data.sel(time=slice(f"{str(start).zfill(4)}-01-01",f"{str(end).zfill(4)}-12-31"))

    og = xr.open_dataset(staticfile)
    og = og.assign_coords({"xh":gds.data.xh, "yh":gds.data.yh})

    models[odiv] = {"gds":gds, "og":og}

### Define some helper function for the plots

In [None]:
def set_annotaions(ax, odiv):
    _ = ax.set_xticks([])
    _ = ax.set_yticks([])
    _ = ax.text(0.0,1.06, f"Bottom stratification ({start}-{end})", weight="bold", fontsize=12, transform=ax.transAxes)
    _ = ax.text(0.0,1.02, f"{odiv} [{params[odiv]['expName']}]", style="italic", fontsize=10, transform=ax.transAxes)
    _ = ax.text(1.0,1.05, str(starttime.values), ha="right", fontsize=8, transform=ax.transAxes)
    _ = ax.text(1.0,1.02, str(endtime.values), ha="right", fontsize=8, transform=ax.transAxes)

In [None]:
def add_colorbar(fig, cb):
    cbar_ax = fig.add_axes([0.16, 0.06, 0.7, 0.03])
    fig.colorbar(
        cb, cax=cbar_ax, orientation="horizontal", extend="both", label=r"[s$^{-2}$]"
    )

# Plotting model values

In [None]:
import matplotlib.colors as colors

In [None]:
for (odiv, model_dict) in models.items():
    gds = model_dict["gds"]
    
    # Time-average the model data
    model = gds.data.N2_b
    starttime = model.time[0]
    endtime = model.time[-1]
    model = model.mean("time", keep_attrs=True).load()
    
    # Setup plot
    fig = plt.figure(figsize=(10,6))
    ax = plt.subplot(1,1,1, facecolor="lightgray")
    
    # Definie geolon and geolat for plotting (use corners!)
    x = gds.data.geolon_c
    y = gds.data.geolat_c
    
    # Run pcolormesh
    cb = plt.pcolormesh(x,y,model, cmap="cividis",  norm=colors.LogNorm(vmin=1e-8, vmax=1e-4))
    
    # Clean up figure and add labels
    set_annotaions(ax, odiv)
    
    # Add colorbar
    add_colorbar(fig, cb)

# Plotting inter-model ratios

In [None]:
def set_annotaions_ratio(ax, odiv, odiv_ref):
    _ = ax.set_xticks([])
    _ = ax.set_yticks([])
    _ = ax.text(0.0,1.06, f"percent change in stratification ({start}-{end})", weight="bold", fontsize=12, transform=ax.transAxes)
    _ = ax.text(0.0,1.02, rf"relative change = 100% (({odiv} [{params[odiv]['expName']}] - {odiv_ref} [{params[odiv_ref]['expName']}]) $/$ {odiv_ref} [{params[odiv_ref]['expName']}]))", style="italic", fontsize=10, transform=ax.transAxes)
    _ = ax.text(1.0,1.05, str(starttime.values), ha="right", fontsize=8, transform=ax.transAxes)
    _ = ax.text(1.0,1.02, str(endtime.values), ha="right", fontsize=8, transform=ax.transAxes)

def add_colorbar_relative_change(fig, cb):
    cbar_ax = fig.add_axes([0.16, 0.06, 0.7, 0.03])
    fig.colorbar(
        cb, cax=cbar_ax, orientation="horizontal", extend="both", label=r"% change, relative to reference simulation"
    )

In [None]:
first = True

if len(models)==1: 
    print("No difference plot because only one experiment ID was provided!")

else:
    for (odiv, model_dict) in models.items():
        if first:
            first = False
            gds_ref = model_dict["gds"].regrid(resolution=0.25)
            odiv_ref = odiv
            continue
            
        gds = model_dict["gds"].regrid(resolution=0.25)
        
        # Time-average the model data
        model = gds.N2_b
        starttime = model.time[0]
        endtime = model.time[-1]
        model = model.mean("time", keep_attrs=True).load()
    
        model_ref = gds_ref.N2_b
        starttime = model_ref.time[0]
        endtime = model_ref.time[-1]
        model_ref = model_ref.mean("time", keep_attrs=True).load()
        
        # Setup plot
        fig = plt.figure(figsize=(12,6))
        ax = plt.subplot(1,1,1, facecolor="lightgray")
        
        # Definie geolon and geolat for plotting (use corners!)
        x = gds.lon_b
        y = gds.lat_b
                
        # Run pcolormesh
        cb = plt.pcolormesh(x,y,100*(model-model_ref)/model_ref, cmap="RdBu_r", vmin=-100, vmax=100)
    
        # Clean up figure and add labels
        set_annotaions_ratio(ax, odiv, odiv_ref)
    
        # Add colorbar
        add_colorbar_relative_change(fig, cb)