diff --git a/samples/snippets/classification_boosted_tree_model_test.py b/samples/snippets/classification_boosted_tree_model_test.py index fbc9369dde..707ce16279 100644 --- a/samples/snippets/classification_boosted_tree_model_test.py +++ b/samples/snippets/classification_boosted_tree_model_test.py @@ -14,7 +14,7 @@ def test_boosted_tree_model(random_model_id: str) -> None: - # your_model_id = random_model_id + your_model_id = random_model_id # [START bigquery_dataframes_bqml_boosted_tree_prepare] import bigframes.pandas as bpd @@ -39,4 +39,28 @@ def test_boosted_tree_model(random_model_id: str) -> None: ) del input_data["functional_weight"] # [END bigquery_dataframes_bqml_boosted_tree_prepare] + # [START bigquery_dataframes_bqml_boosted_tree_create] + from bigframes.ml import ensemble + + # input_data is defined in an earlier step. + training_data = input_data[input_data["dataframe"] == "training"] + X = training_data.drop(columns=["income_bracket", "dataframe"]) + y = training_data["income_bracket"] + + # create and train the model + census_model = ensemble.XGBClassifier( + n_estimators=1, + booster="gbtree", + tree_method="hist", + max_iterations=1, # For a more accurate model, try 50 iterations. + subsample=0.85, + ) + census_model.fit(X, y) + + census_model.to_gbq( + your_model_id, # For example: "your-project.census.census_model" + replace=True, + ) + # [END bigquery_dataframes_bqml_boosted_tree_create] assert input_data is not None + assert census_model is not None