In [75]:

def draw_map(lw=0.3):
    wr50a_map.m.drawmapboundary(fill_color=(0.9, 0.9, 0.9))
    wr50a_map.m.fillcontinents(color='white', zorder=0)
    wr50a_map.m.drawparallels(np.arange(-80., 81., 20.), linewidth=lw)
    wr50a_map.m.drawmeridians(np.arange(-180., 181., 20.), linewidth=lw)
    wr50a_map.m.drawcoastlines(color='k', linewidth=lw)
    
def plot_anoms(monthly_means, vmin=-25, vmax=25, smin=0, smax=3, amin=-4.5, amax=4.5,
                     cmap='Spectral_r', smap='YlGnBu', amap='RdBu', cbar_extend='both',
                     sbar_extend='max', abar_extend='both', cbar_label= 'Mean', era_interim=True,
                     varname=None, units='-'):
    with sns.axes_style("white"):

        # Set colorbar norms and ticks
        # assert(type(cmap) == str)
        cn = 10
        cmap = cmap_discretize(cmap, n_colors=cn)
        cnorm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
        cticks = np.linspace(vmin, vmax, num=cn + 1)
        # assert(type(smap) == str)
        cn = 10
        smap = cmap_discretize(smap, n_colors=cn)
        snorm = mpl.colors.Normalize(vmin=smin, vmax=smax)
        sticks = np.linspace(smin, smax, num=cn + 1)
        # assert(type(amap) == str)
        an = 9
        amap = cmap_discretize(amap, n_colors=an)
        anorm = mpl.colors.Normalize(vmin=amin, vmax=amax)
        aticks = np.round(np.linspace(amin, amax, num=an + 1), 1)

        keys = list(monthly_means.keys())

        dss_rasm = monthly_means[keys[0]].resample('QS-SEP', dim='time').groupby('time.season')
        dsa_rasm = monthly_means[keys[0]].resample('AS', dim='time')
        
        dss_era = monthly_means[keys[1]].resample('QS-SEP', dim='time').groupby('time.season')
        dsa_era = monthly_means[keys[1]].resample('AS', dim='time')
        
        # nrows = len(monthly_means) + 1
        ncols = 5
        width = 11
        
        nrows = 8
        
        # height = 1.55 * nrows + 0.6
        height = 1.55 * nrows + 0.6
        
        fig, axes = plt.subplots(nrows=nrows, ncols=ncols, 
                                 figsize=(width, height),
                                 sharex=True, sharey=True)
        plt.subplots_adjust(left=0.125, bottom=0.05,
                            right=0.9, top=0.9,
                            wspace=0.05, hspace=0.05)
        
        # rasm means 
        season_means_rasm = dss_rasm.mean(dim='time')
        annual_means_rasm = dsa_rasm.mean(dim='time')
        
        # ERA-Interim means
        season_means_era = dss_era.mean(dim='time')
        annual_means_era = dsa_era.mean(dim='time')
        
        # seasonal/annual means for RASM baseline and ERA-Interim as top two rows 
        # RASM baseline
        for i, season in enumerate(seasons):
            plt.sca(axes[0, i])
            draw_map()
            sub_plot_pcolor(np.ma.masked_where(spatial_plot_mask, season_means_rasm.sel(season=season).to_masked_array()),
                            map_obj=wr50a_map, cbar=None, vmin=vmin, vmax=vmax, cmap=cmap, ax=axes[0, i])


        plt.sca(axes[0, 4])
        draw_map()
        sub_plot_pcolor(np.ma.masked_where(spatial_plot_mask, annual_means_rasm.to_masked_array()),
                        map_obj=wr50a_map, cbar=None, vmin=vmin, vmax=vmax, cmap=cmap, ax=axes[0, 4])
        
        # ERA-Interim
        for i, season in enumerate(seasons):
            plt.sca(axes[1, i])
            draw_map()
            sub_plot_pcolor(np.ma.masked_where(spatial_plot_mask, season_means_era.sel(season=season).to_masked_array()),
                            map_obj=wr50a_map, cbar=None, vmin=vmin, vmax=vmax, cmap=cmap, ax=axes[1, i])


        plt.sca(axes[1, 4])
        draw_map()
        sub_plot_pcolor(np.ma.masked_where(spatial_plot_mask, annual_means_era.to_masked_array()),
                        map_obj=wr50a_map, cbar=None, vmin=vmin, vmax=vmax, cmap=cmap, ax=axes[1, 4])

        # anomalies for RASM baseline run 
        count = 2
        for j, key in enumerate(keys[2:]):
            print(key)
            # calculate seasonal means for this dataset
            dss = monthly_means[key].resample('QS-SEP', dim='time').groupby('time.season')
            
            # calculate annual means for this dataset
            dsa = monthly_means[key].resample('AS', dim='time')
            
            # SEASONAL MEANS
            # calculate difference between seasonal means for rasm baseline and this dataset
            season_anoms_rasm = (season_means_rasm - dss.mean(dim='time'))
            
            # calculate difference between seasonal means for ERA-Interim and this dataset
            season_anoms_era = (season_means_era - dss.mean(dim='time'))
            
            # ANNUAL MEANS
            # calculate difference between annual means for rasm baseline and this dataset 
            annual_anoms_rasm = (annual_means_rasm - dsa.mean(dim='time'))
            
            # calculate difference between annual means for ERA-Interim and this dataset 
            annual_anoms_era = (annual_means_era - dsa.mean(dim='time'))

            # plot rasm row 
            for i, season in enumerate(seasons):
                plt.sca(axes[count, i])
                draw_map()
                sub_plot_pcolor(np.ma.masked_where(spatial_plot_mask, season_anoms_rasm.sel(season=season).to_masked_array()),
                                map_obj=wr50a_map, cbar=None, vmin=amin, vmax=amax, cmap=amap, ax=axes[count, i])
            
            # plot annual anomalies for rasm 
            plt.sca(axes[count, 4])
            draw_map()
            sub_plot_pcolor(np.ma.masked_where(spatial_plot_mask, annual_anoms_rasm.to_masked_array()),
                            map_obj=wr50a_map, cbar=None, vmin=amin, vmax=amax, cmap=amap, ax=axes[count, 4])
            
            # plot ERA-Interim row 
            for i, season in enumerate(seasons):
                plt.sca(axes[count + 1, i])
                draw_map()
                sub_plot_pcolor(np.ma.masked_where(spatial_plot_mask, season_anoms_era.sel(season=season).to_masked_array()),
                                map_obj=wr50a_map, cbar=None, vmin=amin, vmax=amax, cmap=amap, ax=axes[count + 1, i])

            # plot annual anomalies for ERA-Interim
            plt.sca(axes[count + 1, 4])
            draw_map()
            sub_plot_pcolor(np.ma.masked_where(spatial_plot_mask, annual_anoms_era.to_masked_array()),
                            map_obj=wr50a_map, cbar=None, vmin=amin, vmax=amax, cmap=amap, ax=axes[count + 1, 4])
            
            count += 2


        plt.tight_layout()

        # titles = [ax.set_title(str(title)) for title, ax in zip(list(seasons) + ['Annual'], axes[0])]
        new_seasons = list(seasons) + ['Annual']
        for i, ax in enumerate(axes[0]):
            ax.set_title(new_seasons[i])
        # ylabels = [ax.set_ylabel("{0}\n — {1}".format(keys[0], label)) for label, ax in zip(keys[1:], axes[2:, 0])]
        for label, ax in zip(keys[1:], axes[2:, 0]): 
            # ax.set_ylabel("{0}\n — {1}".format(keys[0], label))
            ax.set_ylabel("%s \n - %s" % (keys[0], label))
        # axes[0, 0].set_ylabel(keys[0])
        # axes[1, 0].set_ylabel("{0} (Std.)".format(keys[0]))
        axes[0, 0].set_ylabel("%s" % keys[0])
        axes[1, 0].set_ylabel("%s" % keys[1])
        axes[2, 0].set_ylabel("%s \n - %s" % (keys[0], keys[2]))
        axes[2, 0].set_ylabel("%s \n - %s" % (keys[1], keys[2]))
        
        axes[4, 0].set_ylabel("%s \n - %s" % (keys[0], keys[3]))
        axes[5, 0].set_ylabel("%s \n - %s" % (keys[1], keys[3]))
        
        axes[6, 0].set_ylabel("%s \n - %s" % (keys[0], keys[4]))
        axes[7, 0].set_ylabel("%s \n - %s" % (keys[1], keys[4]))
        

        # Colorbars
        cbar_height = 0.02
        cbar_width = .313
        ax1 = fig.add_axes([0.01, -cbar_height, cbar_width, cbar_height])
        cb1 = mpl.colorbar.ColorbarBase(ax1, cmap=cmap, norm=cnorm,
                                        orientation='horizontal',
                                        extend=cbar_extend,
                                        ticks=cticks)
        ax2 = fig.add_axes([0.343, -cbar_height, cbar_width, cbar_height])
        cb2 = mpl.colorbar.ColorbarBase(ax2, cmap=smap, norm=snorm,
                                        orientation='horizontal',
                                        extend=sbar_extend,
                                        ticks=sticks)
        ax3 = fig.add_axes([0.673, -cbar_height, cbar_width, cbar_height])
        cb3 = mpl.colorbar.ColorbarBase(ax3, cmap=amap, norm=anorm,
                                        orientation='horizontal',
                                        extend=abar_extend,
                                        ticks=aticks)
        if varname is not None:
            cb1.set_label("{0} ({1}, Mean)".format(varname, units))
            cb2.set_label("{0} ({1}, Std.)".format(varname, units))
            cb3.set_label("{0} ({1}, Anomaly)".format(varname, units))

    return fig, axes

In [None]:
'''var = 'runoff_tot'
monthly_means = OrderedDict()
monthly_means['15_37_b'] = ncdata['baseline_sim_trunc'][var]
monthly_means['ERA-Interim'] = ncdata['era_interim_ts3'][var]

monthly_means['200_32_b'] = ncdata['200_32_b'][var]
monthly_means['200_37_a'] = ncdata['200_37_a'][var]
monthly_means['200_37_a_cfsr'] = ncdata['200_37_a_cfsr'][var]

fig, axes = plot_anoms(monthly_means, vmin=0., vmax=3, smax=1, amin = -0.5, 
                             amax = 0.5, varname='Runoff', units='mm/day', cmap='Blues', amap='RdBu',
                             cbar_extend='max', era_interim=True)

plotname = '%s.png' % var
plot_direc = '/u/home/gergel/rasm_postprocessing/plots_20180425'
if not os.path.exists(plot_direc):
    os.makedirs(plot_direc)
savepath = os.path.join(plot_direc, plotname)
plt.savefig(savepath, format='png', dpi=dpi, bbox_inches='tight')'''