In [1]:
%load_ext autoreload
%autoreload 2
import yaml

import utils
from utils.utils_units import conv_units
from utils.Plotting import _apply_log10_vals

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import cftime
from tqdm.auto import tqdm
from pathlib import Path
import dask
from dask_jobqueue import SLURMCluster
from distributed import Client, progress, wait



In [2]:
cluster = SLURMCluster(cores=1, processes=1, memory="40GB")
cluster.adapt(minimum_jobs=34, maximum_jobs=35)
client = Client(cluster)
cluster

VBox(children=(HTML(value='<h2>SLURMCluster</h2>'), HBox(children=(HTML(value='\n<div>\n  <style scoped>\n    …

In [3]:
image_dir = Path("../images")
image_dir.mkdir(parents=True, exist_ok=True)

In [4]:
with open("diag_metadata.yaml", mode="r") as fptr:
    diag_metadata_list = yaml.safe_load(fptr)

In [5]:
def isel_dict_as_string(isel_dict, to_plot):
    depth_levels = dict()
    for key in isel_dict:
        depth_levels[key] = to_plot[key].data.item()
    s = "+".join(
        ["%s@%s" % (key, value) for (key, value) in depth_levels.items()]
    )
    return s


def summary_plot_map(da, diag_metadata, plot_dir, cmap="plasma"):
    import matplotlib.pyplot as plt
    import numpy as np
    import xarray as xr

    try:
        time = str(da.time.data.item())
        varname = diag_metadata["varname"]

        for apply_log10 in _apply_log10_vals(diag_metadata):
            filename = f"{varname}+{time}"
            vmin = diag_metadata.get("map_vmin")
            vmax = diag_metadata.get("map_vmax")
            if apply_log10:
                if vmin is not None:
                    vmin = np.log10(vmin) if vmin > 0.0 else None
                if vmax is not None:
                    vmax = np.log10(vmax) if vmax > 0.0 else None
            to_plot = da.copy()

            if "display_units" in diag_metadata:
                to_plot = conv_units(to_plot, diag_metadata["display_units"])
            if apply_log10:
                to_plot = np.log10(xr.where(to_plot > 0.0, to_plot, np.nan))
                to_plot.name = f"log10({to_plot.name})"
                filename = f"{filename}+log_10@{apply_log10}"
            if "isel_dict" in diag_metadata:
                isel_dict = diag_metadata["isel_dict"]
                s = isel_dict_as_string(isel_dict, to_plot)
                filename = f"{filename}+{s}.png"
            else:
                filename = f"{filename}.png"
            path = plot_dir / filename
            if path.exists():
                return da.time
            ax = to_plot.plot(cmap=cmap, vmin=vmin, vmax=vmax)
            fig = ax.get_figure()
            plt.savefig(path.as_posix(), dpi=300)
            plt.close(fig)
    except Exception as exc:
        print(exc)
    return da.time

In [6]:
case = utils.CaseClass(
    "g.e22.G1850ECO_JRA_HR.TL319_t13.004",
    start_date="0001-01",
    end_date="0005-12",
)
stream = "pop.h"
case._open_history_files(stream)

Datasets contain a total of 60 time samples
Last average written at 0006-01-01 00:00:00


In [7]:
plots, v = [], []
plot_type = "timestep-global-map"
plot_dir = image_dir / plot_type
plot_dir.mkdir(parents=True, exist_ok=True)
for diag_metadata in diag_metadata_list:
    z = dict(diag=diag_metadata.copy())
    varname = diag_metadata["varname"]
    isel_dict = diag_metadata.get("isel_dict")
    data = (
        case.history_contents[stream][varname]
        .isel(isel_dict)
        .chunk({"time": 1})
    )
    z["data"] = data
    template = xr.zeros_like(data.time).chunk({"time": 1})
    z["template"] = template
    p = xr.map_blocks(
        summary_plot_map, data, args=[z["diag"], plot_dir], template=template
    )
    v.append(z)
    plots.append(p)

In [8]:
plots = xr.concat(plots, dim="diag", compat="override", coords="minimal")
plots

Unnamed: 0,Array,Chunk
Bytes,11.52 kB,8 B
Shape,"(24, 60)","(1, 1)"
Count,19644 Tasks,1440 Chunks
Type,object,numpy.ndarray
"Array Chunk Bytes 11.52 kB 8 B Shape (24, 60) (1, 1) Count 19644 Tasks 1440 Chunks Type object numpy.ndarray",60  24,

Unnamed: 0,Array,Chunk
Bytes,11.52 kB,8 B
Shape,"(24, 60)","(1, 1)"
Count,19644 Tasks,1440 Chunks
Type,object,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4 B,4 B
Shape,(),()
Count,793 Tasks,1 Chunks
Type,float32,numpy.ndarray
Array Chunk Bytes 4 B 4 B Shape () () Count 793 Tasks 1 Chunks Type float32 numpy.ndarray,,

Unnamed: 0,Array,Chunk
Bytes,4 B,4 B
Shape,(),()
Count,793 Tasks,1 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4 B,4 B
Shape,(),()
Count,793 Tasks,1 Chunks
Type,float32,numpy.ndarray
Array Chunk Bytes 4 B 4 B Shape () () Count 793 Tasks 1 Chunks Type float32 numpy.ndarray,,

Unnamed: 0,Array,Chunk
Bytes,4 B,4 B
Shape,(),()
Count,793 Tasks,1 Chunks
Type,float32,numpy.ndarray


In [9]:
client.compute(plots, retries=2)

In [12]:
def summary_plot_global_ts(
    ds, diag_metadata, plot_dir=None, time_coarsen_len=None
):
    varname = diag_metadata["varname"]
    da = ds[varname]
    reduce_dims = da.dims[-2:]
    weights = ds["TAREA"].fillna(0)
    da_weighted = da.weighted(weights)
    spatial_op = diag_metadata.get("spatial_op", "average")
    filename = f"{varname}+spatial_op@{spatial_op}"
    if spatial_op == "average":
        to_plot = da_weighted.mean(dim=reduce_dims)
        to_plot.attrs = da.attrs
        if "display_units" in diag_metadata:
            to_plot = conv_units(to_plot, diag_metadata["display_units"])
    if spatial_op == "integrate":
        to_plot = da_weighted.sum(dim=reduce_dims)
        to_plot.attrs = da.attrs
        to_plot.attrs["units"] += f" {weights.attrs['units']}"
        if "integral_display_units" in diag_metadata:
            to_plot = conv_units(
                to_plot,
                diag_metadata["integral_display_units"],
                units_scalef=diag_metadata.get("integral_unit_conv"),
            )
    # do not use to_plot.plot.line("-o") because of incorrect time axis values
    # https://github.com/pydata/xarray/issues/4401
    fig, ax = plt.subplots()
    ax.plot(
        utils.utils.time_year_plus_frac(to_plot, "time"), to_plot.values, "-o"
    )
    ax.set_xlabel(xr.plot.utils.label_from_attrs(to_plot["time"]))
    ax.set_ylabel(xr.plot.utils.label_from_attrs(to_plot))
    ax.set_title(to_plot._title_for_slice())
    if time_coarsen_len is not None:
        filename = f"{filename}+time_coarsen_len@{time_coarsen_len}"
        tlen = len(to_plot.time)
        tlen_trunc = (tlen // time_coarsen_len) * time_coarsen_len
        to_plot_trunc = to_plot.isel(time=slice(0, tlen_trunc))
        to_plot_coarse = to_plot_trunc.coarsen(
            {"time": time_coarsen_len}
        ).mean()
        ax.plot(
            utils.utils.time_year_plus_frac(to_plot_coarse, "time"),
            to_plot_coarse.values,
            "-o",
        )
        title = ax.get_title()
        if title != "":
            title += ", "
        title += f"last mean value={utils.utils.round_sig(to_plot_coarse.values[-1],4)}"
        ax.set_title(title)
    fig = ax.get_figure()
    if plot_dir is None:
        plt.close(fig)
        return fig
    if "isel_dict" in diag_metadata:
        isel_dict = diag_metadata["isel_dict"]
        s = isel_dict_as_string(isel_dict, to_plot)
        filename = f"{filename}+{s}.png"
    else:
        filename = f"{filename}.png"
    path = plot_dir / filename
    fig = ax.get_figure()
    if not path.exists():
        plt.savefig(path.as_posix(), dpi=300)
        plt.close(fig)
    return fig

In [13]:
plot_type = "global-timeseries"
plot_dir = image_dir / plot_type
plot_dir.mkdir(parents=True, exist_ok=True)
for diag_metadata in tqdm(diag_metadata_list):
    varname = diag_metadata["varname"]
    isel_dict = diag_metadata.get("isel_dict")
    data = case.history_contents[stream][[varname, "TAREA"]].isel(isel_dict)
    summary_plot_global_ts(
        data, diag_metadata, plot_dir=plot_dir, time_coarsen_len=12
    )

HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))


