In [1]:
import xarray as xr
import matplotlib.pyplot as plt

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import xesmf as xe

import os

import sys
sys.path.append('../')
from utils import DataFinder

In [None]:
# setting up constants for each combination

const_dict = {
    'tas': {
        "monthly":{
            'data':{
                'cmap':'viridis',
                'vmin':-40,
                'vmax':40,
                'cbar_ticks':[-40,-20,0,20,40],
            },
            'error':{
                'cmap':'RdBu_r',
                'vmin':-20,
                'vmax':20,
                'cbar_ticks':[-20,-10,0,10,20],
            },
        },
        "annual":{
            'data':{
                'cmap':'viridis',
                'vmin':-40,
                'vmax':40,
                'cbar_ticks':[-40,-20,0,20,40],
            },
            'error':{
                'cmap':'RdBu_r',
                'vmin':-20,
                'vmax':20,
                'cbar_ticks':[-20,-10,0,10,20],
            },
        },
        'obs_path':'../observational_data/tas_HadCRUT5.zarr',
    },
    'pr': {
        "monthly":{
            'data':{
                'cmap':'viridis',
                'vmin':0,
                'vmax':0.0005,
                'cbar_ticks':[0,0.00025,0.0005],
            },
            'error':{
                'cmap':'RdBu',
                'vmin':-0.00025,
                'vmax':0.00025,
                'cbar_ticks':[-0.00025,0,0.00025],
            },
        },
        "annual":{
            'data':{
                'cmap':'viridis',
                'vmin':0,
                'vmax':0.002,
                'cbar_ticks':[0,0.001,0.002],
            },
            'error':{
                'cmap':'RdBu',
                'vmin':-0.001,
                'vmax':0.001,
                'cbar_ticks':[-0.0010,-0.005,0,0.005,0.0010],
            },
        },
        'obs_path':'../observational_data/pr_noaa_gpcp.zarr',
    },
    'clt': {
        "monthly":{
            'data':{
                'cmap':'viridis',
                'vmin':0,
                'vmax':100,
                'cbar_ticks':[0,50,100],
            },
            'error':{
                'cmap':'RdBu_r',
                'vmin':-50,
                'vmax':50,
                'cbar_ticks':[-50,0,50],
            },
        },
        "annual":{
            'data':{
                'cmap':'viridis',
                'vmin':0,
                'vmax':100,
                'cbar_ticks':[0,50,100],
            },
            'error':{
                'cmap':'RdBu_r',
                'vmin':-50,
                'vmax':50,
                'cbar_ticks':[-50,0,50],
            },
        },
        'obs_path':'../observational_data/clt_nasa_modis.zarr',
    },
    'od550aer': {
        "monthly":{
            'data':{
                'cmap':'viridis',
                'vmin':0,
                'vmax':1,
                'cbar_ticks':[0,0.5,1],
            },
            'error':{
                'cmap':'RdBu_r',
                'vmin':-0.5,
                'vmax':0.5,
                'cbar_ticks':[-0.5,0,0.5],
            },
        },
        "annual":{
            'data':{
                'cmap':'viridis',
                'vmin':0,
                'vmax':1,
                'cbar_ticks':[0,0.5,1],
            },
            'error':{
                'cmap':'RdBu_r',
                'vmin':-0.5,
                'vmax':0.5,
                'cbar_ticks':[-0.5,0,0.5],
            },
        },
        'obs_path':'../observational_data/od550aer_nasa_modis.zarr',
    },
    'tos': {
        "monthly":{
            'data':{
                'cmap':'viridis',
                'vmin':0,
                'vmax':30,
                'cbar_ticks':[0,10,20,30],
            },
            'error':{
                'cmap':'RdBu_r',
                'vmin':-10,
                'vmax':10,
                'cbar_ticks':[-10,0,10],
            },
        },
        "annual":{
            'data':{
                'cmap':'viridis',
                'vmin':0,
                'vmax':30,
                'cbar_ticks':[0,10,20,30],
            },
            'error':{
                'cmap':'RdBu_r',
                'vmin':-10,
                'vmax':10,
                'cbar_ticks':[-10,0,10],
            },
        },
        'obs_path':'../observational_data/tos_noaa_oisst.zarr',
    },
}

In [None]:
for model in ["CESM2-WACCM", "CanESM5", "MPI-ESM1-2-LR", "IPSL-CM6A-LR"]:
    for variable in ['tas','pr','clt','od550aer', 'tos']:
        for resolution in ['annual','monthly']:
            for type in ['data','error']:
                print(f"running for combination {model}, {variable}, {resolution}, {type}")
                      
                data_finder = DataFinder(variable=variable,model=model,start_year=2005,end_year=2024)
                model_ds = data_finder.load_model_ds()

                obs_ds = xr.open_zarr(const_dict[variable]['obs_path']).sel(time=slice('2005-01-01','2024-12-31'))

                regridder = xe.Regridder(
                    obs_ds, model_ds[["lat", "lon"]], "bilinear", periodic=True
                )
                obs_rg_ds = regridder(obs_ds[variable], keep_attrs=True)

                if resolution == 'annual':
                    if variable == 'pr':
                        plot_ds = model_ds.groupby('time.year').sum()
                        obs_annual = obs_rg_ds.groupby('time.year').sum()
                    else:
                        plot_ds = model_ds.groupby('time.year').mean()
                        obs_annual = obs_rg_ds.groupby('time.year').mean()

                    if type == "error":
                        plot_ds = plot_ds.sel(year=obs_annual.year.values) - obs_annual
                    else:
                        if variable == 'tas':
                            plot_ds = plot_ds - 273.15
                        else:
                            plot_ds = plot_ds

                    plot_ds = plot_ds.rename({'year':'time'})
                    
                else:
                    if type == "error":
                        plot_ds = model_ds.sel(time=obs_rg_ds.time.values) - obs_rg_ds
                    else:
                        if variable == 'tas':
                            plot_ds = model_ds - 273.15
                        else:
                            plot_ds = model_ds

                # using "ClimateBench_app2" to not overwrite any production maps
                path = f'../../../ClimateBench_app2/data/images/{model}/{resolution}/{variable}/{type}/'
                os.makedirs(path,exist_ok=True)

                for date in plot_ds.time.values:
                    
                    if resolution == 'annual':
                        file = f'{str(date)[:10]}-01-01.png'
                    else:
                        file = f'{str(date)[:10]}.png'
                    file_path = path + file
                    # Create a figure and axes with a map projection
                    if variable != "tos":
                        fig, ax = plt.subplots(subplot_kw={'projection': ccrs.PlateCarree()},figsize=(10,5))
                        # Plot the temperature variable
                        img = plot_ds[variable].sel(time=date).plot(ax=ax, transform=ccrs.PlateCarree(), vmin=const_dict[variable][resolution][type]['vmin'], vmax=const_dict[variable][resolution][type]['vmax'],add_colorbar=False,cmap=const_dict[variable][resolution][type]['cmap'])
                        ax.coastlines()
                        cb = fig.colorbar(img,label=' ', ticks=const_dict[variable][resolution][type]['cbar_ticks'])
                        cb.ax.tick_params(labelsize=15)

                    else:
                        fig, axis = plt.subplots(1, 1, figsize=(9,5),subplot_kw=dict(projection=ccrs.Robinson(central_longitude=180)))

                        cax = plot_ds[variable].sel(time=date).plot(
                            ax=axis,transform=ccrs.PlateCarree(), 
                            vmin=const_dict[variable][resolution][type]['vmin'],
                            vmax=const_dict[variable][resolution][type]['vmax'],
                            cmap=const_dict[variable][resolution][type]['cmap'],
                            add_colorbar=False
                        )
                        cbar = fig.colorbar(cax, ticks=const_dict[variable][resolution][type]['cbar_ticks'])
                        axis.coastlines()  # cartopy function

                    axis.set_title(model,fontsize=20)

                    fig.savefig(file_path,transparent=True,bbox_inches='tight')
                    plt.close(fig)
