In [1]:
#!pip install anchor_exp==0.0.0.5

In [3]:
import itertools
import anchor_explainer as anchor

In [4]:
names = [
  'price',
  'date_utc',
  'day',
  'hour',
  'prev_week_min',
  'prev_week_25th',
  'prev_week_50th',
  'prev_week_75th',
  'prev_week_max'
]
w_names = [
  'temperature',
  'wind_speed_100m',
  'wind_direction_100m',
  'air_density',
  'precipitation',
  'wind_gust',
  'radiation',
  'wind_speed',
  'wind_direction',
  'pressure'
]

w_names = [ ['loc'+str(i)+"_"+n for n in w_names] for i in range(18)]

w_names = list(itertools.chain.from_iterable(w_names))

names = names + w_names

In [5]:
explainer = anchor.Explainer()

In [6]:
explainer.load_data(
  gcs_path = 'gs://energyforecast/data/csv/MLDataTrain.csv',
  target_idx = 0,
  features_to_use = range(2,189),
  categorical_features = [2,3],
  feature_names = names,
  skip_first=True
)

In [7]:
token = !gcloud auth print-access-token

In [8]:
explainer.create_cmle_client(
  gcp_project = 'energy-forecasting',
  gcp_model = 'energyforecaster',
  gcp_model_version = 'new_energy',
  padding = (2,0),
  access_token = token[0]  
)

In [9]:
explainer.assess_model()

{'accuracy': 0.67}

In [10]:
rand_expl = explainer.explain_random_record() #, show_in_notebook = True)
rand_expl

{'anchor': 'hour = 2 AND loc15_temperature <= 14.57 AND loc1_wind_gust > 4.20 AND loc9_wind_direction > 282.50',
 'coverage': 0.0,
 'precision': 1.0,
 'prediction': b'price <= 0.43',
 'record': ['5',
  '2',
  '0.18850000000000003',
  '0.395',
  '0.4743',
  '0.551',
  '0.7715000000000001',
  '9',
  '8.9',
  '303.2',
  '1.16',
  '0',
  '11.2',
  '0',
  '4.7',
  '301.9',
  '948.2',
  '9.1',
  '5.5',
  '274.3',
  '1.14',
  '0',
  '6.2',
  '0',
  '3.2',
  '269.5',
  '926.6',
  '14.4',
  '9.3',
  '249.95',
  '1.22',
  '0',
  '10.3',
  '0',
  '6.1',
  '249',
  '998.3',
  '13.6',
  '10.2',
  '275',
  '1.22',
  '0',
  '11.2',
  '0',
  '7.7',
  '279.05',
  '1016.9',
  '4.2',
  '6',
  '330.65',
  '1.05',
  '0',
  '7.4',
  '0',
  '4.1',
  '335.1',
  '844.4',
  '10.9',
  '10.8',
  '295.4',
  '1.14',
  '0',
  '12',
  '0',
  '7.7',
  '293.4',
  '922.4',
  '16.1',
  '7.9',
  '266.4',
  '1.2',
  '0',
  '8.4',
  '0',
  '5',
  '264.5',
  '984.1',
  '8.2',
  '6.9',
  '288.3',
  '1.22',
  '0',
  '6.8',
  '

In [11]:
mod_record = ','.join(['5','12'] + rand_expl['record'][2:])
print(mod_record)
explainer.explain_record(mod_record) #, show_in_notebook = True)

5,12,0.18850000000000003,0.395,0.4743,0.551,0.7715000000000001,9,8.9,303.2,1.16,0,11.2,0,4.7,301.9,948.2,9.1,5.5,274.3,1.14,0,6.2,0,3.2,269.5,926.6,14.4,9.3,249.95,1.22,0,10.3,0,6.1,249,998.3,13.6,10.2,275,1.22,0,11.2,0,7.7,279.05,1016.9,4.2,6,330.65,1.05,0,7.4,0,4.1,335.1,844.4,10.9,10.8,295.4,1.14,0,12,0,7.7,293.4,922.4,16.1,7.9,266.4,1.2,0,8.4,0,5,264.5,984.1,8.2,6.9,288.3,1.22,0,6.8,0,2.7,238.9,1014.8,6.3,4.7,232.7,1.14,0,9,0,2.3,231.8,905,7.2,7.7,306.7,1.07,0,8.5,0,5.6,307.7,874.2,14.2,8.9,280,1.15,0,9.2,0,6.3,276.6,939.5,3.75,4.9,338.29999999999995,1.09,0,4.8,0,3.6,338,884.4,7.5,7.9,291.4,1.15,0,10.1,0,5.7,305.5,937.5,8.6,10.4,309.8,1.1,0,12.1,0,7,304.6,876.5,14.4,10.6,302.95,1.15,0,11.6,0,7.1,298.95,933.7,9.7,7.3,259.2,1.18,0,4.7,0,4,256.5,986,4.6,5.6,325.7,1.08,0,5.6,0,3.5,320.05,875.2,4.6,5.6,325.7,1.08,0,5.6,0,3.5,320.05,875.2


{'anchor': 'loc5_wind_gust > 9.30 AND loc5_wind_speed_100m > 9.00 AND loc3_radiation <= 0.00 AND 1.07 < loc17_air_density <= 1.09',
 'coverage': 0.0083,
 'precision': 1.0,
 'prediction': b'price <= 0.43',
 'record': ['5',
  '12',
  '0.18850000000000003',
  '0.395',
  '0.4743',
  '0.551',
  '0.7715000000000001',
  '9',
  '8.9',
  '303.2',
  '1.16',
  '0',
  '11.2',
  '0',
  '4.7',
  '301.9',
  '948.2',
  '9.1',
  '5.5',
  '274.3',
  '1.14',
  '0',
  '6.2',
  '0',
  '3.2',
  '269.5',
  '926.6',
  '14.4',
  '9.3',
  '249.95',
  '1.22',
  '0',
  '10.3',
  '0',
  '6.1',
  '249',
  '998.3',
  '13.6',
  '10.2',
  '275',
  '1.22',
  '0',
  '11.2',
  '0',
  '7.7',
  '279.05',
  '1016.9',
  '4.2',
  '6',
  '330.65',
  '1.05',
  '0',
  '7.4',
  '0',
  '4.1',
  '335.1',
  '844.4',
  '10.9',
  '10.8',
  '295.4',
  '1.14',
  '0',
  '12',
  '0',
  '7.7',
  '293.4',
  '922.4',
  '16.1',
  '7.9',
  '266.4',
  '1.2',
  '0',
  '8.4',
  '0',
  '5',
  '264.5',
  '984.1',
  '8.2',
  '6.9',
  '288.3',
  '1.2

In [9]:
explainer.explain_model(sample=3)

Unnamed: 0,anchor,coverage,precision,prediction
0,hour = 11 AND 1.40 < loc4_wind_speed <= 2.20 A...,0.0035,1.0,b'0.51 < price <= 0.60'
1,day = 0 AND loc1_wind_speed_100m <= 2.10 AND 1...,0.0021,1.0,b'0.43 < price <= 0.51'
2,hour = 6 AND loc1_wind_speed <= 1.20 AND day =...,0.0,1.0,b'0.51 < price <= 0.60'
