In [None]:
%load_ext autoreload
%autoreload 2

#--------------------IMPORT NECESSARY PACKAGES-------------------#

# OS interactions and calls:
import os
import subprocess as sbp
import sys

# project specific configurations:
from cf_monthly_forecast.config import *
import cf_monthly_forecast.plot_annotations as pla
import cf_monthly_forecast.plot_options_monthly as pom

# data access
import xarray as xr

# data processing
import numpy as np
from scipy import interpolate
from datetime import datetime
from calendar import monthrange

# plotting
import matplotlib.pyplot as plt
from cf_monthly_forecast.vis_utils import TWOCOLUMN_WIDTH_INCHES,SubplotFigure
from mpl_toolkits.basemap import Basemap

In [None]:
#--------------------SOME INPUT--------------------#

# fontsize:
FS = 9.
# figure width:
figw_inches = TWOCOLUMN_WIDTH_INCHES*.8

# splines used for interpolation
nsplines = 4

# 'model' key word for looking plotting parameters and units:
model = 'ens_mean_anom'

In [None]:
#--------------------CHECK FOR EXISTENCE OF FORECAST FILE--------------------#

# get current date during time of running the script:
# today = datetime.today()
today = datetime(2022,6,15)
initmonth = today.month
inityear = today.year

# check for existence of production files of single parameters:

plt_vars = []

for pvar in pom.variables:
    fname = '{0:s}/forecast_production_{3:s}_{1:d}_{2:d}.nc'.format(
        dirs['SFE_forecast'],
        inityear,
        initmonth,
        file_key[pvar]
    )

    if os.path.isfile(fname):
        plt_vars.append(pvar)

if len(plt_vars) == 0:
    print('{0:} (UTC)\tNo monthly forecast files exist yet for initialization {1:d}-{2:d}!\n'.format(datetime.now(),inityear,initmonth))
else:
    print('{0:} (UTC)\tMonthly forecast files exist for initialization {1:d}-{2:d}, creating anomaly plots for {3:}.\n'.format(datetime.now(),inityear,initmonth,plt_vars))

missing_vars = [pv for pv in pom.variables if pv not in plt_vars]

In [None]:
#--------------------MAKE A DIRECTORY FOR THE FIGURES--------------------#

# define where forecasts are located and where figures should be saved:
figdir = '{0:s}/monthly_fc/init_{1:s}-{2:s}/anomalies/'.format(
    dirs['public'],str(inityear).zfill(4),str(initmonth).zfill(2)
)

# create a folder for the initialization if it doesn't already exist:
if not os.path.exists(figdir):
    os.makedirs(figdir,exist_ok=False) # creates directories recursively!

In [None]:
for variable in plt_vars:

    #--------------------LOAD FORECAST DATA FOR REQUESTED VARIABLE--------------------#
    varf_name = file_key[variable]
    FILE = '{0:s}/forecast_production_{3:s}_{1:d}_{2:d}.nc'.format(dirs['SFE_forecast'],inityear,initmonth,varf_name)
    
    ds = xr.open_dataset(FILE)

    # get grid info as arrays for plotting and interpolation:
    LON,LAT = np.meshgrid(ds.lon,ds.lat)

    #--------------------FORECAST MONTHS TO LOOP OVER--------------------#
    # loop over forecast months & note that index 0 is forecast month 1!! (e.g. May init, index 0 has June monthly mean)
    FCMONTHS = np.array(ds.variables['target_month'][:],dtype=int)
    FCYEARS = []
    for mm in FCMONTHS:
        if mm >= initmonth:
            FCYEARS.append(inityear)
        else:
            FCYEARS.append(inityear+1)

    # choose a subset of forecast months to plot:
    subset = slice(0,None)
    FCMONTH = FCMONTHS[subset]
    FCYEAR = FCYEARS[subset]

    for fcmonth,fcyear in zip(FCMONTH,FCYEAR):
        
        for area in pom.DOMAINS:
            
            if area == 'EUROPE_SMALL':
                lon0 = 0
                lon1 = 30
                lat0 = 54
                lat1 = 71.5
                bm = Basemap(
                    resolution = 'i', 
                    projection = 'gall',
                    llcrnrlon = 0,
                    llcrnrlat = 54,
                    urcrnrlon = 30,
                    urcrnrlat = 71.5,
                    fix_aspect = False
                )
                aspectratio = 1.1
            elif area == 'EUROPE':
                lon0 = -20
                lon1 = 50
                lat0 = 35
                lat1 = 72
                bm = Basemap(
                    resolution = 'l', 
                    projection = 'gall',
                    llcrnrlon = -20.,
                    llcrnrlat = 35.,
                    urcrnrlon = 50.,
                    urcrnrlat = 72.,
                    fix_aspect = False
                )
                aspectratio = 1.5
            elif area == 'GLOBAL':
                lon0 = -180
                lon1 = 180
                lat0 = -90
                lat1 = 90
                bm = Basemap(projection='moll',lon_0=0,resolution='c')
                aspectratio = 2.05
                
            # weights:
            glat0 = 40
            glat1 = 70
            gpoints = np.nonzero((LAT.ravel()>=glat0)&(LAT.ravel()<=glat1))[0]
            gweights = np.cos(np.radians(LAT.ravel()[gpoints]))
            gweights /= np.sum(gweights)
            points = np.nonzero((LON.ravel()>=lon0)&(LON.ravel()<=lon1)&(LAT.ravel()>=lat0)&(LAT.ravel()<=lat1))[0]
            weights = np.cos(np.radians(LAT.ravel()[points]))
            weights /= np.sum(weights)

            if nsplines:
                lon2 = np.linspace(LON[0,0],LON[0,-1],LON.shape[1]*nsplines)
                lat2 = np.linspace(LAT[0,0],LAT[-1,0],LAT.shape[0]*nsplines)
                lon3,lat3 = np.meshgrid(lon2,lat2)
                xi,yi = bm(lon3,lat3)
            else:
                xi,yi = bm(LON,LAT)
                xp,yp = bm(LON-.25,LAT-.25)
        
            if model == 'ens_mean_anom':
                if variable == 'pr':
                    a = (ds.mean_standardized_anomaly * ds.sd_era).sel(target_month=fcmonth).values * units_tf_factor[variable] * monthrange(fcyear,fcmonth)[1]
                else:
                    a = (ds.mean_standardized_anomaly * ds.sd_era).sel(target_month=fcmonth).values * units_tf_factor[variable]
            try:
                cv = pom.cvs[model][variable][area]
            except:
                try:
                    cv = pom.cvs[model][variable]
                except:
                    cv = pom.cvs[model]
            ticks = cv
            fmt = pom.FMT[variable]
            try:
                cmapname = pom.cmapnames[model][variable]
            except:
                cmapname = pom.cmapnames[model]
            cmap = plt.get_cmap(cmapname,len(cv)-1)

            # Compute area average:
            gavg = np.sum(a.ravel()[gpoints]*gweights)
            avg = np.sum(a.ravel()[points]*weights)
            print(model,avg)
            
            print(variable,np.min(a),np.max(a))
            for lang in pom.langs:
                mstr = pla.monthnames[lang][fcmonth-1]
                title = ''
                if model in ('ens_mean_anom',):
                    title = {
                        'en': '{2:s} Ensemble Mean Anomaly {0:s} {1:d}'.format(
                            mstr,
                            fcyear,
                            long_names[variable]['en']
                        ),
                        'no': '{2:s} gjennomsnittige anomali {0:s} {1:d}'.format(
                            mstr,
                            fcyear,
                            long_names[variable]['no']
                        )
                    }[lang]

                fig = SubplotFigure(
                    figw_inches = figw_inches,
                    aspectratio = aspectratio,
                    marginleft_inches = 0.05,
                    marginright_inches = 0.05,
                    margintop_inches = 0.45,
                    marginbottom_inches = 0.05,
                    cbar_height_inches = .15,
                    cbar_bottompadding_inches = .25,
                    cbar_toppadding_inches = .05,
                    cbar_width_percent = 95.
                )
                ax = fig.subplot(0)
                if nsplines:
                    print('linearly interpolating data to {0:d}x the resolution'.format(nsplines))
                    f = interpolate.interp2d(LON[0,:], LAT[:,0], a, kind='linear')
                    a = f(lon2,lat2)
                spr = cv[-1]-cv[0]
                a[a<=cv[0]] = cv[0]+spr/1000.
                a[a>=cv[-1]] = cv[-1]-spr/1000.
                levels = np.arange(cv[0],cv[-1])
                hatches = [None]*(len(cv)-1)
                # Plot probabilities:
                cf = ax.contourf(xi,yi,a,cv,cmap=cmap,vmin=cv[0],vmax=cv[-1],hatches=hatches,extend='both')
                bm.drawcoastlines(linewidth=.5)
                bm.drawcountries(linewidth=.35,color='.5')
                if not area in ('GLOBAL',):
                    tkw = {
                        'horizontalalignment':'left',
                        'verticalalignment':'top',
                        'transform':ax.transAxes
                    }
                    t = '%s'%{'no':'Varsel fra','en':'Forecast from'}[lang]
                    t += ' Climate Futures'
                    plt.text(0.01,.99,t,fontweight='bold',fontsize=FS-2,**tkw)
                    t = {'no':'Finansiert av Forskningsrådet','en':'Funded by the Research Council of Norway'}[lang]
                    t += '\n%s:'%{'no':'Basert på data fra','en':'Based on data from'}[lang]
                    t += '\nECMWF (%s)'%{'no':'Europa','en':'Europe'}[lang]
                    t += '\nUK Met Office (%s)'%{'no':'Storbritannia','en':'UK'}[lang]
                    t += '\nCMCC (%s)'%{'no':'Italia','en':'Italy'}[lang]
                    t += '\nMétéo France (%s)'%{'no':'Frankrike','en':'France'}[lang]
                    t += '\nDWD (%s)'%{'no':'Tyskland','en':'Germany'}[lang]
                    t += '\nBjerknes Centre (%s)'%{'no':'Norge','en':'Norway'}[lang]
                    t += '\n{0:s} {1:d} {2:s} {3:d}'.format({'no':'Utarbeidet','en':'Produced'}[lang],today.day,pla.monthnames[lang][today.month-1],today.year)
                    plt.text(0.01,.94,t,fontsize=FS-4,**tkw)
                plt.title(title,fontsize = FS-1)
                desc = dict(
                    en = 'Anomaly ({0:s})'.format(units_plot[variable]),
                    no = 'Anomali ({0:s})'.format(units_plot[variable])
                    )[lang]
                rightmargin = 0
                if area in ('GLOBAL',):
                    rightmargin = 0.05
                fig.draw_colorbar(
                    mappable = cf,
                    fontsize=FS-1, 
                    cmap = cmap, 
                    vmin = cv[0], 
                    vmax = cv[-1],
                    desc = desc,
                    ticks = ticks,
                    rightmargin = rightmargin,
                    fmt = fmt,
                    extend='both'
                )
                filename = 'fc_{0:s}_{1:s}_{2:s}_{3:s}_{4:s}'.format(
                    variable,
                    str(fcmonth).zfill(2),
                    model,
                    area,
                    lang
                )
                
                print(filename)

                # fig.fig.savefig('{0:s}{1:s}.png'.format(figdir,filename),dpi=300)
                # plt.close(fig.fig)
    
    if variable == 'pr':
        print('relative anomalies')

        for fcmonth,fcyear in zip(FCMONTH,FCYEAR):

            for area in pom.DOMAINS:

                if area == 'EUROPE_SMALL':
                    lon0 = 0
                    lon1 = 30
                    lat0 = 54
                    lat1 = 71.5
                    bm = Basemap(
                        resolution = 'i', 
                        projection = 'gall',
                        llcrnrlon = 0,
                        llcrnrlat = 54,
                        urcrnrlon = 30,
                        urcrnrlat = 71.5,
                        fix_aspect = False
                    )
                    aspectratio = 1.1
                elif area == 'EUROPE':
                    lon0 = -20
                    lon1 = 50
                    lat0 = 35
                    lat1 = 72
                    bm = Basemap(
                        resolution = 'l', 
                        projection = 'gall',
                        llcrnrlon = -20.,
                        llcrnrlat = 35.,
                        urcrnrlon = 50.,
                        urcrnrlat = 72.,
                        fix_aspect = False
                    )
                    aspectratio = 1.5
                elif area == 'GLOBAL':
                    lon0 = -180
                    lon1 = 180
                    lat0 = -90
                    lat1 = 90
                    bm = Basemap(projection='moll',lon_0=0,resolution='c')
                    aspectratio = 2.05

                # weights:
                glat0 = 40
                glat1 = 70
                gpoints = np.nonzero((LAT.ravel()>=glat0)&(LAT.ravel()<=glat1))[0]
                gweights = np.cos(np.radians(LAT.ravel()[gpoints]))
                gweights /= np.sum(gweights)
                points = np.nonzero((LON.ravel()>=lon0)&(LON.ravel()<=lon1)&(LAT.ravel()>=lat0)&(LAT.ravel()<=lat1))[0]
                weights = np.cos(np.radians(LAT.ravel()[points]))
                weights /= np.sum(weights)

                if nsplines:
                    lon2 = np.linspace(LON[0,0],LON[0,-1],LON.shape[1]*nsplines)
                    lat2 = np.linspace(LAT[0,0],LAT[-1,0],LAT.shape[0]*nsplines)
                    lon3,lat3 = np.meshgrid(lon2,lat2)
                    xi,yi = bm(lon3,lat3)
                else:
                    xi,yi = bm(LON,LAT)
                    xp,yp = bm(LON-.25,LAT-.25)
            
                if model == 'ens_mean_anom':
                    a = (ds.mean_standardized_anomaly * ds.sd_era / ds.climatology_era).sel(target_month=fcmonth).values * 100
                
                if area == 'EUROPE':
                    cv = np.linspace(-50,50,11)
                elif area == 'GLOBAL':
                    cv = np.linspace(-100,100,11)
                ticks = cv
                fmt = pom.FMT[variable]
                try:
                    cmapname = pom.cmapnames[model][variable]
                except:
                    cmapname = pom.cmapnames[model]
                cmap = plt.get_cmap(cmapname,len(cv)-1)

                # Compute area average:
                gavg = np.sum(a.ravel()[gpoints]*gweights)
                avg = np.sum(a.ravel()[points]*weights)
                print(model,avg)
                
                print(variable,np.min(a),np.max(a))
                for lang in pom.langs:
                    mstr = pla.monthnames[lang][fcmonth-1]
                    title = ''
                    if model in ('ens_mean_anom',):
                        title = {
                            'en': '{2:s} Ensemble Mean Percent Deviation from Climatology {0:s} {1:d}'.format(
                                mstr,
                                fcyear,
                                long_names[variable]['en']
                            ),
                            'no': '{2:s} gjennomsnittige anomali {0:s} {1:d}'.format(
                                mstr,
                                fcyear,
                                long_names[variable]['no']
                            )
                        }[lang]

                    fig = SubplotFigure(
                        figw_inches = figw_inches,
                        aspectratio = aspectratio,
                        marginleft_inches = 0.05,
                        marginright_inches = 0.05,
                        margintop_inches = 0.45,
                        marginbottom_inches = 0.05,
                        cbar_height_inches = .15,
                        cbar_bottompadding_inches = .25,
                        cbar_toppadding_inches = .05,
                        cbar_width_percent = 95.
                    )
                    ax = fig.subplot(0)
                    if nsplines:
                        print('linearly interpolating data to {0:d}x the resolution'.format(nsplines))
                        if np.isinf(a).sum() > 0:
                            print('Some points are inf!')
                        a[np.isinf(a)] = np.sign(a[np.isinf(a)]) * 999
                        f = interpolate.interp2d(LON[0,:], LAT[:,0], a, kind='linear')
                        a = f(lon2,lat2)
                    spr = cv[-1]-cv[0]
                    a[a<=cv[0]] = cv[0]+spr/1000.
                    a[a>=cv[-1]] = cv[-1]-spr/1000.
                    levels = np.arange(cv[0],cv[-1])
                    hatches = [None]*(len(cv)-1)
                    # Plot probabilities:
                    cf = ax.contourf(xi,yi,a,cv,cmap=cmap,vmin=cv[0],vmax=cv[-1],hatches=hatches,extend='both')
                    bm.drawcoastlines(linewidth=.5)
                    bm.drawcountries(linewidth=.35,color='.5')
                    if not area in ('GLOBAL',):
                        tkw = {
                            'horizontalalignment':'left',
                            'verticalalignment':'top',
                            'transform':ax.transAxes
                        }
                        t = '%s'%{'no':'Varsel fra','en':'Forecast from'}[lang]
                        t += ' Climate Futures'
                        plt.text(0.01,.99,t,fontweight='bold',fontsize=FS-2,**tkw)
                        t = {'no':'Finansiert av Forskningsrådet','en':'Funded by the Research Council of Norway'}[lang]
                        t += '\n%s:'%{'no':'Basert på data fra','en':'Based on data from'}[lang]
                        t += '\nECMWF (%s)'%{'no':'Europa','en':'Europe'}[lang]
                        t += '\nUK Met Office (%s)'%{'no':'Storbritannia','en':'UK'}[lang]
                        t += '\nCMCC (%s)'%{'no':'Italia','en':'Italy'}[lang]
                        t += '\nMétéo France (%s)'%{'no':'Frankrike','en':'France'}[lang]
                        t += '\nDWD (%s)'%{'no':'Tyskland','en':'Germany'}[lang]
                        t += '\nBjerknes Centre (%s)'%{'no':'Norge','en':'Norway'}[lang]
                        t += '\n{0:s} {1:d} {2:s} {3:d}'.format({'no':'Utarbeidet','en':'Produced'}[lang],today.day,pla.monthnames[lang][today.month-1],today.year)
                        plt.text(0.01,.94,t,fontsize=FS-4,**tkw)
                    plt.title(title,fontsize = FS-1)
                    desc = dict(
                        en = 'Anomaly (%)',
                        no = 'Anomali (%)'
                        )[lang]
                    rightmargin = 0
                    if area in ('GLOBAL',):
                        rightmargin = 0.05
                    fig.draw_colorbar(
                        mappable = cf,
                        fontsize=FS-1, 
                        cmap = cmap, 
                        vmin = cv[0], 
                        vmax = cv[-1],
                        desc = desc,
                        ticks = ticks,
                        rightmargin = rightmargin,
                        fmt = fmt,
                        extend='both'
                    )
                    filename = 'fc_{0:s}_{1:s}_rel_{2:s}_{3:s}_{4:s}'.format(
                        variable,
                        str(fcmonth).zfill(2),
                        model,
                        area,
                        lang
                    )
                    
                    print(filename)

                    # fig.fig.savefig('{0:s}{1:s}.png'.format(figdir,filename),dpi=300)
                    # plt.close(fig.fig)