In [1]:
import ipywidgets as widgets
from ipywidgets import HBox, interact, interactive, IntSlider, Label 
from IPython.display import display
%matplotlib widget
%matplotlib inline
import math
import matplotlib.pyplot as plt
import os
import pandas as pd
import plotly.express as px
import sys
#add methods_exploration/functions to python path
sys.path.append('functions')
import data_exploration_functions as dexf
import utils
#sys.path
plt.rcParams['figure.figsize'] = [12, 5]

# Correlation and mutual information between sources and sinks: Salt front drivers
***
All data in this notebook are observational, although the hope is that this analysis will be applied to compare model output from COAWST and machine learning model

## 1) Input sources and sinks

>Sources should be a csv with the first column named 'datetime' with dates in the format YYYY-MM-DD. Other column titles should be of of the form 'site_variable' for example 'Trenton_discharge'. 

>Sinks should be a csv with the first column named 'datetime' with dates in the format YYYY-MM-DD, one additional column will be the sink variable(s), for example 'Daily_salt_front' and '7_day_avg_salt_front'

>datetime column must be continuous, missing data in other columns are OK and will skipped in analysis

>sources and sinks must have the same time step although they do not need to perfectly overlap in time, the analyses will be run on their overlapping period



In [4]:
sources = pd.read_csv('data/srcs_example_extended.csv', index_col = 'datetime', parse_dates = True)
sinks = pd.read_csv('data/snks_example_extended.csv', index_col = 'datetime', parse_dates = True)
sinks.head()

Unnamed: 0_level_0,Daily_salt_front
datetime,Unnamed: 1_level_1
2000-01-01,
2000-01-02,
2000-01-03,
2000-01-04,
2000-01-05,


## 2) Plot time series of sources and sinks
>Use the interactivity to pan and zoom on the time series, this is for exploration of the sources and sinks

In [5]:
sources_long = sources.stack().reset_index()
sources_long.columns = ['datetime','name','value']
sources_long[['site','var']] = sources_long['name'].str.split('_',expand = True)

sinks_long = sinks.stack().reset_index()
sinks_long.columns = ['datetime','sink_name','value']

fig = px.line(sources_long, x='datetime', y='value', color='site', facet_row = 'var', height=600, width=800,
              title="Sources")
fig.update_yaxes(matches=None,title_text='')
#fig.for_each_yaxis(lambda yaxis: yaxis.update(title  = var))
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig.show()

fig = px.line(sinks_long, x='datetime', y='value', color='sink_name', height=500, width=800,
              title="Sinks")
fig.update_yaxes(title_text='River Mile')
#fig.for_each_yaxis(lambda yaxis: yaxis.update(title  = var))
#fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig.show()


## 3) Make interactive heat map for correlation

>Running the following cell will generate: 
> 1. a heat map of lagged correlations between sources and sinks  
> 2. a plot of the correlations between sources and sinks across the time lags considered

> Both plots are reactive to the inputs in the upper left where you can select: 
> 1. the start and end date for the analysis (start_date and end_date)
> 2. the number of lags to consider (n_lags)
> 3. the threshold level you'd like to mask out in the heat map (mask_threshold) 

In [7]:
def f(start_date, end_date, n_lags, mask_threshold):
    dexf.generate_correlation_heatmap(sources, sinks, start_date, end_date, n_lags, mask_threshold)
    dexf.generate_correlation_timeseries(sources, sinks, start_date, end_date, n_lags, mask_threshold)
interactive_plot = interactive(f,
                               start_date = widgets.DatePicker(value = pd.to_datetime('2001-01-01')),
                               end_date = widgets.DatePicker(value = pd.to_datetime('2020-01-01')),
                               n_lags = widgets.IntSlider(min = 0,max = 9,step = 1, value = 4, continuous_update=False),
                               mask_threshold = widgets.FloatSlider(min = 0,max = 1,step = 0.01, value = 0.5, continuous_update=False))
interactive_plot

interactive(children=(DatePicker(value=Timestamp('2001-01-01 00:00:00'), description='start_date'), DatePicker…

## 4) Make interactive heat map for mutual information


>Running the following cell will generate: 
> 1. a heat map of lagged mutual information between sources and sinks  
> 2. a plot of the mutual information between sources and sinks across the time lags considered

>*mutual information* - the amount of information obtained by one variable when observing another. Here mutual information between source and sink variables is normalized by the uncertainty of the sink variables, so that the output can be conceptualized as the fraction of uncertainty in the sink variable that can be explained by the source variable. __[More info](https://github.com/pdirmeyer/l-a-cheat-sheets/blob/main/Coupling_metrics_V30_MI.pdf)__

>mutual information differs from correlation in a few key ways:
    > -  it compares the probability distributions of the variables so it makes no assumption about the functional form of the relationship
    > -  it considers the entire distrubtion of the variables, as such it is more sensitive to outliers and small sample size

> Both plots are reactive to the inputs in the upper left where you can select: 
> 1. the start and end date for the analysis (start_date and end_date)
> 2. the number of lags to consider (n_lags)
> 3. the threshold level you'd like to mask out in the heat map (mask_threshold) 

In [8]:
def mi(start_date, end_date, n_lags, mask_threshold):
    dexf.generate_mutual_information_heatmap(sources, sinks, start_date, end_date, n_lags, mask_threshold)
    dexf.generate_mutual_information_timeseries(sources, sinks, start_date, end_date, n_lags, mask_threshold)
interactive_plot = interactive(mi,
                               start_date = widgets.DatePicker(value = pd.to_datetime('2001-01-01')),
                               end_date = widgets.DatePicker(value = pd.to_datetime('2020-01-01')),
                               n_lags = widgets.IntSlider(min = 0,max = 9,step = 1, value = 4, continuous_update=False),
                               mask_threshold = widgets.FloatSlider(min = 0,max = 1,step = 0.01, value = 0.0, continuous_update=False))
interactive_plot

interactive(children=(DatePicker(value=Timestamp('2001-01-01 00:00:00'), description='start_date'), DatePicker…