In [None]:
import xarray as xr
import matplotlib.pyplot as plt

import sys
sys.path.append("../../")

from config.config import Config

opt = Config(
    raw_ismn_global_path='../../data/raw/Data_separate_files_header_20000101_20201231_9562_Crun_20230723.zip', 
    raw_ismn_tibetan_path='../../data/raw/Data_separate_files_header_19500101_20230826_9562_asrG_20230826.zip', 
    era5_path='../../settings/data.nc'
)
tibetan_coords = opt.tibetan_coords

ds_mask = xr.open_dataset("../../data/raw/tibetan_mask.nc")
ds_mask_interp = ds_mask.interp(lon=tibetan_coords["lon"], lat=tibetan_coords["lat"])

In [None]:
# 1, 4, 7, 10
import os

from tqdm import tqdm

months = ["01", "04", "07", "10"]

# discussion
save_root = "../../data/plot/discussion_compare/month"
os.makedirs(save_root, exist_ok=True)

month_map = {
    "01": "January", 
    "04": "April", 
    "07": "July", 
    "10": "October" 
}

# change font size
factor = 1.5

for l in tqdm(range(1, 6)):
    # discussion
    num_cols = 3
    fig, axes = plt.subplots(nrows=len(months), ncols=num_cols, figsize=(num_cols * 6, len(months) * 3))
    for i, (month, ax) in enumerate(zip(months, axes)):
        # change fontsize
        for ax_element in ax:
            for ax_item in ([ax_element.title, ax_element.xaxis.label, ax_element.yaxis.label] + 
                ax_element.get_xticklabels() + ax_element.get_yticklabels()):
                ax_item.set_fontsize(ax_item.get_fontsize() * factor)
        # discussion
        ds_base = xr.open_dataset(f"../../data/compare/monthly/pred/layer{l}/{month}.nc")
        ds_best = xr.open_dataset(f"../../data/compare/monthly/discussion_best/layer{l}/{month}.nc")
        ds_ensemble = xr.open_dataset(f"../../data/compare/monthly/discussion_ensemble/layer{l}/{month}.nc")

        ds_base = ds_base.rename({"sm": "sm_base"})
        ds_best = ds_best.rename({"sm": "sm_best"})
        ds_ensemble = ds_ensemble.rename({"sm": "sm_ensemble"})

        ds = xr.merge([ds_base, ds_best, ds_ensemble, ds_mask_interp])

        ax0 = (ds["sm_base"] * ds["Band1"]).plot.pcolormesh(ax=ax[0], add_colorbar=False, ylim=(25, 42.5))
        ax1 = (ds["sm_best"] * ds["Band1"]).plot.pcolormesh(ax=ax[1], add_colorbar=False, ylim=(25, 42.5))
        ax2 = (ds["sm_ensemble"] * ds["Band1"]).plot.pcolormesh(ax=ax[2], add_colorbar=False, ylim=(25, 42.5))

        ax_lst = [ax0, ax1, ax2]
        if i == 0:
            ax[0].set_title("AMSMQTP_base", fontsize=12 * factor)
            ax[1].set_title("AMSMQTP_best", fontsize=12 * factor)
            ax[2].set_title("AMSMQTP_ensemble", fontsize=12 * factor)
        
        for ax_item in ax:
            ax_item.set_xlabel("")
            ax_item.set_ylabel("")

        ax[0].set_ylabel(month_map[month])

        for ax_element in ax[1:]:
            ax_element.set_yticks([])

        if i != 3:
            for ax_element in ax:
                ax_element.set_xticks([])

        for ai in ax_lst:
            cb = plt.colorbar(ai, orientation="vertical")
            cb.set_label(label="SM[$m^3/m^3$]") # colorbar label
            cb.ax.tick_params(labelsize=10 * factor)

    plt.tight_layout()

    # plt.savefig(os.path.join(save_root, f"layer{l}.pdf"))
    plt.savefig(os.path.join(save_root, f"layer{l}.png"), dpi=300)


In [None]:
# 2020, 2015, 2010, 2005
import os

years = [2020, 2015, 2010, 2005]

# discussion
save_root = "../../data/plot/discussion_compare/year"
os.makedirs(save_root, exist_ok=True)

for l in tqdm(range(1, 6)):
    # discussion
    num_cols = 3
    fig, axes = plt.subplots(nrows=len(years), ncols=num_cols, figsize=(num_cols * 6, len(years) * 3))
    for i, (year, ax) in enumerate(zip(years, axes)):
        # change fontsize
        for ax_element in ax:
            for ax_item in ([ax_element.title, ax_element.xaxis.label, ax_element.yaxis.label] + 
                ax_element.get_xticklabels() + ax_element.get_yticklabels()):
                ax_item.set_fontsize(ax_item.get_fontsize() * factor)
        ds_base = xr.open_dataset(f"../../data/compare/yearly/pred/layer{l}/{year}.nc")
        ds_best = xr.open_dataset(f"../../data/compare/yearly/discussion_best/layer{l}/{year}.nc")
        ds_ensemble = xr.open_dataset(f"../../data/compare/yearly/discussion_ensemble/layer{l}/{year}.nc")

        ds_base = ds_base.rename({"sm": "sm_base"})
        ds_best = ds_best.rename({"sm": "sm_best"})
        ds_ensemble = ds_ensemble.rename({"sm": "sm_ensemble"})

        ds = xr.merge([ds_base, ds_best, ds_ensemble, ds_mask_interp])

        ax0 = (ds["sm_base"] * ds["Band1"]).plot.pcolormesh(ax=ax[0], add_colorbar=False, ylim=(25, 42.5))
        ax1 = (ds["sm_best"] * ds["Band1"]).plot.pcolormesh(ax=ax[1], add_colorbar=False, ylim=(25, 42.5))
        ax2 = (ds["sm_ensemble"] * ds["Band1"]).plot.pcolormesh(ax=ax[2], add_colorbar=False, ylim=(25, 42.5))
        
        ax_lst = [ax0, ax1, ax2]
        if i == 0:
            ax[0].set_title("AMSMQTP_base", fontsize=12 * factor)
            ax[1].set_title("AMSMQTP_best", fontsize=12 * factor)
            ax[2].set_title("AMSMQTP_ensemble", fontsize=12 * factor)

        for ax_item in ax:
            ax_item.set_xlabel("")
            ax_item.set_ylabel("")

        ax[0].set_ylabel(str(year))

        for ax_element in ax[1:]:
            ax_element.set_yticks([])

        if i != 3:
            for ax_element in ax:
                ax_element.set_xticks([])

        for ai in ax_lst:
            cb = plt.colorbar(ai, orientation="vertical")
            cb.set_label(label="SM[$m^3/m^3$]") # colorbar label
            cb.ax.tick_params(labelsize=10 * factor)

    plt.tight_layout()

    # plt.savefig(os.path.join(save_root, f"layer{l}.pdf"))
    plt.savefig(os.path.join(save_root, f"layer{l}.png"), dpi=300)
