# Notebook Setup

In [None]:
if 'google.colab' in str(get_ipython()):
    IN_COLLAB = True
else:
    IN_COLLAB = False

if IN_COLLAB:
    #TODO: CHANGE THIS BASED ON YOUR OWN LOCAL SETTINGS
    # MY_HOME_ABS_PATH = "/content/drive/MyDrive/W210/co2-flux-hourly-gpp-modeling"
    MY_HOME_ABS_PATH = "/content/drive/MyDrive/W210/co2-flux-hourly-gpp-modeling"
    from google.colab import drive
    drive.mount('/content/drive/')
else:
    # MY_HOME_ABS_PATH = "/root/co2-flux-hourly-gpp-modeling/"
    MY_HOME_ABS_PATH = "/root/co2-flux-hourly-gpp-modeling"

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


## Import Modules

In [None]:
required_packages = ['azure-storage-blob', 'kaleido', 'nbformat']  # Might need to restart kernel after installing nbformat

for p in required_packages: 
    try:
         __import__(p)
    except ImportError:
          %pip install {p} --quiet

In [None]:
import os
os.chdir(MY_HOME_ABS_PATH)

import sys
import warnings
warnings.filterwarnings("ignore")
import copy
import json
from pathlib import Path
import numpy as np
import pandas as pd

# required plotly libs
import kaleido
import matplotlib.pyplot as plt
import plotly.express as px
from plotly.express.colors import sample_colorscale
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.io as pio

# from pytorch_forecasting import TemporalFusionTransformer

from datetime import datetime
import gc
import pickle

# Load locale custome modules
os.chdir(MY_HOME_ABS_PATH)
if IN_COLLAB:
     sys.path.insert(0,os.path.abspath("./code/src/tools"))
else:
    sys.path.append('./.cred')
    sys.path.append('./code/src/tools')
    sys.path.append(os.path.abspath("./code/src/tools"))

from CloudIO.AzStorageClient import AzStorageClient

pd.set_option('display.max_columns', None)
pd.set_option('display.float_format', lambda x: '%.5f' % x)

## Define Local File System Constants

In [None]:
root_dir =  MY_HOME_ABS_PATH
tmp_dir =  root_dir + os.sep + '.tmp'
raw_data_dir = tmp_dir
data_dir = root_dir + os.sep + 'data'
img_dir = data_dir + os.sep + 'figures'
cred_dir = root_dir + os.sep + '.cred'
az_cred_file = cred_dir + os.sep + 'azblobcred.json'
model_objects_dir = root_dir + os.sep + 'code/src/modeling/model_objects'

# Plot

In [None]:
ENCODER_LEN = 24*14
filname = "encoder_fi_df_AU-DaP_GPP_TFT_EN14_month1.csv"
plot_df = pd.read_csv(tmp_dir + os.sep + filname)
print(f"Encoder Length({ENCODER_LEN})")


Encoder Length(336)


In [None]:
plot_df.head()

Unnamed: 0,encoder_index,month,day,hour,gap_flag_month,gap_flag_hour,timestep_idx_global,TA_ERA,SW_IN_ERA,LW_IN_ERA,VPD_ERA,P_ERA,PA_ERA,EVI,NDVI,NIRv,b1,b2,b3,b4,b5,b6,b7,BESS-PAR,BESS-PARdiff,BESS-RSDN,CSIF-SIFdaily,PET,Ts,ESACCI-sm,NDWI,Percent_Snow,Fpar,Lai,LST_Day,LST_Night,relative_time_idx,GPP_NT_VUT_REF,encoder_attention
0,-336,0.00676,0.0029,0.0205,0.0215,0.05365,0.02157,0.03694,0.02914,0.01535,0.00568,0.01712,0.01996,0.00894,0.00333,0.01465,0.00841,0.00198,0.02084,0.36876,0.01639,0.01659,0.02135,0.01901,0.0192,0.01579,0.01178,0.01823,0.00149,0.00566,0.03099,0.00531,0.0121,0.02289,0.00735,0.01603,0.00489,0.07696,0.0008
1,-335,0.00765,0.00281,0.02278,0.02134,0.06404,0.02432,0.04618,0.02807,0.01659,0.00693,0.0184,0.02389,0.01059,0.00408,0.01567,0.01039,0.00203,0.02164,0.30291,0.01726,0.01799,0.02347,0.02099,0.02101,0.01699,0.01239,0.01987,0.0015,0.00551,0.03482,0.00447,0.01355,0.02673,0.00688,0.01736,0.00522,0.08367,0.00071
2,-334,0.00596,0.00268,0.01795,0.0178,0.05067,0.01987,0.03435,0.02419,0.01428,0.00478,0.01577,0.0187,0.00773,0.00329,0.01423,0.00721,0.00192,0.01917,0.4119,0.01545,0.01534,0.01944,0.01774,0.01759,0.01468,0.01118,0.01674,0.00154,0.00532,0.02733,0.00748,0.01123,0.02137,0.00763,0.01482,0.00461,0.07809,0.00162
3,-333,0.01411,0.00436,0.02114,0.01838,0.10716,0.03572,0.06819,0.03335,0.01787,0.01212,0.03288,0.0271,0.01729,0.00478,0.02372,0.01845,0.00274,0.01908,0.12195,0.01794,0.02029,0.02928,0.0255,0.0265,0.01798,0.01316,0.0244,0.00149,0.00211,0.03473,0.00269,0.02169,0.03195,0.00856,0.01953,0.00756,0.09427,8e-05
4,-332,0.00872,0.0028,0.02373,0.01841,0.07438,0.02658,0.04988,0.02897,0.01718,0.00768,0.01967,0.02598,0.01089,0.00462,0.01836,0.01125,0.00211,0.02203,0.25713,0.01801,0.01879,0.02545,0.02292,0.02246,0.0176,0.01276,0.02112,0.00154,0.00426,0.0355,0.00446,0.01505,0.02711,0.0073,0.01811,0.00538,0.0918,0.00049


In [32]:
features=['GPP_NT_VUT_REF','VPD_ERA','P_ERA','TA_ERA', 'ESACCI-sm','BESS-PAR','b4'] #TODO: Put in preferred order
xticks = [i for i in range(-ENCODER_LEN,0, 24)]

readable_feature_names = {
    'GPP_NT_VUT_REF': 'GPP',
    'BESS-PAR': 'Photosynthetic Active Radiation',
    'ESACCI-sm': 'Soil Moisture', 
    'b4': 'MODIS Band 4',
    'VPD_ERA': 'Vapor Pressure Deficit',
    'P_ERA': 'Precipitation ', 
    'TA_ERA': 'Air Temperature',
}

# Select color scale from: https://plotly.com/python/builtin-colorscales/#builtin-sequential-color-scales
colors = sample_colorscale('tempo', np.linspace(0.2, 0.85, len(features)))

# Plot feature importance time-series
fig = make_subplots(specs=[[{"secondary_y": True}]])
for i,  f in enumerate(features):
    fig.add_trace( 
        go.Scatter( x=plot_df['encoder_index'], y=plot_df[f],
                   name=f, mode='lines',
                   line_color = '#AAA', line_width = 1,
                   fillcolor =  colors[i],
                   stackgroup='one' , # define stack group
                   hovertemplate = '%{y:.4f}'
                  ),
        secondary_y=False,
    )
    
# Update lengend name to readable feature  names
fig.for_each_trace(lambda t: t.update(name = readable_feature_names[t.name],
                                      legendgroup = readable_feature_names[t.name],
                                      #hovertemplate = t.hovertemplate.replace(t.name, readable_feature_names[t.name])
                 )
)

# Add attention line
fig.add_trace(
    go.Scatter(
        x=plot_df['encoder_index'], y=plot_df['encoder_attention'],
        mode='lines', line_color = 'white', line_width = 3,
        name = 'Average Attention',
        showlegend = False,
        hovertemplate = '%{y:.4e}'
    ),
    secondary_y=True
)

# Other formattings stuff

# Available plotly template/theme: https://plotly.com/python/templates/
# ['ggplot2', 'seaborn', 'simple_white', 'plotly', 'plotly_white', 'plotly_dark', 'presentation', 'xgridoff', 'ygridoff', 'gridon', 'none']
fig.update_layout(title={'text': "Feature Importance by Time Steps Before Prediction", 'y':0.965,'x':0.5},
                  margin={"r":10,"t":75,"l":60,"b":50},
                  height = 500, width = 1000,
                  legend={ 'title':{'text' :"Features"}, 'orientation':"h", 
                          'y':0.95, 'itemwidth':50,
                          #'x': 0.1,  'itemwidth':30,
                          },
                  hovermode="x unified", # or just "x"
                  template='plotly_white')
fig.update_xaxes(title={'text': "Encoder Index: # of Time Step (Hour) Before Prediction", 
                        'font_size': 14, 'standoff': 0},   #autorange='reversed',
                 tickvals=xticks,
                 rangeslider_visible=True, # show time sliders
                )
fig.update_yaxes(title={'text': "Importance",  'font_size': 14, 'standoff':0}, secondary_y=False)
fig.update_yaxes(title={'text': "Avergae Attention",  'font_size': 14, 'standoff':0}, showgrid=False, secondary_y=True)


# y-scaling button (removable)
showToggle = True
target_trace_ids = [i for i in range(len(features))]
if showToggle:
    fig.update_layout(
        updatemenus=[
            dict(
                type = "buttons",
                direction = "left",
                buttons=list([
                    dict(
                        args=[{"groupnorm": "percent", 'hovertemplate':'%{y:.2f}%'},target_trace_ids],
                        label="Relvative(%)",
                        method="restyle"
                    ),
                    dict(
                         args=[{"groupnorm": "", 'hovertemplate':'%{y:.4f}'},target_trace_ids],
                        label="Absolute",
                        method="restyle"
                    )
                ]),
                pad={"r": 0, "t": 0},
                active=1,
                x=-0.05, xanchor="left",
                y=1.3,
                font_size=10,
            ),
        ]
    )
    
fig.show()

In [33]:
# for website
MAX_WIDTH  = 740

fig.update_layout(title={'text': "Feature Importance by Time Steps Before Prediction", 'y':0.94,'x':0.5},
                  margin={"r":10,"t":150,"l":10,"b":50},
                  height = int(MAX_WIDTH*0.75), width = MAX_WIDTH,
                  legend={ 'tracegroupgap':0.5, 'traceorder':"reversed", 'y':1.025},
                  # 'title':{'text' :"Features"}, 'orientation':"h", 'x': 0.95, 'xanchor': "left", 'itemwidth':50,
                  hovermode="x unified", # or just "x"
                  hoverlabel = { 'bgcolor':"#333"},
                  template='plotly_dark')
fig.update_yaxes(title={'text': "Importance",  'font_size': 14, 'standoff':0}, secondary_y=False)
fig.update_yaxes(title={'text': "Avergae Attention",  'font_size': 14, 'standoff':0}, showgrid=False, secondary_y=True)

# y-scaling button (removable)
showToggle = True
if showToggle:
    fig.update_layout(
        updatemenus=[
            dict(
                type = "buttons",
                direction = "left",
                buttons=list([
                    dict(
                        args=[{"groupnorm": "percent", 'hovertemplate':'%{y:.2f}%'},target_trace_ids],
                        label="%",
                        method="restyle"
                    ),
                    dict(
                         args=[{"groupnorm": "", 'hovertemplate':'%{y:.4f}'},target_trace_ids],
                        label="#",
                        method="restyle"
                    )
                ]),
                pad={"r": 0, "t": 0, "l":0, "b":0},
                active=1,
                x=-0.05, xanchor="left",
                y=1.5,
                font_color='teal',
                bgcolor="rgba(0,0,0,0)"
            ),
        ]
    )

# For Dark backgroun
fig.update_layout(paper_bgcolor="rgba(0,0,0,0)", )
    
fig.show()

# Export to HTML
file_name = img_dir + os.sep + "FeatureImportance_plot.html" # TODO: Update name if there are multiple plots
#pio.write_html(fig, file = file_name, include_plotlyjs = 'cdn', include_mathjax='cdn')

# use this way to save to avoid encoding issues on negative sign
fig_json = fig.to_json()
iframe = '<iframe srcdoc="{0}" style="width:100%; height:555px; border:none"></iframe>'.format(fig_json)
with open(file_name, 'w') as f:
    f.write('<html><head><script src="https://cdn.plot.ly/plotly-latest.min.js"></script></head><body>')
    f.write('<div id="plot"></div>')
    f.write('<script>Plotly.newPlot("plot", {0});</script>'.format(fig_json))
    f.write('</body></html>')