In [13]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# import seaborn as sns
import os

In [14]:
from molten.data_drift.kdq_tree import KdqTree

In [20]:
df_orig = pd.read_csv(os.path.join("src", "molten", "tools", "artifacts", "example_data.csv"),
                index_col = 'id')

In [21]:
df = pd.concat([df_orig, pd.get_dummies(df_orig.cat, prefix='cat')], axis=1)
df.drop(columns=['cat', 'confidence', 'drift'], inplace=True)

In [22]:
df.head()

Unnamed: 0_level_0,year,a,b,c,d,e,f,g,h,i,j,cat_1,cat_2,cat_3,cat_4,cat_5,cat_6,cat_7
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
1,2007,6043.607465,206.843471,11079.264098,75455.714279,1.559448,-0.021553,161658.917829,7256.384898,8774.638131,911.289707,0,1,0,0,0,0,0
2,2007,11113.663042,197.908656,7548.574555,78957.689468,0.214995,-0.078272,83044.306589,8067.669091,6471.199987,1641.692884,0,0,0,0,1,0,0
3,2007,3586.610839,197.208754,10694.642849,99401.964207,1.790866,0.606395,145751.934456,10196.104571,6505.145608,744.610919,1,0,0,0,0,0,0
4,2007,7858.247874,203.033535,6025.428771,97933.752903,-0.033178,-1.116475,58802.134349,13417.866666,9850.291597,909.445086,1,0,0,0,0,0,0
5,2007,12932.260882,218.739229,7100.66148,114541.129273,1.808175,1.847939,106241.702624,4275.532336,5175.73494,786.073604,0,0,0,0,1,0,0


In [23]:
plot_data = {}
np.random.seed(123)
status = pd.DataFrame(columns=['year', 
                                'drift'])
det = KdqTree(input_type="batch")
for group, subdf in df.groupby('year'):
    det.update(subdf.drop(columns=['year']).values)
    status = pd.concat([status, pd.DataFrame({'year':[group], 
                                              'drift':[det.drift_state]})], 
                        axis=0, ignore_index=True)
    if det.drift_state is not None:
        #capture the visualization data
        plot_data[group] = det.to_plotly_dataframe()
        #update the reference window if drift is detected
        det.update(subdf.drop(columns=['year']).values)

In [24]:
#It's only in year 2018 where, for this test data, we don't detect drift
#immediately. It does get picked up in the following year.
status.merge(drift_years, how='left', on='year')

Unnamed: 0,year,drift,drift_true
0,2007,,
1,2008,,
2,2009,drift,drift
3,2010,drift,drift
4,2011,,
5,2012,drift,drift
6,2013,drift,drift
7,2014,,
8,2015,drift,drift
9,2016,drift,drift


In [26]:
#If we save off the dataframes at each drift detection, we can subsequently display the KSS:
#right now we're not gracefully including the column names, but in this case the map is simple:
#ax 0 - column a
#ax 1 - column b
#ax 2 - column c
#ax 3 - column d
#ax 4 - column e
#ax 5 - column f
#ax 6 - column g
#ax 7 - column h
#ax 8 - column i
#ax 9 - column j


#We can see that the regions of greatest drift do line up with at least one of
#the items that were modified in a given year.
#For reference:
    # Drift 1: change the mean & var of item B in 2009, means will revert for 2010 on
    # Drift 2: change the variance of item c and d in 2012 by replacing some with the mean
    # keep same mean as other years, revert by 2013
    # Drift 3: change the correlation of item e and f in 2015 (go from correlation of 0 to correlation of 0.5)
    # Drift 4: change mean and var of H and persist it from 2018 on
    # Drift 5: change mean and var just for a year of J in 2021
import plotly.express as px
for year, df_plot in plot_data.items():
    fig = px.treemap(data_frame=df_plot, names='name', ids='idx', parents='parent_idx', color='kss', 
                    color_continuous_scale='blues',
                    title=year)
    fig.update_traces(root_color='lightgrey')
    fig.show()