In [1]:
#!pip install anchor_exp==0.0.0.5

In [3]:
#!gcloud auth activate-service-account --key-file energy-forecasting.json

In [5]:
import itertools
import anchor_explainer as anchor

In [6]:
names = [
  'price',
  'date_utc',
  'day',
  'hour',
  'prev_week_min',
  'prev_week_25th',
  'prev_week_50th',
  'prev_week_75th',
  'prev_week_max'
]
w_names = [
  'temperature',
  'wind_speed_100m',
  'wind_direction_100m',
  'air_density',
  'precipitation',
  'wind_gust',
  'radiation',
  'wind_speed',
  'wind_direction',
  'pressure'
]

w_names = [ ['loc'+str(i)+"_"+n for n in w_names] for i in range(18)]

w_names = list(itertools.chain.from_iterable(w_names))

names = names + w_names

In [7]:
explainer = anchor.Explainer()

In [8]:
explainer.load_data(
  gcs_path = 'gs://energyforecast/data/csv/MLDataTrain.csv',
  target_idx = 0,
  features_to_use = range(2,189),
  categorical_features = [2,3],
  feature_names = names,
  skip_first=True
)

In [9]:
token = !gcloud auth print-access-token

In [10]:
explainer.create_cmle_client(
  gcp_project = 'energy-forecasting',
  gcp_model = 'energyforecaster',
  gcp_model_version = 'new_energy',
  padding = (2,0),
  access_token = token[0]  
)

In [11]:
explainer.assess_model()

{'accuracy': 0.66}

In [12]:
rand_expl = explainer.explain_random_record() #, show_in_notebook = True)
rand_expl

{'anchor': 'day = 0 AND loc8_wind_speed_100m <= 2.20 AND loc0_radiation > 330.55 AND 3.80 < loc2_wind_speed_100m <= 5.80',
 'coverage': 0.0,
 'precision': 1.0,
 'prediction': b'0.43 < price <= 0.51',
 'record': ['0',
  '13',
  0.44320000000000004,
  0.51,
  0.6089,
  0.64295,
  0.7000000000000001,
  20.2,
  6.1,
  226.3,
  1.14,
  0.3,
  7.1,
  598.9,
  4.7,
  222,
  952.6,
  23.7,
  3.7,
  274.3,
  1.1,
  0.3,
  3.3,
  631.95,
  3.2,
  269.5,
  932.8,
  33.5,
  4.8,
  305.1,
  1.14,
  0.3,
  4.1,
  683.9,
  4.2,
  299.1,
  1001,
  25.6,
  2.9,
  275,
  1.19,
  0.3,
  2.7,
  718.6500000000001,
  3.9,
  279.05,
  1014.5,
  18.4,
  2.5,
  330.65,
  1.03,
  0.3,
  2.3,
  153.75,
  1.9,
  335.1,
  853.5,
  31.6,
  5.25,
  295.4,
  1.06,
  0.3,
  4.4,
  660.3,
  5.2,
  293.4,
  928.7,
  37.3,
  5.4,
  266.4,
  1.11,
  0.3,
  4.6,
  699.8499999999999,
  5,
  264.5,
  989.7,
  35.6,
  3.1,
  288.3,
  1.15,
  0.3,
  1.6,
  695.65,
  2.7,
  285.1,
  1012.3,
  22.9,
  1.5,
  16.2,
  1.07,
  0.3,

In [14]:
mod_record = ['5','12'] + rand_expl['record'][2:]
print(mod_record)
explainer.explain_record(mod_record) #, show_in_notebook = True)

['5', '12', 0.44320000000000004, 0.51, 0.6089, 0.64295, 0.7000000000000001, 20.2, 6.1, 226.3, 1.14, 0.3, 7.1, 598.9, 4.7, 222, 952.6, 23.7, 3.7, 274.3, 1.1, 0.3, 3.3, 631.95, 3.2, 269.5, 932.8, 33.5, 4.8, 305.1, 1.14, 0.3, 4.1, 683.9, 4.2, 299.1, 1001, 25.6, 2.9, 275, 1.19, 0.3, 2.7, 718.6500000000001, 3.9, 279.05, 1014.5, 18.4, 2.5, 330.65, 1.03, 0.3, 2.3, 153.75, 1.9, 335.1, 853.5, 31.6, 5.25, 295.4, 1.06, 0.3, 4.4, 660.3, 5.2, 293.4, 928.7, 37.3, 5.4, 266.4, 1.11, 0.3, 4.6, 699.8499999999999, 5, 264.5, 989.7, 35.6, 3.1, 288.3, 1.15, 0.3, 1.6, 695.65, 2.7, 285.1, 1012.3, 22.9, 1.5, 16.2, 1.07, 0.3, 1.4, 600, 2.3, 16.2, 910.2, 27.2, 4.7, 263.4, 1.03, 0.3, 4.3, 660, 3.4, 307.7, 881.5, 37.6, 5.7, 254.1, 1.06, 0.3, 4.9, 686.85, 6.3, 251.4, 945.1, 30.2, 2.6, 203.2, 1.02, 0.3, 2.4, 683.3, 1.9, 203.4, 890.1, 22.5, 3, 17.5, 1.11, 0.3, 1.6, 550, 2.4, 19.6, 942.8, 27.2, 7, 233.85000000000002, 1.03, 0.3, 6.8, 650, 4.4, 231.8, 881.6, 33.6, 3, 302.95, 1.07, 0.3, 5.3, 679, 3.4, 298.95, 937.5, 40.4

{'anchor': 'hour = 12 AND 908.50 < loc8_pressure <= 911.70 AND 5.80 < loc13_wind_speed_100m <= 8.20 AND 1.80 < loc16_wind_speed_100m <= 2.80',
 'coverage': 0.0,
 'precision': 1.0,
 'prediction': b'0.51 < price <= 0.60',
 'record': ['5',
  '12',
  0.44320000000000004,
  0.51,
  0.6089,
  0.64295,
  0.7000000000000001,
  20.2,
  6.1,
  226.3,
  1.14,
  0.3,
  7.1,
  598.9,
  4.7,
  222,
  952.6,
  23.7,
  3.7,
  274.3,
  1.1,
  0.3,
  3.3,
  631.95,
  3.2,
  269.5,
  932.8,
  33.5,
  4.8,
  305.1,
  1.14,
  0.3,
  4.1,
  683.9,
  4.2,
  299.1,
  1001,
  25.6,
  2.9,
  275,
  1.19,
  0.3,
  2.7,
  718.6500000000001,
  3.9,
  279.05,
  1014.5,
  18.4,
  2.5,
  330.65,
  1.03,
  0.3,
  2.3,
  153.75,
  1.9,
  335.1,
  853.5,
  31.6,
  5.25,
  295.4,
  1.06,
  0.3,
  4.4,
  660.3,
  5.2,
  293.4,
  928.7,
  37.3,
  5.4,
  266.4,
  1.11,
  0.3,
  4.6,
  699.8499999999999,
  5,
  264.5,
  989.7,
  35.6,
  3.1,
  288.3,
  1.15,
  0.3,
  1.6,
  695.65,
  2.7,
  285.1,
  1012.3,
  22.9,
  1.5,
  

In [9]:
explainer.explain_model(sample=3)

Unnamed: 0,anchor,coverage,precision,prediction
0,hour = 11 AND 1.40 < loc4_wind_speed <= 2.20 A...,0.0035,1.0,b'0.51 < price <= 0.60'
1,day = 0 AND loc1_wind_speed_100m <= 2.10 AND 1...,0.0021,1.0,b'0.43 < price <= 0.51'
2,hour = 6 AND loc1_wind_speed <= 1.20 AND day =...,0.0,1.0,b'0.51 < price <= 0.60'
