In [1]:
import warnings
import itertools
import lime_explainer as limee
warnings.filterwarnings(action='ignore')

  from ._conv import register_converters as _register_converters


In [2]:
names = [
  'price',
  'date_utc',
  'day',
  'hour',
  'prev_week_min',
  'prev_week_25th',
  'prev_week_50th',
  'prev_week_75th',
  'prev_week_max'
]
w_names = [
  'temperature',
  'wind_speed_100m',
  'wind_direction_100m',
  'air_density',
  'precipitation',
  'wind_gust',
  'radiation',
  'wind_speed',
  'wind_direction',
  'pressure'
]

w_names = [ ['loc'+str(i)+"_"+n for n in w_names] for i in range(18)]

w_names = list(itertools.chain.from_iterable(w_names))
feature_names = names + w_names

In [3]:
explainer = limee.Explainer()

In [4]:
explainer.load_data(
  gcs_path = 'gs://energyforecast/data/csv/MLDataTrain.csv',
  features_to_use = range(2,189),
  categorical_features = [2,3],
  feature_names = feature_names,
  skip_rows=1
)

In [5]:
explainer.create_cmle_client(
  gcp_project = 'energy-forecasting',
  gcp_model = 'energyforecaster',
  gcp_model_version = 'new_energy',
  padding = (2,0)
)

In [6]:
record, df = explainer.explain_random_record(numeric_rows=[])
print(record)
print(df)

Intercept 0.5354139563156892
Prediction_local [0.39530005]
Right: 0.35974639654159546
[1.0, 1.0, 0.041299999999999996, 0.33, 0.434, 0.5048, 0.7715000000000001, 2.1, 2.9, 97.2, 1.22, 0.7, 2.3, 0.0, 1.6, 94.6, 960.2, 4.9, 3.0, 76.5, 1.17, 0.7, 2.7, 0.0, 1.8, 79.4, 937.3, 6.1, 3.5, 21.2, 1.25, 0.7, 2.4, 0.0, 1.6, 15.3, 1005.5, 12.4, 2.6, 117.6, 1.25, 0.7, 2.6, 0.0, 2.0, 118.1, 1021.4, -2.8, 1.1, 197.5, 1.1, 0.7, 1.2, 0.0, 0.8, 186.8, 857.2, 3.0, 7.9, 81.9, 1.18, 0.7, 5.1, 0.0, 4.1, 81.7, 934.0, 4.6, 2.1, 254.8, 1.25, 0.7, 1.7, 0.0, 1.2, 247.5, 994.6, 5.0, 1.7, 142.6, 1.28, 0.7, 1.3, 0.0, 0.9, 147.5, 1020.3, -1.5, 1.4, 337.2, 1.18, 0.7, 1.1, 0.0, 0.7, 320.4, 918.2, -2.9, 0.9, 279.5, 1.14, 0.7, 0.9, 0.0, 0.6, 286.6, 885.2, 2.8, 3.9, 13.3, 1.2, 0.7, 3.0, 0.0, 2.2, 12.3, 949.3, 2.4, 2.9, 183.0, 1.13, 0.7, 2.6, 0.0, 1.8, 181.7, 892.9, -1.1, 2.7, 337.1, 1.22, 0.7, 2.1, 0.0, 1.5, 322.6, 951.1, -0.8, 1.5, 30.2, 1.13, 0.7, 0.9, 0.0, 0.6, 144.5, 887.0, 5.7, 4.2, 145.1, 1.18, 0.7, 2.9, 0.0, 2.2, 136

In [7]:
mod_record = ','.join(['5','12'] + [str(e) for e in record[2:]])
print('Mod Record: {}\n\n'.format(mod_record))
explainer.explain_record(mod_record)

Mod Record: 5,12,0.0413,0.33,0.434,0.5048,0.7715,2.1,2.9,97.2,1.22,0.7,2.3,0.0,1.6,94.6,960.2,4.9,3.0,76.5,1.17,0.7,2.7,0.0,1.8,79.4,937.3,6.1,3.5,21.2,1.25,0.7,2.4,0.0,1.6,15.3,1005.5,12.4,2.6,117.6,1.25,0.7,2.6,0.0,2.0,118.1,1021.4,-2.8,1.1,197.5,1.1,0.7,1.2,0.0,0.8,186.8,857.2,3.0,7.9,81.9,1.18,0.7,5.1,0.0,4.1,81.7,934.0,4.6,2.1,254.8,1.25,0.7,1.7,0.0,1.2,247.5,994.6,5.0,1.7,142.6,1.28,0.7,1.3,0.0,0.9,147.5,1020.3,-1.5,1.4,337.2,1.18,0.7,1.1,0.0,0.7,320.4,918.2,-2.9,0.9,279.5,1.14,0.7,0.9,0.0,0.6,286.6,885.2,2.8,3.9,13.3,1.2,0.7,3.0,0.0,2.2,12.3,949.3,2.4,2.9,183.0,1.13,0.7,2.6,0.0,1.8,181.7,892.9,-1.1,2.7,337.1,1.22,0.7,2.1,0.0,1.5,322.6,951.1,-0.8,1.5,30.2,1.13,0.7,0.9,0.0,0.6,144.5,887.0,5.7,4.2,145.1,1.18,0.7,2.9,0.0,2.2,136.5,944.3,6.7,1.8,78.4,1.24,0.7,1.7,0.0,1.3,78.7,999.5,-3.3,2.1,76.5,1.15,0.7,2.8,0.0,2.0,132.6,887.9,-3.3,2.1,76.5,1.15,0.7,2.8,0.0,2.0,132.6,887.9


Intercept 0.5355745389417982
Prediction_local [0.40974288]
Right: 0.36163330078125


([5.0,
  12.0,
  0.0413,
  0.33,
  0.434,
  0.5048,
  0.7715,
  2.1,
  2.9,
  97.2,
  1.22,
  0.7,
  2.3,
  0.0,
  1.6,
  94.6,
  960.2,
  4.9,
  3.0,
  76.5,
  1.17,
  0.7,
  2.7,
  0.0,
  1.8,
  79.4,
  937.3,
  6.1,
  3.5,
  21.2,
  1.25,
  0.7,
  2.4,
  0.0,
  1.6,
  15.3,
  1005.5,
  12.4,
  2.6,
  117.6,
  1.25,
  0.7,
  2.6,
  0.0,
  2.0,
  118.1,
  1021.4,
  -2.8,
  1.1,
  197.5,
  1.1,
  0.7,
  1.2,
  0.0,
  0.8,
  186.8,
  857.2,
  3.0,
  7.9,
  81.9,
  1.18,
  0.7,
  5.1,
  0.0,
  4.1,
  81.7,
  934.0,
  4.6,
  2.1,
  254.8,
  1.25,
  0.7,
  1.7,
  0.0,
  1.2,
  247.5,
  994.6,
  5.0,
  1.7,
  142.6,
  1.28,
  0.7,
  1.3,
  0.0,
  0.9,
  147.5,
  1020.3,
  -1.5,
  1.4,
  337.2,
  1.18,
  0.7,
  1.1,
  0.0,
  0.7,
  320.4,
  918.2,
  -2.9,
  0.9,
  279.5,
  1.14,
  0.7,
  0.9,
  0.0,
  0.6,
  286.6,
  885.2,
  2.8,
  3.9,
  13.3,
  1.2,
  0.7,
  3.0,
  0.0,
  2.2,
  12.3,
  949.3,
  2.4,
  2.9,
  183.0,
  1.13,
  0.7,
  2.6,
  0.0,
  1.8,
  181.7,
  892.9,
  -1.1,
  2.7,
  33