## Project Overview

This notebook implements The Talking Tree. It trains a boosted tree model using BigQuery ML on the census_adult_income dataset to predict if an individual earns over $50K/year. Using BigQuery’s ML.EXPLAIN_PREDICT, it extracts decision paths, then leverages ML.GENERATE_TEXT to produce human-readable explanations (e.g., “High income predicted due to advanced education and long work hours”). The code runs entirely within BigQuery’s ecosystem, showcasing its ML and generative AI capabilities for interpretable AI.



## Setup

Create a new dedicated GCP project with a name like `treetalk`.

**Create a Service Account:**
Go to `IAM & Admin > Service Accounts`.
Click `+ Create Service Account` at the top. Name it `treetalk`.
Grant access to roles `BigQuery Admin` and `Vertex AI Administrator`.

**Generate the JSON Key:**
On the Service Accounts page, find your new service account (`treetalk@...`).
Click the account, then go to the Keys tab.
Click `Add Key > Create new key`.
Select JSON as the key type and click Create.
A json file will download to your computer. This is your service account JSON key. Save it securely.


**Upload the JSON Key as a Dataset:**
In the notebook, locate the Data panel on the right. Click Upload to create a new dataset. Name the dataset something like “Service Account Key” (keep it private for security). Drag and drop your JSON key file (e.g., treetalk-abcdef123456.json) or click to browse and select it. Click create. The dataset will appear under Data > Your Datasets in the notebook sidebar (path: /kaggle/input/bigquery-key/ or similar).

In [None]:
# Configs
sa_key_file_path = '/kaggle/input/sa-key/treetalk-470016-9e5b0e9489cd.json'
gcp_region = 'US'
llm_model = 'gemini-2.0-flash'

In [None]:
%%capture --no-stderr

!pip install google-cloud-bigquery-storage


from google.cloud import bigquery
from google.oauth2 import service_account
from google.cloud import resourcemanager_v3
from google.api_core import exceptions
import pandas as pd
import warnings
import json
import numpy as np
import os

# Convert warnings to errors, we don't want to have any warnings
warnings.simplefilter('error', UserWarning)

In [None]:
credentials = service_account.Credentials.from_service_account_file(
    sa_key_file_path,
    scopes=['https://www.googleapis.com/auth/cloud-platform']
)

gcp_project=credentials.project_id
bq_client = bigquery.Client(credentials=credentials, project=gcp_project)
print(gcp_project)

## Step 1: Explore the Dataset

Query `the census_adult_income` dataset to understand its structure and select a sample record for prediction.

In [None]:
# Query to preview the dataset
query_explore = f"""
SELECT *
FROM `bigquery-public-data.ml_datasets.census_adult_income`
LIMIT 5
"""
df_explore = bq_client.query(query_explore).to_dataframe()
print(df_explore)

In [None]:
# Select a sample record (e.g., index = 123) for prediction and explanation
sample_index = 23300
query_sample = f"""
SELECT *
FROM `bigquery-public-data.ml_datasets.census_adult_income`
ORDER BY (SELECT NULL)
LIMIT 1 OFFSET {sample_index}
"""
df_sample = bq_client.query(query_sample).to_dataframe()
print("\nSample Record for Prediction:")
print(df_sample)

## Step 2: Train the Boosted Tree Model

Train a boosted tree classifier using BigQuery ML to predict income (> $50K or not).

In [None]:
dataset_id = f"{gcp_project}.treetalk"
dataset = bigquery.Dataset(dataset_id)
dataset.location = gcp_region
bq_client.create_dataset(dataset, exists_ok=True)
print(f"Dataset {dataset_id} is ready.")

In [None]:
# Check if model already exists
model_id = f"{gcp_project}.treetalk.income_predictor"
try:
    bq_client.get_model(model_id)
    print(f"Model {model_id} already exists. Skipping training.")
except exceptions.NotFound:
    # Model doesn't exist, so create and train it
    print("Creating the model.")
    query_train = f"""
    CREATE MODEL `{model_id}`
    OPTIONS(
      model_type='BOOSTED_TREE_CLASSIFIER',
      input_label_cols=['income_bracket'],
      max_iterations=50
    ) AS
    SELECT *
    FROM `bigquery-public-data.ml_datasets.census_adult_income`
    WHERE income_bracket IS NOT NULL
    """
    job = bq_client.query(query_train)
    job.result()  # Wait for the query to complete
    print("Model training completed.")

## Step 3: Generate Prediction and Explanation

Use ML.EXPLAIN_PREDICT to predict income for the sample record and extract feature attributions for the decision path.

In [None]:
query_explain = f"""
SELECT *
FROM ML.EXPLAIN_PREDICT(
  MODEL `{gcp_project}.treetalk.income_predictor`,
  (SELECT * FROM `bigquery-public-data.ml_datasets.census_adult_income` 
   ORDER BY (SELECT NULL)
   LIMIT 1 OFFSET {sample_index}),
  STRUCT(3 AS top_k_features)
)
"""
print("Generating explanation data.")
df_explain = bq_client.query(query_explain).to_dataframe()
print("Explanation:")
print(df_explain)
explanation_str = df_explain.to_string(index=False)
print("Explanation string:")
print(explanation_str)

## Step 4: Generate Human-Readable Narrative

Use BigQuery’s generative AI (ML.GENERATE_TEXT) to create a plain-English explanation of the decision path.

In [None]:
"""Create a Vertex AI connection for BigQuery ML"""

# This requires the BigQuery Connection API
from google.cloud import bigquery_connection_v1 as bq_connection

client = bq_connection.ConnectionServiceClient(credentials=credentials)
parent = f"projects/{gcp_project}/locations/{gcp_region}"

connection = bq_connection.Connection()
connection.cloud_resource = bq_connection.CloudResourceProperties()

request = bq_connection.CreateConnectionRequest(
    parent=parent,
    connection_id="vertex-ai-connection",
    connection=connection,
)


try:
    result = client.create_connection(request=request)
    print(f"Created connection: {result.name}")
except Exception as e:
    print(f"Error creating connection: {e}")


In [None]:
# Create dataset
dataset_id = f"{gcp_project}.llm"
dataset = bigquery.Dataset(dataset_id)
dataset.location = gcp_region
bq_client.create_dataset(dataset, exists_ok=True)
print(f"Dataset {dataset_id} is ready.")

# Then create/deploy your Gemini model
create_model_query = f"""
CREATE OR REPLACE MODEL `{gcp_project}.llm.gemini`
REMOTE WITH CONNECTION `{gcp_project}.{gcp_region}.vertex-ai-connection`
OPTIONS (
  ENDPOINT = '{llm_model}'
)
"""

print("Creating model...")
bq_client.query(create_model_query).result()
print("Model created!")

In [None]:
query_narrative = f"""
SELECT ml_generate_text_result AS narrative
FROM ML.GENERATE_TEXT(
  MODEL `{gcp_project}.llm.gemini`,
  (SELECT CONCAT('Explain in simple terms and in short, why the model made this prediction: ', @explanation) AS prompt),
  STRUCT(0.7 AS temperature, 500 AS max_output_tokens)
)
"""

job_config = bigquery.QueryJobConfig(
    query_parameters=[
        bigquery.ScalarQueryParameter("explanation", "STRING", explanation_str)
    ]
)

# Execute the query
try:    
    df_narrative = bq_client.query(query_narrative, job_config=job_config).to_dataframe()
except Exception as e:
    print(f"❌ Error executing query: {str(e)}")

if not df_narrative.empty:
    try:
        # 1. Get the JSON string from the DataFrame
        json_string = df_narrative['narrative'].iloc[0]
        
        # 2. Parse the string into a Python dictionary
        response_data = json.loads(json_string)
        
        # 3. Extract the text from the correct path in the dictionary
        narrative_text = response_data['candidates'][0]['content']['parts'][0]['text']
        
        # 4. Print the clean text
        print(narrative_text.strip())

    except (json.JSONDecodeError, KeyError, IndexError) as e:
        print(f"❌ Error parsing the model's JSON response: {e}")
        
else:
    print("❌ No explanation generated.")