In [3]:
import itertools
import lime_explainer as lime

In [4]:
names = [
  'price',
  'date_utc',
  'day',
  'hour',
  'prev_week_min',
  'prev_week_25th',
  'prev_week_50th',
  'prev_week_75th',
  'prev_week_max'
]
w_names = [
  'temperature',
  'wind_speed_100m',
  'wind_direction_100m',
  'air_density',
  'precipitation',
  'wind_gust',
  'radiation',
  'wind_speed',
  'wind_direction',
  'pressure'
]

w_names = [ ['loc'+str(i)+"_"+n for n in w_names] for i in range(18)]

w_names = list(itertools.chain.from_iterable(w_names))
feature_names = names + w_names

In [5]:
explainer = lime.Explainer(model_type='regression')

In [6]:
explainer.load_data(
  gcs_path = 'gs://energyforecast/data/csv/MLDataTrain.csv',
  target_idx = 0,
  features_to_use = range(2,189),
  categorical_features = [2,3],
  feature_names = feature_names,
  skip_first=True,
  integer_rows=[]
)

In [7]:
explainer.create_cmle_client(
  gcp_project = 'energy-forecasting',
  gcp_model = 'energyforecaster',
  gcp_model_version = 'new_energy',
  padding = (2,0)
)

In [8]:
record, df = explainer.explain_random_record()

Intercept 0.5174112846477084
Prediction_local [0.56253181]
Right: 0.575685977935791


In [9]:
print(record)

[1.0, 6.0, 0.44320000000000004, 0.517, 0.6, 0.6409999999999999, 0.7069, 29.2, 4.0, 285.1, 1.1, 0.0, 3.4, 828.8, 3.0, 289.4, 952.6, 32.2, 3.3, 248.8, 1.06, 0.0, 2.9, 849.2, 2.5, 251.8, 931.0, 44.3, 3.3, 243.7, 1.09, 0.0, 3.7, 826.7, 2.9, 246.6, 997.3, 29.8, 4.9, 209.1, 1.17, 0.0, 4.9, 854.3, 4.5, 208.3, 1014.0, 29.3, 2.5, 205.7, 0.98, 0.0, 4.3, 843.5, 1.7, 209.4, 853.4, 43.3, 9.8, 220.7, 1.02, 0.0, 10.7, 852.0, 8.4, 222.2, 926.8, 48.9, 5.2, 224.1, 1.07, 0.0, 7.4, 841.0, 4.3, 227.6, 985.9, 45.9, 6.0, 227.1, 1.1, 0.0, 5.9, 840.8, 4.8, 232.2, 1010.5, 34.7, 6.3, 225.0, 1.03, 0.0, 8.9, 829.3, 4.8, 230.0, 911.3, 36.8, 10.0, 207.2, 0.99, 0.0, 11.0, 846.7, 8.1, 206.7, 881.8, 51.3, 5.9, 210.9, 1.01, 0.0, 6.6, 847.6, 5.2, 211.4, 943.6, 40.6, 4.5, 164.7, 0.99, 0.0, 5.3, 865.6, 3.8, 167.7, 891.0, 34.7, 3.2, 141.9, 1.07, 0.0, 4.5, 811.8, 2.6, 134.9, 943.3, 38.4, 8.9, 227.7, 0.99, 0.0, 8.3, 807.2, 7.2, 228.5, 884.6, 44.7, 8.7, 162.5, 1.03, 0.0, 7.0, 855.2, 7.3, 160.7, 939.5, 50.0, 7.2, 63.5, 1.07, 0.

In [10]:
df

Unnamed: 0,representation,weight
0,2.80 < loc4_wind_gust <= 4.40,0.043878
1,loc17_wind_gust > 3.90,-0.028451
2,loc2_radiation > 473.12,0.025402
3,loc12_precipitation <= 0.00,0.004669
4,loc3_air_density <= 1.20,-0.000378


In [11]:
mod_record = ','.join(['5','12'] + [str(e) for e in record[2:]])
print('Mod Record: {}\n\n'.format(mod_record))
_, mod_df = explainer.explain_record(mod_record)

Mod Record: 5,12,0.44320000000000004,0.517,0.6,0.6409999999999999,0.7069,29.2,4.0,285.1,1.1,0.0,3.4,828.8,3.0,289.4,952.6,32.2,3.3,248.8,1.06,0.0,2.9,849.2,2.5,251.8,931.0,44.3,3.3,243.7,1.09,0.0,3.7,826.7,2.9,246.6,997.3,29.8,4.9,209.1,1.17,0.0,4.9,854.3,4.5,208.3,1014.0,29.3,2.5,205.7,0.98,0.0,4.3,843.5,1.7,209.4,853.4,43.3,9.8,220.7,1.02,0.0,10.7,852.0,8.4,222.2,926.8,48.9,5.2,224.1,1.07,0.0,7.4,841.0,4.3,227.6,985.9,45.9,6.0,227.1,1.1,0.0,5.9,840.8,4.8,232.2,1010.5,34.7,6.3,225.0,1.03,0.0,8.9,829.3,4.8,230.0,911.3,36.8,10.0,207.2,0.99,0.0,11.0,846.7,8.1,206.7,881.8,51.3,5.9,210.9,1.01,0.0,6.6,847.6,5.2,211.4,943.6,40.6,4.5,164.7,0.99,0.0,5.3,865.6,3.8,167.7,891.0,34.7,3.2,141.9,1.07,0.0,4.5,811.8,2.6,134.9,943.3,38.4,8.9,227.7,0.99,0.0,8.3,807.2,7.2,228.5,884.6,44.7,8.7,162.5,1.03,0.0,7.0,855.2,7.3,160.7,939.5,50.0,7.2,63.5,1.07,0.0,7.2,844.5,6.1,67.0,991.9,31.6,1.8,262.6,1.01,0.0,6.1,823.7,1.2,277.5,884.7,31.6,1.8,262.6,1.01,0.0,6.1,823.7,1.2,277.5,884.7


Intercept 0.455770497968

In [12]:
mod_df

Unnamed: 0,representation,weight
0,loc9_air_density <= 1.05,0.034483
1,929.10 < loc1_pressure <= 931.50,0.02367
2,prev_week_25th > 0.49,0.017713
3,loc5_precipitation <= 0.00,0.012068
4,941.27 < loc10_pressure <= 943.70,0.009891


In [13]:
explainer.explain_model()

Intercept 0.5138182366793298
Prediction_local [0.44965706]
Right: 0.4278600215911865
Intercept 0.5270374182415
Prediction_local [0.38819119]
Right: 0.33764752745628357
Intercept 0.5438415988516622
Prediction_local [0.42794255]
Right: 0.4092845320701599
Intercept 0.5002159958963835
Prediction_local [0.59535508]
Right: 0.6190247535705566
Intercept 0.46202472585644416
Prediction_local [0.45313694]
Right: 0.4216153919696808


[[('prev_week_50th <= 0.45', -0.07625936318367244),
  ('88.50 < loc2_radiation <= 473.12', 0.039975791852310365),
  ('1.09 < loc16_air_density <= 1.12', 0.02615280555526107),
  ('loc5_precipitation <= 0.00', 0.01981780327905437),
  ('loc14_precipitation <= 0.00', -0.01857481945179395)],
 [('day=0', -0.05179403300297597),
  ('0.65 < prev_week_max <= 0.69', -0.05054366099387933),
  ('loc0_temperature > 15.90', -0.03711332952280304),
  ('loc12_pressure <= 941.18', 0.008245175174062662),
  ('101.00 < loc6_radiation <= 495.25', -0.007640381445741393)],
 [('loc6_wind_direction > 276.90', -0.040529810155231986),
  ('loc11_wind_gust > 3.60', -0.030148437533964314),
  ('loc11_wind_speed_100m > 4.00', -0.02583698316663189),
  ('6.60 < loc0_temperature <= 11.00', -0.012743165163056755),
  ('loc11_air_density > 1.10', -0.0066406575377021)],
 [('6.60 < loc14_wind_gust <= 9.50', 0.04462647831068589),
  ('924.70 < loc5_pressure <= 927.30', 0.02356861482991265),
  ('1.80 < loc4_wind_gust <= 2.80', 0.0