In [3]:
import itertools
import lime_explainer as lime

In [4]:
names = [
  'price',
  'date_utc',
  'day',
  'hour',
  'prev_week_min',
  'prev_week_25th',
  'prev_week_50th',
  'prev_week_75th',
  'prev_week_max'
]
w_names = [
  'temperature',
  'wind_speed_100m',
  'wind_direction_100m',
  'air_density',
  'precipitation',
  'wind_gust',
  'radiation',
  'wind_speed',
  'wind_direction',
  'pressure'
]

w_names = [ ['loc'+str(i)+"_"+n for n in w_names] for i in range(18)]

w_names = list(itertools.chain.from_iterable(w_names))
feature_names = names + w_names

In [5]:
explainer = lime.Explainer(model_type='regression')

In [6]:
explainer.load_data(
  gcs_path = 'gs://energyforecast/data/csv/MLDataTrain.csv',
  target_idx = 0,
  features_to_use = range(2,189),
  categorical_features = [2,3],
  feature_names = feature_names,
  skip_first=True,
  integer_rows=[]
)

In [7]:
explainer.create_cmle_client(
  gcp_project = 'energy-forecasting',
  gcp_model = 'energyforecaster',
  gcp_model_version = 'new_energy',
  padding = (2,0)
)

In [8]:
record, df = explainer.explain_random_record()

Intercept 0.5174112846477084
Prediction_local [0.56253181]
Right: 0.575685977935791


In [9]:
print(record)

[1.0, 6.0, 0.44320000000000004, 0.517, 0.6, 0.6409999999999999, 0.7069, 29.2, 4.0, 285.1, 1.1, 0.0, 3.4, 828.8, 3.0, 289.4, 952.6, 32.2, 3.3, 248.8, 1.06, 0.0, 2.9, 849.2, 2.5, 251.8, 931.0, 44.3, 3.3, 243.7, 1.09, 0.0, 3.7, 826.7, 2.9, 246.6, 997.3, 29.8, 4.9, 209.1, 1.17, 0.0, 4.9, 854.3, 4.5, 208.3, 1014.0, 29.3, 2.5, 205.7, 0.98, 0.0, 4.3, 843.5, 1.7, 209.4, 853.4, 43.3, 9.8, 220.7, 1.02, 0.0, 10.7, 852.0, 8.4, 222.2, 926.8, 48.9, 5.2, 224.1, 1.07, 0.0, 7.4, 841.0, 4.3, 227.6, 985.9, 45.9, 6.0, 227.1, 1.1, 0.0, 5.9, 840.8, 4.8, 232.2, 1010.5, 34.7, 6.3, 225.0, 1.03, 0.0, 8.9, 829.3, 4.8, 230.0, 911.3, 36.8, 10.0, 207.2, 0.99, 0.0, 11.0, 846.7, 8.1, 206.7, 881.8, 51.3, 5.9, 210.9, 1.01, 0.0, 6.6, 847.6, 5.2, 211.4, 943.6, 40.6, 4.5, 164.7, 0.99, 0.0, 5.3, 865.6, 3.8, 167.7, 891.0, 34.7, 3.2, 141.9, 1.07, 0.0, 4.5, 811.8, 2.6, 134.9, 943.3, 38.4, 8.9, 227.7, 0.99, 0.0, 8.3, 807.2, 7.2, 228.5, 884.6, 44.7, 8.7, 162.5, 1.03, 0.0, 7.0, 855.2, 7.3, 160.7, 939.5, 50.0, 7.2, 63.5, 1.07, 0.

In [10]:
df

Unnamed: 0,representation,weight
0,2.80 < loc4_wind_gust <= 4.40,0.043878
1,loc17_wind_gust > 3.90,-0.028451
2,loc2_radiation > 473.12,0.025402
3,loc12_precipitation <= 0.00,0.004669
4,loc3_air_density <= 1.20,-0.000378


In [11]:
mod_record = ','.join(['5','12'] + [str(e) for e in record[2:]])
print('Mod Record: {}\n\n'.format(mod_record))
_, mod_df = explainer.explain_record(mod_record)

Mod Record: 5,12,0.44320000000000004,0.517,0.6,0.6409999999999999,0.7069,29.2,4.0,285.1,1.1,0.0,3.4,828.8,3.0,289.4,952.6,32.2,3.3,248.8,1.06,0.0,2.9,849.2,2.5,251.8,931.0,44.3,3.3,243.7,1.09,0.0,3.7,826.7,2.9,246.6,997.3,29.8,4.9,209.1,1.17,0.0,4.9,854.3,4.5,208.3,1014.0,29.3,2.5,205.7,0.98,0.0,4.3,843.5,1.7,209.4,853.4,43.3,9.8,220.7,1.02,0.0,10.7,852.0,8.4,222.2,926.8,48.9,5.2,224.1,1.07,0.0,7.4,841.0,4.3,227.6,985.9,45.9,6.0,227.1,1.1,0.0,5.9,840.8,4.8,232.2,1010.5,34.7,6.3,225.0,1.03,0.0,8.9,829.3,4.8,230.0,911.3,36.8,10.0,207.2,0.99,0.0,11.0,846.7,8.1,206.7,881.8,51.3,5.9,210.9,1.01,0.0,6.6,847.6,5.2,211.4,943.6,40.6,4.5,164.7,0.99,0.0,5.3,865.6,3.8,167.7,891.0,34.7,3.2,141.9,1.07,0.0,4.5,811.8,2.6,134.9,943.3,38.4,8.9,227.7,0.99,0.0,8.3,807.2,7.2,228.5,884.6,44.7,8.7,162.5,1.03,0.0,7.0,855.2,7.3,160.7,939.5,50.0,7.2,63.5,1.07,0.0,7.2,844.5,6.1,67.0,991.9,31.6,1.8,262.6,1.01,0.0,6.1,823.7,1.2,277.5,884.7,31.6,1.8,262.6,1.01,0.0,6.1,823.7,1.2,277.5,884.7


Intercept 0.455770497968

In [12]:
mod_df

Unnamed: 0,representation,weight
0,loc9_air_density <= 1.05,0.034483
1,929.10 < loc1_pressure <= 931.50,0.02367
2,prev_week_25th > 0.49,0.017713
3,loc5_precipitation <= 0.00,0.012068
4,941.27 < loc10_pressure <= 943.70,0.009891


In [14]:
text_explanations = explainer.explain_model()

Intercept 0.5107759280541668
Prediction_local [0.49062331]
Right: 0.4652999937534332
Intercept 0.51963179783581
Prediction_local [0.44109417]
Right: 0.4230053424835205
Intercept 0.537327159347436
Prediction_local [0.47949718]
Right: 0.4723219871520996
Intercept 0.46434539537682107
Prediction_local [0.571108]
Right: 0.5940424203872681
Intercept 0.48062195348895714
Prediction_local [0.56869981]
Right: 0.5912395715713501


In [15]:
text_explanations

[[('loc13_air_density <= 1.05', 0.05576553893867203),
  ('107.10 < loc10_radiation <= 493.02', 0.03495892959493519),
  ('loc8_precipitation <= 0.00', 0.013908066613298428),
  ('day=6', 0.01068581163367),
  ('loc2_wind_speed_100m <= 3.80', -0.00855573833992803)],
 [('loc8_wind_speed_100m > 5.70', -0.03785857445732849),
  ('loc5_wind_speed_100m > 9.00', -0.021873830713285945),
  ('loc10_wind_speed_100m > 6.90', -0.018935217771014052),
  ('11.20 < loc6_temperature <= 19.70', 0.0053673148647476136),
  ('loc14_wind_speed_100m > 8.80', -0.005237316877630178)],
 [('117.00 < loc14_radiation <= 480.00', 0.04417734367466724),
  ('loc10_wind_direction <= 90.88', 0.02755047288042857),
  ('loc7_air_density > 1.24', 0.020005734557573512),
  ('prev_week_max <= 0.65', -0.0021524500136895527),
  ('2.30 < loc2_wind_speed <= 3.50', -0.0015032429651478355)],
 [('216.75 < loc12_wind_direction <= 334.20', -0.05789219651818054),
  ('loc8_precipitation <= 0.00', 0.030630263480528174),
  ('loc6_radiation <= 0.