In [1]:
import requests
from vertexai.preview.generative_models import (
    Content,
    FunctionDeclaration,
    GenerativeModel,
    Part,
    Tool,
)
from google.cloud import bigquery

model = GenerativeModel("gemini-pro")

In [2]:
sql_query_func = FunctionDeclaration(
    name="sql_query",
    description="Get information from data in BigQuery using SQL queries",
    parameters={
    "type": "object",
    "properties": {
        "query": {
            "type": "string",
            "description": "SQL query that will help answer the user's question when run on a BigQuery dataset and table"
        }
    },
         "required": [
            "query",
      ]
  },
)

list_datasets_func = FunctionDeclaration(
    name="list_datasets",
    description="Get a list of datasets",
    parameters={
    "type": "object",
    "properties": {
  },
},
)

list_tables_func = FunctionDeclaration(
    name="list_tables",
    description="List tables in a dataset",
    parameters={
    "type": "object",
    "properties": {
        "dataset_id": {
            "type": "string",
            "description": "ID of the dataset to fetch tables from"
        }
    },
         "required": [
            "dataset_id",
      ]
  },
)

get_table_func = FunctionDeclaration(
    name="get_table",
    description="Get information about a table, including the description, schema, and number of rows",
    parameters={
    "type": "object",
    "properties": {
        "table_id": {
            "type": "string",
            "description": "ID of the table to get information about"
        }
    },
         "required": [
            "query",
      ]
  },
)

sql_query_tool = Tool(
    function_declarations=[
        sql_query_func,
        list_datasets_func, 
        list_tables_func,
        get_table_func,
    ],
)

In [3]:
model = GenerativeModel(
    "gemini-pro", generation_config={"temperature": 0}, tools=[sql_query_tool]
)
chat = model.start_chat()

client = bigquery.Client()

In [4]:
prompt = """
What type of data is in this database?
"""

response = chat.send_message(prompt)
response.candidates[0].content.parts[0]

function_call {
  name: "list_datasets"
  args {
  }
}

In [5]:
api_response = client.list_datasets()
api_response = str([dataset.dataset_id for dataset in api_response])

response = chat.send_message(
    Part.from_function_response(
        name="list_datasets",
        response={
            "content": api_response,
        },
    ),
)
response.candidates[0].content.parts[0]

function_call {
  name: "get_table"
  args {
    fields {
      key: "table_id"
      value {
        string_value: "thelook_ecommerce.products"
      }
    }
  }
}

In [6]:
params = {}
for key, value in response.candidates[0].content.parts[0].function_call.args.items():
    params[key] = value

api_response = client.get_table(params["table_id"])
api_response = str(api_response.to_api_repr())

response = chat.send_message(
    Part.from_function_response(
        name="get_table",
        response={
            "content": api_response,
        },
    ),
)
response.candidates[0].content.parts[0]

text: " The database contains information about products, including their ID, cost, category, name, brand, retail price, department, SKU, and distribution center ID."

In [None]:
prompt = """
What kind of products are there? And what is their price range?
"""

response = chat.send_message(prompt)
response.candidates[0].content.parts[0]

In [None]:
params = {}
for key, value in response.candidates[0].content.parts[0].function_call.args.items():
    params[key] = value

api_response = client.query(params["query"])
api_response = str([row for row in api_response])

response = chat.send_message(
    Part.from_function_response(
        name="sql_query",
        response={
            "content": api_response,
        },
    ),
)
response.candidates[0].content.parts[0]

In [None]:
prompt = """
How many different brands do we have? And what are the top 5 brands based on the number of items we have?
"""

response = chat.send_message(prompt)
response.candidates[0].content.parts[0]

In [None]:
params = {}
for key, value in response.candidates[0].content.parts[0].function_call.args.items():
    params[key] = value

api_response = client.query(params["query"])
api_response = str([row for row in api_response])

response = chat.send_message(
    Part.from_function_response(
        name="sql_query",
        response={
            "content": api_response,
        },
    ),
)
response.candidates[0].content.parts[0]

In [None]:
prompt = """
How many distribution centers do we have?
"""

response = chat.send_message(prompt)
response.candidates[0].content.parts[0]

In [None]:
params = {}
for key, value in response.candidates[0].content.parts[0].function_call.args.items():
    params[key] = value

api_response = client.query(params["query"])
api_response = str([row for row in api_response])

response = chat.send_message(
    Part.from_function_response(
        name="sql_query",
        response={
            "content": api_response,
        },
    ),
)
response.candidates[0].content.parts[0]

In [None]:
params = {}
for key, value in response.candidates[0].content.parts[0].function_call.args.items():
    params[key] = value

api_response = client.query(params["query"])
api_response = str([row for row in api_response])

response = chat.send_message(
    Part.from_function_response(
        name="sql_query",
        response={
            "content": api_response,
        },
    ),
)
response.candidates[0].content.parts[0]