In [None]:
%matplotlib inline
%autoreload 2

from tonic.io import read_config, read_netcdf
import xarray as xr
import pandas as pd
import numpy as np
import argparse
import os
import matplotlib.pyplot as plt
from general_functions import datestamp
from collections import OrderedDict

In [None]:
# parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
# parser.add_argument("config_file",help="input config file")
# args = parser.parse_args()
# config_file = args.config_file

# Read in the config file (TODO: make this a command line argument by uncommenting the lines above)
config_file = 'spatial_comparisons.cfg'
config_dict = read_config(config_file)

# Read in the output directory and make it if it doesn't already exist
global_options = config_dict.pop('Global')
out_dir = global_options['figure_dir']
os.makedirs(out_dir, exist_ok=True)

# Count the files that you have in your config file (either one or two)
files = []
for key in config_dict.keys():
    if 'file' in key:
        files.append(key)

# Make a dictionary for the file options
# Read in the strings used to identify/describe the data contained in each file
file_options_dict = OrderedDict({})
for file in files:
    file_options_dict[file] = config_dict.pop(file)

# The remaining pieces of the config_dict are the variables and their information
variables = config_dict

In [None]:
# TO-DO:
# Add comparison (difference plots)
# Convert to python script
# 

In [None]:
def calculate_mean_monthly_spatial_field(ds, variable):
    '''
    Input: xarray dataset
    Output: Mean monthly spatial array for dataset 
    '''
    monthly = ds[variable].groupby('time.month').mean(dim='time')
    return monthly

In [None]:
def mean_monthly_spatial_plot(monthly_groupby_arrays, variable, title, units, vmin, vmax, figure_stub):

    '''
    Saves a figure with 12 spatial plots of mean monthly 
    variable fields.
    '''

    monthly_plot_setup= OrderedDict([('January', (0,0)), ('February', (0,1)), 
                                 ('March', (0,2)), ('April', (0,3)), 
                                 ('May', (1,0)), ('June', (1,1)), 
                                 ('July', (1,2)), ('August', (1,3)),
                                 ('September', (2,0)), ('October', (2,1)),
                                 ('November', (2,2)), ('December', (2,3))])
    # Initialize the figure
    f, axarr = plt.subplots(3, 4, figsize=(20,16), sharex=False, sharey=False)
    
    # Cycle through the months in the setup dictionary defined above and put the plots in the
    # correct location in the plot according to the items in the dictionary
    for i,month in enumerate(monthly_plot_setup.keys()):
        ax = plt.sca(axarr[monthly_plot_setup[month]])
        plot = monthly_groupby_arrays[i].plot(ax=ax, vmin=vmin, vmax=vmax, add_colorbar=False)
        
    # Assign a title based upon the setup dictionary defined above
        plt.title(month, size=20)
        
    # Make room for a title and a colorbar
    f.subplots_adjust(left=0.12, right=0.9, top=0.93)

    # Plot a colorbar on the left side of the plot
    cbar_ax_abs = f.add_axes([0.045, 0.15, 0.014, 0.8])
    cbar_ax_abs.tick_params(labelsize=20)
    cbar_abs = plt.colorbar(plot, cax=cbar_ax_abs, extend='max').set_label(label=(' ').join([variable,units]),
                                                                           size=30, labelpad=-100)
    # Add a colorbar
    plt.suptitle((' ').join([title, 'mean monthly spatial fields']), size=30)
    f.savefig(os.path.join(out_dir, ('_').join([datestamp(),figure_stub,variable+'.png'])))

In [None]:
# Loop through the files that you want to plot (there is either one or two files)
for file_label, file_options in file_options_dict.items():
    
# Access the data path, title and figure stub (same for all varaibles' plots)
    file_path, title, figure_stub = file_options['path'], file_options['description'], file_options['figure_stub']
    
# Open the data file
    ds = xr.open_dataset(file_path)
    
# Loop through the variables and create a plot for each of them
    for variable, option_dict in variables.items():
        
# Check whether the variable is present in this file that we're plotting
        if file_label in option_dict['files']:
# Access the units and vmax/vmin which differ for every variable        
            units, vmin, vmax = option_dict['units'], option_dict['vmin'], option_dict['vmax']
# Get the groupby'ed mean monthly arrays
            monthly = calculate_mean_monthly_spatial_field(ds, variable)

# Make the plot!
            mean_monthly_spatial_plot(monthly, variable, title, units, vmin, vmax, figure_stub)

In [None]:
# Make a difference plot if there is more than one netcdf file
if len(file_options.keys()) > 1:   
# Access the data path, title and figure stub (same for all varaibles' plots)
    figure_stub = global_options['comparison_stub']   
    title = global_options['comparison_title']
# Open the data files and load them into datasets
    ds = {}
    for file_label, file_options in file_options_dict.items():
# Access the data path, title and figure stub (same for all varaibles' plots)
        file_path, figure_stub = file_options['path'], file_options['figure_stub']
# Open the data file
        ds[file_label] = xr.open_dataset(file_path)
# Loop through the variables and create a plot for each of them that appears in both datasets
    for variable, option_dict in variables.items():
# Check whether the variable is present in this file that we're plotting
        if set(option_dict['files']) == set(file_options_dict.keys()):
# Access the units and vmax/vmin which differ for every variable        
            units, vmin, vmax = option_dict['units'], option_dict['vmin'], option_dict['vmax']
# Get the groupby'ed mean monthly arrays
            monthly = {}
            for file_label in file_options_dict.keys():
                monthly[file_label] = calculate_mean_monthly_spatial_field(ds[file_label], variable)
# Take the difference- subtract the second file from the first
            labels = list(file_options_dict.keys())
            difference = monthly[labels[0]]-monthly[labels[1]]
# To keep track of which file is being subtracted, get a little text snippet to add to the figure's title
            figure_text = (' - ').join([file_options_dict[labels[0]]['figure_stub'],
                                        file_options_dict[labels[1]]['figure_stub']])
            title += '\n'+figure_text
# Make the plot!
            mean_monthly_spatial_plot(difference, variable, title, units, vmin, vmax, figure_stub)