In [None]:
import xarray as xr
import matplotlib.pyplot as plt

import sys
sys.path.append("../../")

from config.config import Config

opt = Config(
    raw_ismn_global_path='../../data/raw/Data_separate_files_header_20000101_20201231_9562_Crun_20230723.zip', 
    raw_ismn_tibetan_path='../../data/raw/Data_separate_files_header_19500101_20230826_9562_asrG_20230826.zip', 
    era5_path='../../settings/data.nc'
)
tibetan_coords = opt.tibetan_coords

ds_mask = xr.open_dataset("../../data/raw/tibetan_mask.nc")
ds_mask_interp = ds_mask.interp(lon=tibetan_coords["lon"], lat=tibetan_coords["lat"])

In [None]:
# 1, 4, 7, 10
import os

from tqdm import tqdm

months = ["01", "04", "07", "10"]

save_root = "../../data/plot/compare/month"
os.makedirs(save_root, exist_ok=True)

era5_map = {
    'layer1': 'swvl1', 
    'layer2': 'swvl2', 
    'layer3': 'swvl3', 
    'layer4': 'swvl3', 
    'layer5': 'swvl3'
}
gldas_map = {
    "layer1": "SoilMoi0_10cm_inst", 
    "layer2": "SoilMoi10_40cm_inst", 
    "layer3": "SoilMoi10_40cm_inst", 
    "layer4": "SoilMoi40_100cm_inst", 
    "layer5": "SoilMoi40_100cm_inst"
}
month_map = {
    "01": "January", 
    "04": "April", 
    "07": "July", 
    "10": "October" 
}

# change font size
factor = 1.5

for l in tqdm(range(1, 6)):
    num_cols = 3 if l == 1 else 4
    fig, axes = plt.subplots(nrows=len(months), ncols=num_cols, figsize=(num_cols * 6, len(months) * 3))
    for i, (month, ax) in enumerate(zip(months, axes)):
        # change fontsize
        for ax_element in ax:
            for ax_item in ([ax_element.title, ax_element.xaxis.label, ax_element.yaxis.label] + 
                ax_element.get_xticklabels() + ax_element.get_yticklabels()):
                ax_item.set_fontsize(ax_item.get_fontsize() * factor)
        ds_gldas = xr.open_dataset(f"../../data/compare/monthly/gldas/layer{l}/{month}.nc")
        ds_era5 = xr.open_dataset(f"../../data/compare/monthly/era5/layer{l}/{month}.nc")
        ds_era5 = ds_era5.rename({"latitude": "lat", "longitude": "lon"})
        ds_pred = xr.open_dataset(f"../../data/compare/monthly/pred/layer{l}/{month}.nc")
        compare_lst = [ds_era5, ds_gldas]

        if l != 1:
            ds_smci = xr.open_dataset(f"../../data/compare/monthly/smci/layer{l}/{month}.nc")
            compare_lst.append(ds_smci)

        ds = xr.merge([ds_pred] + compare_lst + [ds_mask_interp])

        ax0 = ((ds[era5_map[f"layer{l}"]] - ds["sm"]) * ds["Band1"]).plot.pcolormesh(ax=ax[0], add_colorbar=False, ylim=(25, 42.5))
        ax1 = ((ds[gldas_map[f"layer{l}"]] - ds["sm"]) * ds["Band1"]).plot.pcolormesh(ax=ax[1], add_colorbar=False, ylim=(25, 42.5))
        ax_lst = [ax0, ax1]
        # "label": "SM[$m^3/m^3$]"
        if l != 1:
            ax2 = (ds["sm"] * ds["Band1"]).plot.pcolormesh(ax=ax[3], add_colorbar=False, ylim=(25, 42.5))
            ax3 = ((ds["SMCI"] * 0.001 - ds["sm"]) * ds["Band1"]).plot.pcolormesh(ax=ax[2], add_colorbar=False, ylim=(25, 42.5))
            ax_lst += [ax2, ax3]
        else:
            ax4 = (ds["sm"] * ds["Band1"]).plot.pcolormesh(ax=ax[2], add_colorbar=False, ylim=(25, 42.5))
            ax_lst.append(ax4)

        if i == 0:
            ax[0].set_title(f"ERA5-Land - AMSMQTP_base", fontsize=12 * factor)
            ax[1].set_title(f"GLDAS-2.1 - AMSMQTP_base", fontsize=12 * factor)

            if l == 1:
                ax[2].set_title("AMSMQTP_base", fontsize=12 * factor)
            else:
                ax[2].set_title("SMCI1.0_9km - AMSMQTP_base", fontsize=12 * factor)
                ax[3].set_title("AMSMQTP_base", fontsize=12 * factor)

        for ax_item in ax:
            ax_item.set_xlabel("")
            ax_item.set_ylabel("")

        ax[0].set_ylabel(month_map[month])

        for ax_element in ax[1:]:
            ax_element.set_yticks([])

        if i != 3:
            for ax_element in ax:
                ax_element.set_xticks([])

        for ai in ax_lst:
            cb = plt.colorbar(ai, orientation="vertical")
            cb.set_label(label="SM[$m^3/m^3$]") # colorbar label
            cb.ax.tick_params(labelsize=10 * factor)

    plt.tight_layout()

    # plt.savefig(os.path.join(save_root, f"layer{l}.pdf"))
    
    plt.savefig(os.path.join(save_root, f"layer{l}.png"), dpi=300)

In [None]:
# 2020, 2015, 2010, 2005
import os

years = [2020, 2015, 2010, 2005]

save_root = "../../data/plot/compare/year"
os.makedirs(save_root, exist_ok=True)

era5_map = {
    'layer1': 'swvl1', 
    'layer2': 'swvl2', 
    'layer3': 'swvl3', 
    'layer4': 'swvl3', 
    'layer5': 'swvl3'
}
gldas_map = {
    "layer1": "SoilMoi0_10cm_inst", 
    "layer2": "SoilMoi10_40cm_inst", 
    "layer3": "SoilMoi10_40cm_inst", 
    "layer4": "SoilMoi40_100cm_inst", 
    "layer5": "SoilMoi40_100cm_inst"
}
# change fontsize
factor = 1.5

for l in tqdm(range(1, 6)):
    num_cols = 3 if l == 1 else 4
    fig, axes = plt.subplots(nrows=len(years), ncols=num_cols, figsize=(num_cols * 6, len(years) * 3))
    for i, (year, ax) in enumerate(zip(years, axes)):
        # change fontsize
        for ax_element in ax:
            for ax_item in ([ax_element.title, ax_element.xaxis.label, ax_element.yaxis.label] + 
                ax_element.get_xticklabels() + ax_element.get_yticklabels()):
                ax_item.set_fontsize(ax_item.get_fontsize() * factor)
        ds_gldas = xr.open_dataset(f"../../data/compare/yearly/gldas/layer{l}/{year}.nc")
        ds_era5 = xr.open_dataset(f"../../data/compare/yearly/era5/layer{l}/{year}.nc")
        ds_era5 = ds_era5.rename({"latitude": "lat", "longitude": "lon"})
        ds_pred = xr.open_dataset(f"../../data/compare/yearly/pred/layer{l}/{year}.nc")
        compare_lst = [ds_era5, ds_gldas]

        if l != 1:
            ds_smci = xr.open_dataset(f"../../data/compare/yearly/smci/layer{l}/{year}.nc")
            compare_lst.append(ds_smci)

        ds = xr.merge([ds_pred] + compare_lst + [ds_mask_interp])

        ax0 = ((ds[era5_map[f"layer{l}"]] - ds["sm"]) * ds["Band1"]).plot.pcolormesh(ax=ax[0], add_colorbar=False, ylim=(25, 42.5))
        ax1 = ((ds[gldas_map[f"layer{l}"]] - ds["sm"]) * ds["Band1"]).plot.pcolormesh(ax=ax[1], add_colorbar=False, ylim=(25, 42.5))
        ax_lst = [ax0, ax1]

        if l != 1:
            ax2 = (ds["sm"] * ds["Band1"]).plot.pcolormesh(ax=ax[3], add_colorbar=False, ylim=(25, 42.5))
            ax3 = ((ds["SMCI"] * 0.001 - ds["sm"]) * ds["Band1"]).plot.pcolormesh(ax=ax[2], add_colorbar=False, ylim=(25, 42.5))
            ax_lst += [ax2, ax3]
        else:
            ax4 = (ds["sm"] * ds["Band1"]).plot.pcolormesh(ax=ax[2], add_colorbar=False, ylim=(25, 42.5))
            ax_lst.append(ax4)

        if i == 0:
            ax[0].set_title(f"ERA5-Land - AMSMQTP_base", fontsize=12 * factor)
            ax[1].set_title(f"GLDAS-2.1 - AMSMQTP_base", fontsize=12 * factor)

            if l == 1:
                ax[2].set_title("AMSMQTP_base", fontsize=12 * factor)
            else:
                ax[2].set_title("SMCI1.0_9km - AMSMQTP_base", fontsize=12 * factor)
                ax[3].set_title("AMSMQTP_base", fontsize=12 * factor)

        for ax_item in ax:
            ax_item.set_xlabel("")
            ax_item.set_ylabel("")

        ax[0].set_ylabel(str(year))

        for ax_element in ax[1:]:
            ax_element.set_yticks([])

        if i != 3:
            for ax_element in ax:
                ax_element.set_xticks([])

        for ai in ax_lst:
            cb = plt.colorbar(ai, orientation="vertical")
            cb.set_label(label="SM[$m^3/m^3$]") # colorbar label
            cb.ax.tick_params(labelsize=10 * factor)

    plt.tight_layout()

    # plt.savefig(os.path.join(save_root, f"layer{l}.pdf"))
    plt.savefig(os.path.join(save_root, f"layer{l}.png"), dpi=300)
