In [None]:
import yaml

import utils
from utils.utils_units import conv_units
from utils.Plotting import _apply_log10_vals

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import cftime
from tqdm.auto import tqdm
from pathlib import Path
import dask
from dask_jobqueue import SLURMCluster
from distributed import Client, progress, wait
%matplotlib inline
%load_ext autoreload
%autoreload 2


In [None]:
cluster = SLURMCluster(cores=10, processes=10, memory="50GB")
cluster.scale(30)
client = Client(cluster)
cluster

In [None]:
image_dir = Path("../images")
image_dir.mkdir(parents=True, exist_ok=True)

In [None]:
with open("diag_metadata.yaml", mode="r") as fptr:
    diag_metadata_list = yaml.safe_load(fptr)

In [None]:
def summary_plot_maps(da, varname, isel_dict, diag_metadata):
    # maps, 1 plots for time level
    cmap = "plasma"

    for apply_log10 in _apply_log10_vals(diag_metadata):
        vmin = diag_metadata.get("map_vmin")
        vmax = diag_metadata.get("map_vmax")
        if apply_log10:
            if vmin is not None:
                vmin = np.log10(vmin) if vmin > 0.0 else None
            if vmax is not None:
                vmax = np.log10(vmax) if vmax > 0.0 else None
        for t_ind in range(len(da["time"])):
            try:
                to_plot = da.isel(time=t_ind)
                time = str(to_plot.time.data.item())
                filename = f"{varname}+{time}"
                if "display_units" in diag_metadata:
                    to_plot = conv_units(
                        to_plot, diag_metadata["display_units"]
                    )
                if apply_log10:
                    to_plot = np.log10(xr.where(to_plot > 0.0, to_plot, np.nan))
                    to_plot.name = f"log10({to_plot.name})"
                    filename = f"{filename}+log_10@{apply_log10}"
                if isel_dict is not None:
                    depth_levels = dict()
                    for key in isel_dict:
                        depth_levels[key] = to_plot[key].data.item()
                    s = "+".join(
                        [
                            "%s@%s" % (key, value)
                            for (key, value) in depth_levels.items()
                        ]
                    )
                    filename = f"{filename}+{s}.png"
                else:
                    filename = f"{filename}.png"
                path = (image_dir / filename).as_posix()
                ax = to_plot.plot(cmap=cmap, vmin=vmin, vmax=vmax)
                fig = ax.get_figure()
                plt.savefig(path, dpi=300)
                plt.close(fig)
            except Exception as e:  # TODO: Figure out what to do in case of a failure
                print(e)

In [None]:
def summary_plots(case, stream, diag_metadata):
    ds = case.history_contents[stream]
    varname = diag_metadata["varname"]
    isel_dict = diag_metadata.get("isel_dict")
    da = ds[varname].isel(isel_dict)
    summary_plot_maps(da, varname, isel_dict, diag_metadata)


summary_plots_delayed = dask.delayed(summary_plots)

In [None]:
case = utils.CaseClass(
    "g.e22.G1850ECO_JRA_HR.TL319_t13.004",
    start_date="0001-01",
    end_date="0003-12",
)
stream = "pop.h"
case._open_history_files(stream)

tasks = []
for diag_metadata in tqdm(diag_metadata_list):
    tasks.append(summary_plots_delayed(case, stream, diag_metadata))

In [None]:
x = client.compute(tasks)  # start computation in the background
progress(x)  # watch progress
wait(x)

In [None]:
cluster.close()