In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import sys
if '../' not in sys.path:
    sys.path.append('../')

import utils.process as process
import utils.params as params

In [None]:
basefolder_data='../datasets'
sites=['temperate','boreal','tropical']
integrated_gradients={}
for site in sites:
    integrated_gradients[site]=pickle.load(open(f"{basefolder_data}/{site}/integrated_gradients.p", "rb" ))
plotnames = dict(zip(sites, ['Temperate', 'Boreal', 'Tropical']))

In [None]:
grads_per_month={}
for site in sites:
    grads_per_month[site] = process.divide_into_months(integrated_gradients[site], mode='sum')
    grads_per_month[site] = grads_per_month[site].reshape(grads_per_month[site].shape[0], 3, 17)
    grads_per_month[site] = grads_per_month[site].transpose(0, 2, 1)

In [None]:
fontsize=26
fontsize_ticks=18


fig, axs = plt.subplots(3, 3, figsize=(15, 10), sharey='row')
fig.suptitle('Importances of Meteorological Drivers for yearly GPP Production', fontsize=fontsize, y=0.98)

vars = ['Precipitation', 'Temperature', 'Radiation']
months = list(range(-6, 0)) + list(range(1,13))
months = np.array(months)
months_labels = ['Aug','Sept', 'Oct', 'Nov', 'Dec', 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']

norm_factors={}
for i, site in enumerate(sites):
    norm_factors[site] = np.sum(np.abs(grads_per_month[site][:,:,:]).mean(axis=0))

for i, site in enumerate(sites):
    sum_importances=0

    x = np.array(range(17))
    for j, var in enumerate(vars):

        importances = np.abs(grads_per_month[site][:,:,j]).mean(axis=0)/norm_factors[site]
        sum_importances += np.sum(importances)
        axs[i, j].bar(x, importances, width=1.0, edgecolor='white', linewidth=0.7, label='all', color='blue')

        if i==0:
            axs[i, j].set_title(f"{var}", fontsize=fontsize)
        if i == 2:
            axs[i, j].set_xlabel('month', fontsize=fontsize)
        if j == 0:
            axs[i, j].set_ylabel(f'{plotnames[site]} Site', fontsize=fontsize)
        axs[i, j].set_xticks(x[::2])
        axs[i, j].set_xticklabels(months_labels[::2], rotation=90, fontsize=fontsize_ticks)
        axs[i, j].set_yticks([])
        axs[i, j].set_yticklabels([])

    print(f"Sum of importances for {site}: {sum_importances:.2f}")

# Adjust spacing between subplots
fig.subplots_adjust(hspace=0.4, wspace=0.0)  # You can set these to 0.0 for zero spacing
plt.savefig('./feature_importance.pdf', dpi=300) 
