From aa86fbce682500ce1b94836f26ae388fcdef2d40 Mon Sep 17 00:00:00 2001 From: Torry Yang Date: Wed, 18 Dec 2019 15:53:30 -0800 Subject: [PATCH] feat(tables): update samples to show explainability [(#2523)](https://github.com/GoogleCloudPlatform/python-docs-samples/issues/2523) * show xai * local feature importance * use updated client * use fixed library * use new model --- samples/tables/automl_tables_dataset.py | 223 ++++++++++++++---------- samples/tables/automl_tables_model.py | 84 ++++++--- samples/tables/automl_tables_predict.py | 65 +++++-- samples/tables/dataset_test.py | 34 ++-- samples/tables/model_test.py | 20 ++- samples/tables/predict_test.py | 11 +- samples/tables/requirements.txt | 1 + 7 files changed, 267 insertions(+), 171 deletions(-) create mode 100644 samples/tables/requirements.txt diff --git a/samples/tables/automl_tables_dataset.py b/samples/tables/automl_tables_dataset.py index d4be12b3..d9970510 100644 --- a/samples/tables/automl_tables_dataset.py +++ b/samples/tables/automl_tables_dataset.py @@ -79,23 +79,38 @@ def list_datasets(project_id, compute_region, filter_=None): print("Dataset id: {}".format(dataset.name.split("/")[-1])) print("Dataset display name: {}".format(dataset.display_name)) metadata = dataset.tables_dataset_metadata - print("Dataset primary table spec id: {}".format( - metadata.primary_table_spec_id)) - print("Dataset target column spec id: {}".format( - metadata.target_column_spec_id)) - print("Dataset target column spec id: {}".format( - metadata.target_column_spec_id)) - print("Dataset weight column spec id: {}".format( - metadata.weight_column_spec_id)) - print("Dataset ml use column spec id: {}".format( - metadata.ml_use_column_spec_id)) + print( + "Dataset primary table spec id: {}".format( + metadata.primary_table_spec_id + ) + ) + print( + "Dataset target column spec id: {}".format( + metadata.target_column_spec_id + ) + ) + print( + "Dataset target column spec id: {}".format( + metadata.target_column_spec_id + ) + ) + print( + "Dataset weight column spec id: {}".format( + metadata.weight_column_spec_id + ) + ) + print( + "Dataset ml use column spec id: {}".format( + metadata.ml_use_column_spec_id + ) + ) print("Dataset example count: {}".format(dataset.example_count)) print("Dataset create time:") print("\tseconds: {}".format(dataset.create_time.seconds)) print("\tnanos: {}".format(dataset.create_time.nanos)) print("\n") - # [END automl_tables_list_datasets] + # [END automl_tables_list_datasets] result.append(dataset) return result @@ -119,28 +134,31 @@ def list_table_specs( # List all the table specs in the dataset by applying filter. response = client.list_table_specs( - dataset_display_name=dataset_display_name, filter_=filter_) + dataset_display_name=dataset_display_name, filter_=filter_ + ) print("List of table specs:") for table_spec in response: # Display the table_spec information. print("Table spec name: {}".format(table_spec.name)) print("Table spec id: {}".format(table_spec.name.split("/")[-1])) - print("Table spec time column spec id: {}".format( - table_spec.time_column_spec_id)) + print( + "Table spec time column spec id: {}".format( + table_spec.time_column_spec_id + ) + ) print("Table spec row count: {}".format(table_spec.row_count)) print("Table spec column count: {}".format(table_spec.column_count)) - # [END automl_tables_list_specs] + # [END automl_tables_list_specs] result.append(table_spec) return result -def list_column_specs(project_id, - compute_region, - dataset_display_name, - filter_=None): +def list_column_specs( + project_id, compute_region, dataset_display_name, filter_=None +): """List all column specs.""" result = [] # [START automl_tables_list_column_specs] @@ -156,7 +174,8 @@ def list_column_specs(project_id, # List all the table specs in the dataset by applying filter. response = client.list_column_specs( - dataset_display_name=dataset_display_name, filter_=filter_) + dataset_display_name=dataset_display_name, filter_=filter_ + ) print("List of column specs:") for column_spec in response: @@ -166,7 +185,7 @@ def list_column_specs(project_id, print("Column spec display name: {}".format(column_spec.display_name)) print("Column spec data type: {}".format(column_spec.data_type)) - # [END automl_tables_list_column_specs] + # [END automl_tables_list_column_specs] result.append(column_spec) return result @@ -227,19 +246,20 @@ def get_table_spec(project_id, compute_region, dataset_id, table_spec_id): # Display the table spec information. print("Table spec name: {}".format(table_spec.name)) print("Table spec id: {}".format(table_spec.name.split("/")[-1])) - print("Table spec time column spec id: {}".format( - table_spec.time_column_spec_id)) + print( + "Table spec time column spec id: {}".format( + table_spec.time_column_spec_id + ) + ) print("Table spec row count: {}".format(table_spec.row_count)) print("Table spec column count: {}".format(table_spec.column_count)) # [END automl_tables_get_table_spec] -def get_column_spec(project_id, - compute_region, - dataset_id, - table_spec_id, - column_spec_id): +def get_column_spec( + project_id, compute_region, dataset_id, table_spec_id, column_spec_id +): """Get the column spec.""" # [START automl_tables_get_column_spec] # TODO(developer): Uncomment and set the following variables @@ -288,7 +308,7 @@ def import_data(project_id, compute_region, dataset_display_name, path): client = automl.TablesClient(project=project_id, region=compute_region) response = None - if path.startswith('bq'): + if path.startswith("bq"): response = client.import_data( dataset_display_name=dataset_display_name, bigquery_input_uri=path ) @@ -297,7 +317,7 @@ def import_data(project_id, compute_region, dataset_display_name, path): input_uris = path.split(",") response = client.import_data( dataset_display_name=dataset_display_name, - gcs_input_uris=input_uris + gcs_input_uris=input_uris, ) print("Processing import...") @@ -321,8 +341,10 @@ def export_data(project_id, compute_region, dataset_display_name, gcs_uri): client = automl.TablesClient(project=project_id, region=compute_region) # Export the dataset to the output URI. - response = client.export_data(dataset_display_name=dataset_display_name, - gcs_output_uri_prefix=gcs_uri) + response = client.export_data( + dataset_display_name=dataset_display_name, + gcs_output_uri_prefix=gcs_uri, + ) print("Processing export...") # synchronous check of operation status. @@ -331,12 +353,14 @@ def export_data(project_id, compute_region, dataset_display_name, gcs_uri): # [END automl_tables_export_data] -def update_dataset(project_id, - compute_region, - dataset_display_name, - target_column_spec_name=None, - weight_column_spec_name=None, - test_train_column_spec_name=None): +def update_dataset( + project_id, + compute_region, + dataset_display_name, + target_column_spec_name=None, + weight_column_spec_name=None, + test_train_column_spec_name=None, +): """Update dataset.""" # [START automl_tables_update_dataset] # TODO(developer): Uncomment and set the following variables @@ -354,29 +378,31 @@ def update_dataset(project_id, if target_column_spec_name is not None: response = client.set_target_column( dataset_display_name=dataset_display_name, - column_spec_display_name=target_column_spec_name + column_spec_display_name=target_column_spec_name, ) print("Target column updated. {}".format(response)) if weight_column_spec_name is not None: response = client.set_weight_column( dataset_display_name=dataset_display_name, - column_spec_display_name=weight_column_spec_name + column_spec_display_name=weight_column_spec_name, ) print("Weight column updated. {}".format(response)) if test_train_column_spec_name is not None: response = client.set_test_train_column( dataset_display_name=dataset_display_name, - column_spec_display_name=test_train_column_spec_name + column_spec_display_name=test_train_column_spec_name, ) print("Test/train column updated. {}".format(response)) # [END automl_tables_update_dataset] -def update_table_spec(project_id, - compute_region, - dataset_display_name, - time_column_spec_display_name): +def update_table_spec( + project_id, + compute_region, + dataset_display_name, + time_column_spec_display_name, +): """Update table spec.""" # [START automl_tables_update_table_spec] # TODO(developer): Uncomment and set the following variables @@ -391,7 +417,7 @@ def update_table_spec(project_id, response = client.set_time_column( dataset_display_name=dataset_display_name, - column_spec_display_name=time_column_spec_display_name + column_spec_display_name=time_column_spec_display_name, ) # synchronous check of operation status. @@ -399,12 +425,14 @@ def update_table_spec(project_id, # [END automl_tables_update_table_spec] -def update_column_spec(project_id, - compute_region, - dataset_display_name, - column_spec_display_name, - type_code, - nullable=None): +def update_column_spec( + project_id, + compute_region, + dataset_display_name, + column_spec_display_name, + type_code, + nullable=None, +): """Update column spec.""" # [START automl_tables_update_column_spec] # TODO(developer): Uncomment and set the following variables @@ -423,7 +451,8 @@ def update_column_spec(project_id, response = client.update_column_spec( dataset_display_name=dataset_display_name, column_spec_display_name=column_spec_display_name, - type_code=type_code, nullable=nullable + type_code=type_code, + nullable=nullable, ) # synchronous check of operation status. @@ -546,56 +575,62 @@ def delete_dataset(project_id, compute_region, dataset_display_name): if args.command == "list_datasets": list_datasets(project_id, compute_region, args.filter_) if args.command == "list_table_specs": - list_table_specs(project_id, - compute_region, - args.dataset_display_name, - args.filter_) + list_table_specs( + project_id, compute_region, args.dataset_display_name, args.filter_ + ) if args.command == "list_column_specs": - list_column_specs(project_id, - compute_region, - args.dataset_display_name, - args.filter_) + list_column_specs( + project_id, compute_region, args.dataset_display_name, args.filter_ + ) if args.command == "get_dataset": get_dataset(project_id, compute_region, args.dataset_display_name) if args.command == "get_table_spec": - get_table_spec(project_id, - compute_region, - args.dataset_display_name, - args.table_spec_id) + get_table_spec( + project_id, + compute_region, + args.dataset_display_name, + args.table_spec_id, + ) if args.command == "get_column_spec": - get_column_spec(project_id, - compute_region, - args.dataset_display_name, - args.table_spec_id, - args.column_spec_id) + get_column_spec( + project_id, + compute_region, + args.dataset_display_name, + args.table_spec_id, + args.column_spec_id, + ) if args.command == "import_data": - import_data(project_id, - compute_region, - args.dataset_display_name, - args.path) + import_data( + project_id, compute_region, args.dataset_display_name, args.path + ) if args.command == "export_data": - export_data(project_id, - compute_region, - args.dataset_display_name, - args.gcs_uri) + export_data( + project_id, compute_region, args.dataset_display_name, args.gcs_uri + ) if args.command == "update_dataset": - update_dataset(project_id, - compute_region, - args.dataset_display_name, - args.target_column_spec_name, - args.weight_column_spec_name, - args.ml_use_column_spec_name) + update_dataset( + project_id, + compute_region, + args.dataset_display_name, + args.target_column_spec_name, + args.weight_column_spec_name, + args.ml_use_column_spec_name, + ) if args.command == "update_table_spec": - update_table_spec(project_id, - compute_region, - args.dataset_display_name, - args.time_column_spec_display_name) + update_table_spec( + project_id, + compute_region, + args.dataset_display_name, + args.time_column_spec_display_name, + ) if args.command == "update_column_spec": - update_column_spec(project_id, - compute_region, - args.dataset_display_name, - args.column_spec_display_name, - args.type_code, - args.nullable) + update_column_spec( + project_id, + compute_region, + args.dataset_display_name, + args.column_spec_display_name, + args.type_code, + args.nullable, + ) if args.command == "delete_dataset": delete_dataset(project_id, compute_region, args.dataset_display_name) diff --git a/samples/tables/automl_tables_model.py b/samples/tables/automl_tables_model.py index cb8d85dd..a77dfe62 100644 --- a/samples/tables/automl_tables_model.py +++ b/samples/tables/automl_tables_model.py @@ -25,13 +25,15 @@ import os -def create_model(project_id, - compute_region, - dataset_display_name, - model_display_name, - train_budget_milli_node_hours, - include_column_spec_names=None, - exclude_column_spec_names=None): +def create_model( + project_id, + compute_region, + dataset_display_name, + model_display_name, + train_budget_milli_node_hours, + include_column_spec_names=None, + exclude_column_spec_names=None, +): """Create a model.""" # [START automl_tables_create_model] # TODO(developer): Uncomment and set the following variables @@ -116,19 +118,28 @@ def list_models(project_id, compute_region, filter_=None): print("Model id: {}".format(model.name.split("/")[-1])) print("Model display name: {}".format(model.display_name)) metadata = model.tables_model_metadata - print("Target column display name: {}".format( - metadata.target_column_spec.display_name)) - print("Training budget in node milli hours: {}".format( - metadata.train_budget_milli_node_hours)) - print("Training cost in node milli hours: {}".format( - metadata.train_cost_milli_node_hours)) + print( + "Target column display name: {}".format( + metadata.target_column_spec.display_name + ) + ) + print( + "Training budget in node milli hours: {}".format( + metadata.train_budget_milli_node_hours + ) + ) + print( + "Training cost in node milli hours: {}".format( + metadata.train_cost_milli_node_hours + ) + ) print("Model create time:") print("\tseconds: {}".format(model.create_time.seconds)) print("\tnanos: {}".format(model.create_time.nanos)) print("Model deployment state: {}".format(deployment_state)) print("\n") - # [END automl_tables_list_models] + # [END automl_tables_list_models] result.append(model) return result @@ -156,12 +167,24 @@ def get_model(project_id, compute_region, model_display_name): else: deployment_state = "undeployed" + # get features of top importance + feat_list = [ + (column.feature_importance, column.column_display_name) + for column in model.tables_model_metadata.tables_model_column_info + ] + feat_list.sort(reverse=True) + if len(feat_list) < 10: + feat_to_show = len(feat_list) + else: + feat_to_show = 10 + # Display the model information. print("Model name: {}".format(model.name)) print("Model id: {}".format(model.name.split("/")[-1])) print("Model display name: {}".format(model.display_name)) - print("Model metadata:") - print(model.tables_model_metadata) + print("Features of top importance:") + for feat in feat_list[:feat_to_show]: + print(feat) print("Model create time:") print("\tseconds: {}".format(model.create_time.seconds)) print("\tnanos: {}".format(model.create_time.nanos)) @@ -191,21 +214,23 @@ def list_model_evaluations( # List all the model evaluations in the model by applying filter. response = client.list_model_evaluations( - model_display_name=model_display_name, - filter_=filter_ + model_display_name=model_display_name, filter_=filter_ ) print("List of model evaluations:") for evaluation in response: print("Model evaluation name: {}".format(evaluation.name)) print("Model evaluation id: {}".format(evaluation.name.split("/")[-1])) - print("Model evaluation example count: {}".format( - evaluation.evaluated_example_count)) + print( + "Model evaluation example count: {}".format( + evaluation.evaluated_example_count + ) + ) print("Model evaluation time:") print("\tseconds: {}".format(evaluation.create_time.seconds)) print("\tnanos: {}".format(evaluation.create_time.nanos)) print("\n") - # [END automl_tables_list_model_evaluations] + # [END automl_tables_list_model_evaluations] result.append(evaluation) return result @@ -304,12 +329,15 @@ def display_evaluation( regression_metrics = model_evaluation.regression_evaluation_metrics if str(regression_metrics): print("Model regression metrics:") - print("Model RMSE: {}".format( - regression_metrics.root_mean_squared_error - )) + print( + "Model RMSE: {}".format(regression_metrics.root_mean_squared_error) + ) print("Model MAE: {}".format(regression_metrics.mean_absolute_error)) - print("Model MAPE: {}".format( - regression_metrics.mean_absolute_percentage_error)) + print( + "Model MAPE: {}".format( + regression_metrics.mean_absolute_percentage_error + ) + ) print("Model R^2: {}".format(regression_metrics.r_squared)) # [END automl_tables_display_evaluation] @@ -391,7 +419,7 @@ def delete_model(project_id, compute_region, model_display_name): create_model_parser.add_argument("--dataset_display_name") create_model_parser.add_argument("--model_display_name") create_model_parser.add_argument( - "--train_budget_milli_node_hours", type=int, + "--train_budget_milli_node_hours", type=int ) get_operation_status_parser = subparsers.add_parser( @@ -472,7 +500,7 @@ def delete_model(project_id, compute_region, model_display_name): project_id, compute_region, args.model_display_name, - args.model_evaluation_id + args.model_evaluation_id, ) if args.command == "display_evaluation": display_evaluation( diff --git a/samples/tables/automl_tables_predict.py b/samples/tables/automl_tables_predict.py index 069faac1..4a3423e3 100644 --- a/samples/tables/automl_tables_predict.py +++ b/samples/tables/automl_tables_predict.py @@ -25,10 +25,13 @@ import os -def predict(project_id, - compute_region, - model_display_name, - inputs): +def predict( + project_id, + compute_region, + model_display_name, + inputs, + feature_importance=None, +): """Make a prediction.""" # [START automl_tables_predict] # TODO(developer): Uncomment and set the following variables @@ -41,23 +44,50 @@ def predict(project_id, client = automl.TablesClient(project=project_id, region=compute_region) - response = client.predict( - model_display_name=model_display_name, - inputs=inputs) + if feature_importance: + response = client.predict( + model_display_name=model_display_name, + inputs=inputs, + feature_importance=True, + ) + else: + response = client.predict( + model_display_name=model_display_name, inputs=inputs + ) + print("Prediction results:") for result in response.payload: - print("Predicted class name: {}".format(result.display_name)) - print("Predicted class score: {}".format( - result.classification.score)) + print( + "Predicted class name: {}".format(result.tables.value.string_value) + ) + print("Predicted class score: {}".format(result.tables.score)) + + if feature_importance: + # get features of top importance + feat_list = [ + (column.feature_importance, column.column_display_name) + for column in result.tables.tables_model_column_info + ] + feat_list.sort(reverse=True) + if len(feat_list) < 10: + feat_to_show = len(feat_list) + else: + feat_to_show = 10 + + print("Features of top importance:") + for feat in feat_list[:feat_to_show]: + print(feat) # [END automl_tables_predict] -def batch_predict(project_id, - compute_region, - model_display_name, - gcs_input_uris, - gcs_output_uri): +def batch_predict( + project_id, + compute_region, + model_display_name, + gcs_input_uris, + gcs_output_uri, +): """Make a batch of predictions.""" # [START automl_tables_batch_predict] # TODO(developer): Uncomment and set the following variables @@ -107,10 +137,7 @@ def batch_predict(project_id, if args.command == "predict": predict( - project_id, - compute_region, - args.model_display_name, - args.file_path, + project_id, compute_region, args.model_display_name, args.file_path ) if args.command == "batch_predict": diff --git a/samples/tables/dataset_test.py b/samples/tables/dataset_test.py index 0b031850..f25239fb 100644 --- a/samples/tables/dataset_test.py +++ b/samples/tables/dataset_test.py @@ -26,7 +26,7 @@ PROJECT = os.environ["GCLOUD_PROJECT"] REGION = "us-central1" -STATIC_DATASET = "test_dataset_do_not_delete" +STATIC_DATASET = "do_not_delete_this_table" GCS_DATASET = "gs://cloud-ml-tables-data/bank-marketing.csv" ID = "{rand}_{time}".format( @@ -50,16 +50,14 @@ def ensure_dataset_ready(): dataset = automl_tables_dataset.create_dataset(PROJECT, REGION, name) if dataset.example_count is None or dataset.example_count == 0: - automl_tables_dataset.import_data( - PROJECT, REGION, name, GCS_DATASET - ) + automl_tables_dataset.import_data(PROJECT, REGION, name, GCS_DATASET) dataset = automl_tables_dataset.get_dataset(PROJECT, REGION, name) automl_tables_dataset.update_dataset( PROJECT, REGION, dataset.display_name, - target_column_spec_name='Deposit', + target_column_spec_name="Deposit", ) return dataset @@ -72,9 +70,7 @@ def test_dataset_create_import_delete(capsys): assert dataset is not None assert dataset.display_name == name - automl_tables_dataset.import_data( - PROJECT, REGION, name, GCS_DATASET - ) + automl_tables_dataset.import_data(PROJECT, REGION, name, GCS_DATASET) out, _ = capsys.readouterr() assert "Data imported." in out @@ -91,8 +87,8 @@ def test_dataset_update(capsys): PROJECT, REGION, dataset.display_name, - target_column_spec_name='Deposit', - weight_column_spec_name='Balance' + target_column_spec_name="Deposit", + weight_column_spec_name="Balance", ) out, _ = capsys.readouterr() @@ -106,9 +102,9 @@ def test_column_update(capsys): PROJECT, REGION, dataset.display_name, - column_spec_display_name='Job', - type_code='CATEGORY', - nullable=False + column_spec_display_name="Job", + type_code="CATEGORY", + nullable=False, ) out, _ = capsys.readouterr() @@ -117,13 +113,17 @@ def test_column_update(capsys): def test_list_datasets(): ensure_dataset_ready() - assert next( + assert ( + next( ( d - for d - in automl_tables_dataset.list_datasets(PROJECT, REGION) + for d in automl_tables_dataset.list_datasets(PROJECT, REGION) if d.display_name == STATIC_DATASET - ), None) is not None + ), + None, + ) + is not None + ) def test_list_table_specs(): diff --git a/samples/tables/model_test.py b/samples/tables/model_test.py index dde7d1f6..11afdead 100644 --- a/samples/tables/model_test.py +++ b/samples/tables/model_test.py @@ -26,7 +26,7 @@ PROJECT = os.environ["GCLOUD_PROJECT"] REGION = "us-central1" -STATIC_MODEL = "test_model_do_not_delete" +STATIC_MODEL = "do_not_delete_this_model_0" GCS_DATASET = "gs://cloud-ml-tables-data/bank-marketing.csv" ID = "{rand}_{time}".format( @@ -43,13 +43,17 @@ def _id(name): def test_list_models(): ensure_model_ready() - assert next( + assert ( + next( ( m - for m - in automl_tables_model.list_models(PROJECT, REGION) + for m in automl_tables_model.list_models(PROJECT, REGION) if m.display_name == STATIC_MODEL - ), None) is not None + ), + None, + ) + is not None + ) def test_list_model_evaluations(): @@ -70,8 +74,8 @@ def test_get_model_evaluations(): mep = automl_tables_model.get_model_evaluation( PROJECT, REGION, - model.name.rpartition('/')[2], - me.name.rpartition('/')[2] + model.name.rpartition("/")[2], + me.name.rpartition("/")[2], ) assert mep.name == me.name @@ -86,5 +90,5 @@ def ensure_model_ready(): dataset = dataset_test.ensure_dataset_ready() return automl_tables_model.create_model( - PROJECT, REGION, dataset.display_name, name, 1000 + PROJECT, REGION, dataset.display_name, name, 1000 ) diff --git a/samples/tables/predict_test.py b/samples/tables/predict_test.py index 5f65be2f..4ded386e 100644 --- a/samples/tables/predict_test.py +++ b/samples/tables/predict_test.py @@ -33,7 +33,7 @@ def test_predict(capsys): "Balance": 200, "Campaign": 2, "Contact": "cellular", - "Day": 4, + "Day": "4", "Default": "no", "Duration": 12, "Education": "primary", @@ -43,15 +43,16 @@ def test_predict(capsys): "MaritalStatus": "divorced", "Month": "jul", "PDays": 4, - "POutcome": '0', + "POutcome": "0", "Previous": 12, } ensure_model_online() - automl_tables_predict.predict(PROJECT, REGION, STATIC_MODEL, inputs) + automl_tables_predict.predict(PROJECT, REGION, STATIC_MODEL, inputs, True) out, _ = capsys.readouterr() - assert 'Predicted class name:' in out - assert 'Predicted class score:' in out + assert "Predicted class name:" in out + assert "Predicted class score:" in out + assert "Features of top importance:" in out def ensure_model_online(): diff --git a/samples/tables/requirements.txt b/samples/tables/requirements.txt new file mode 100644 index 00000000..77734522 --- /dev/null +++ b/samples/tables/requirements.txt @@ -0,0 +1 @@ +google-cloud-automl==0.9