In [None]:
#@title LICENSE

# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Use Vertex AI Extensions with a Custom Extension

## Overview


Vertex AI Extensions is a platform for creating and managing extensions that connect large language models to external systems via APIs. These external systems can provide LLMs with real-time data and perform data processing actions on their behalf. You can use pre-built or third-party extensions in Vertex AI Extensions.

Learn more about [Vertex AI Extensions](https://cloud.google.com/vertex-ai/docs/generative-ai/extensions/private/overview).

This notebook provides a simple getting started experience for the Vertex AI Extensions framework. This guide assumes that you are familiar with the Vertex AI Python SDK, [LangChain](https://python.langchain.com/docs/get_started/introduction), [OpenAPI specification](https://swagger.io/specification/), and [Cloud Run](https://cloud.google.com/run/docs).

### Objective

In this tutorial, you learn how to create an extension service backend on Cloud Run, register the extension with Vertex, and then use the extension in an application.

The steps performed include:

- Creating a simple service running on Cloud Run
- Creating an OpenAPI 3.1 YAML file for the Cloud Run service
- Registering the service as an extension with Vertex AI
- Using the extension to respond to user queries
- Integrate LangChain into the reasoning for an extension

### Additional Information

This tutorial uses the following Google Cloud services and resources:

- Vertex AI Extensions
- Cloud Run

**_NOTE_**: This notebook has been tested in the following environment:

* Python version = 3.11

### Authenticate your Google Cloud account

You must authenticate to Google Cloud to access the pre-release version of the Python SDK and the Vertex AI Extensions feature.

In [None]:
import sys

if "google.colab" in sys.modules:
    # Authenticate user to Google Cloud
    from google.colab import auth
    auth.authenticate_user()

### Installation

This tutorial requires a pre-release version of the Python SDK for Vertex AI. You must be logged in with credentials that are registered for the Vertex AI Extensions Private Preview.

Run the following command to download the library as a wheel from a Cloud Storage bucket:

In [None]:
!gsutil cp gs://vertex_sdk_private_releases/llm_extension/google_cloud_aiplatform-1.39.dev20231219+llm.extension-py2.py3-none-any.whl .

Then, install the following packages required to execute this notebook:

In [None]:
!pip install --force-reinstall --quiet google_cloud_aiplatform-1.39.dev20231219+llm.extension-py2.py3-none-any.whl
!pip install --upgrade --quiet "langchain==0.0.331" \
"openapi-schema-pydantic==1.2.4" \
"openapi-pydantic==0.3.2" \
"google-cloud-storage" \
"shapely<2"

Restart the kernel after installing packages:

In [None]:
import IPython
app = IPython.Application.instance()
app.kernel.do_shutdown(True)

## Before you begin

### Set up your Google Cloud project

**The following steps are required, regardless of your notebook environment.**

1. [Select or create a Google Cloud project](https://console.cloud.google.com/cloud-resource-manager). When you first create an account, you get a $300 free credit towards your compute/storage costs.
1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).
1. [Enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).
1. If you are running this notebook locally, you need to install the [Cloud SDK](https://cloud.google.com/sdk).
1. Your project must also be allowlisted for the Vertex AI Extension Private Preview.
1. This notebook requires that you have the following permissions for your GCP project:
- `roles/aiplatform.user`

### Set your project ID

**If you don't know your project ID**, try the following:
* Run `gcloud config list`.
* Run `gcloud projects list`.
* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)

In [None]:
PROJECT_ID = "integrations-379317"  # @param {type:"string"}

# Set the project id
!gcloud config set project {PROJECT_ID}

### Region

You can also change the `REGION` variable used by Vertex AI. Learn more about [Vertex AI regions](https://cloud.google.com/vertex-ai/docs/general/locations).

In [None]:
REGION = "us-central1"  # @param {type: "string"}

### Create a Cloud Storage bucket

Create a storage bucket to store intermediate artifacts such as datasets.

In [None]:
BUCKET_NAME = "vai-bucket"  # @param {type:"string"}
BUCKET_URI = f"gs://{BUCKET_NAME}"
extensions_prefix = "extension"

**Only if your bucket doesn't already exist**: Run the following cell to create your Cloud Storage bucket.

In [None]:
!gsutil mb -l $REGION -p $PROJECT_ID $BUCKET_URI

### Import libraries



In [None]:
import os

import vertexai
from google.cloud.aiplatform.private_preview import llm_extension
from google.cloud import storage

from langchain import PromptTemplate, LLMChain
from langchain.llms import VertexAI
from langchain.tools import OpenAPISpec, APIOperation
from langchain.chains import OpenAPIEndpointChain
from langchain.requests import Requests

### Initialize Vertex AI SDK for Python

Initialize the Vertex AI SDK for Python for your project.

In [None]:
vertexai.init(project=PROJECT_ID, location=REGION, staging_bucket=BUCKET_URI)

## Creating an API backend service

In this tutorial, you create a simple "hello world" service that runs on Cloud Run. This service returns "hello" in one of several languages, depending on the prompt sent from your extension (more on that later).

This simple example does not demonstrate best practices for authentication. Authenticating to your service is covered later.

**Note**: Your backend API service does not need to be hosted on Cloud Run.

### Deploy the API service to Cloud Run

In [None]:
if not os.path.exists("extension"):
    os.mkdir("extension")

In [None]:
%%writefile extension/Dockerfile

FROM python:3.11-slim

ENV PYTHONUNBUFFERED True

ENV APP_HOME /app
WORKDIR $APP_HOME
COPY . ./

RUN pip install --no-cache-dir -r requirements.txt

CMD exec gunicorn --bind :$PORT --workers 1 --threads 8 --timeout 0 extension:app

In [None]:
%%writefile extension/extension.py
from flask import Flask, jsonify, request

from astrapy.db import AstraDB

import os
import uuid

app = Flask(__name__)


@app.route("/", methods=["GET"])
def hello_world():
    data = {"message": "Hello, World!"}
    
    return jsonify(data)


@app.route("/readData", methods=["POST"])
def read_astra():
    params = request.json

    # Grab the Astra token and api endpoint from the environment
    token = params.get("token")
    api_endpoint = params.get("api_endpoint")

    # Error out if we don't have the token or api_endpoint
    if not token or not api_endpoint:
        return jsonify({"error": "token or api_endpoint not provided"})

    # Optional Params for the astra call
    table = params.get("tableName", "test")
    filter = params.get("filter", None)

    # Call the vector find operation
    astra_db = AstraDB(token=token, api_endpoint=api_endpoint)
    astra_db_collection = astra_db.collection(table)
    data = astra_db_collection.find(filter=filter)

    return jsonify(data)


@app.route("/insertData", methods=["POST"])
def insert_astra():
    params = request.json

    # Grab the Astra token and api endpoint from the environment
    token = params.get("token")
    api_endpoint = params.get("api_endpoint")

    # Error out if we don't have the token or api_endpoint
    if not token or not api_endpoint:
        return jsonify({"error": "token or api_endpoint not provided"})

    # Some example data
    doc = {
        "_id": str(uuid.uuid4()),
        "name": "Coded Cleats Copy",
        "description": "ChatGPT integrated sneakers that talk to you",
        "$vector": [0.25, 0.25, 0.25, 0.25, 0.25],
    }

    # Optional Params for the astra call
    table = params.get("tableName", "test")
    data = params.get("data", doc)

    # Initialize our vector db
    astra_db = AstraDB(token=token, api_endpoint=api_endpoint)
    astra_db_collection = astra_db.create_collection(table, dimension=5)

    # Insert a document into the test collection
    data = astra_db_collection.insert_one(data)

    return jsonify(data)


@app.route("/updateData", methods=["POST"])
def update_astra():
    params = request.json

    # Grab the Astra token and api endpoint from the environment
    token = params.get("token")
    api_endpoint = params.get("api_endpoint")

    # Error out if we don't have the token or api_endpoint
    if not token or not api_endpoint:
        return jsonify({"error": "token or api_endpoint not provided"})

    # Optional Params for the astra call
    table = params.get("tableName", "test")
    filter = params.get("filter", None)
    field_update = params.get("fieldUpdate", 1)

    # Call the vector find operation
    astra_db = AstraDB(token=token, api_endpoint=api_endpoint)
    astra_db_collection = astra_db.collection(table)
    data = astra_db_collection.find_one_and_update(filter=filter, update=field_update)

    return jsonify(data)


@app.route("/deleteData", methods=["POST"])
def delete_astra():
    params = request.json

    # Grab the Astra token and api endpoint from the environment
    token = params.get("token")
    api_endpoint = params.get("api_endpoint")

    # Error out if we don't have the token or api_endpoint
    if not token or not api_endpoint:
        return jsonify({"error": "token or api_endpoint not provided"})

    # Optional Params for the astra call
    table = params.get("tableName", "test")
    filter = params.get("filter", None)

    # Call the vector find operation
    astra_db = AstraDB(token=token, api_endpoint=api_endpoint)
    astra_db_collection = astra_db.collection(table)
    data = astra_db_collection.delete_many(filter=filter)

    return jsonify(data)


if __name__ == "__main__":
    app.run(debug=True, host="0.0.0.0", port=int(os.environ.get("PORT", 8080)))

In [None]:
%%writefile extension/requirements.txt
Flask==3.0.1
gunicorn==21.2.0
astrapy==0.7.0

In [None]:
%%writefile extension/.dockerignore
Dockerfile
README.md
*.pyc
*.pyo
*.pyd
__pycache__
.pytest_cache

Next, you deploy the service to Cloud Run. However, you might need to log in once more to deploy.

In [None]:
!gcloud auth login

In [None]:
!gcloud run deploy extension --region=us-central1 --allow-unauthenticated --source extension --no-user-output-enabled

List the most recent Cloud Run service that was deployed, then you'll copy its URL to the next cell:

In [None]:
!gcloud run services list | sort -k 3 | head -2

In [None]:
# @title Copy paste the output from the previous command here
service_url = "https://extension-we3tzahyyq-uc.a.run.app"  # @param {type:"string"}

token = "AstraCS:CfepwnbSLyXQrksvpTiexARn:5402b254725e3a6136865820108685f5258b45e29b7d9516ca4e66515f5b7b53"
api_endpoint = "https://910bf7a2-c197-4ea6-86be-ae691533154a-eu-west-1.apps.astra.datastax.com"

### Create an OpenAPI spec

Your Vertex Extension requires an OpenAPI 3.1 YAML file that defines routes, URL, HTTP methods, requests, and responses from your "backend" service. The following code creates a YAML file that you need to upload to your Cloud Storage bucket.

In [None]:
if not os.path.exists("extension-api"):
    os.mkdir("extension-api")

openapi_yaml = """
openapi: 3.1.0
info:
  title: Astra Vertex Extension
  description: An extension to perform CRUD actions on data within your Astra Database.
  version: 1.0.0
paths:
  /readData:
    get:
      operationId: readData
      summary: Search for data within the database
      parameters:
        - name: tableName
          in: query
          description: The name of the table to search
          required: true
          schema:
            type: string
        - name: filter
          in: query
          description: The filter to be applied to the search
          required: true
          schema:
            type: object
      responses:
        '200':
          description: Successful response
          content:
            application/json:
              schema:
                type: array
                items:
                  type: object
  /updateData:
    post:
      operationId: updateData
      summary: Update existing data within the database
      requestBody:
        required: true
        content:
          application/json:
            schema:
              type: object
              properties:
                tableName:
                  type: string
                filter:
                  type: object
                fieldUpdate:
                  type: object
      responses:
        '200':
          description: Data updated successfully
          content:
            application/json:
              schema:
                type: array
                items:
                  type: object
  /insertData:
    post:
      operationId: insertData
      summary: Insert new data into the database
      requestBody:
        required: true
        content:
          application/json:
            schema:
              type: object
              properties:
                tableName:
                  type: string
                data:
                  type: object
      responses:
        '200':
          description: Data inserted successfully
          content:
            application/json:
              schema:
                type: array
                items:
                  type: object
  /deleteData:
    delete:
      operationId: deleteData
      summary: Delete existing data within the database
      parameters:
        - name: tableName
          in: query
          description: The name of the table
          required: true
          schema:
            type: string
        - name: filter
          in: query
          description: The filter to be applied to the search
          required: true
          schema:
            type: object
      responses:
        '200':
          description: Data deleted successfully
          content:
            application/json:
              schema:
                type: array
                items:
                  type: object

"""

print(openapi_yaml)

In [None]:
%store openapi_yaml >extension-api/extension.yaml

Upload the OpenAPI YAML to your Cloud Storage bucket.

In [None]:
storage_client = storage.Client()
bucket = storage_client.bucket(BUCKET_NAME)
blob_name = f"{extensions_prefix}/extension.yaml"
blob = bucket.blob(blob_name)
blob.upload_from_filename("extension-api/extension.yaml")

### Test the service locally using LangChain

First, check that your service can accept simple HTTP `GET` requests:

In [None]:
url = f'{service_url}/'
print(url)

In [None]:
import requests
import uuid

doc = {
    "_id": str(uuid.uuid4()),
    "name": "Coded Cleats Copy Test",
    "description": "ChatGPT integrated sneakers that talk to you",
    "$vector": [0.25, 0.25, 0.25, 0.25, 0.25],
}

body_arguments = {
    "token": token,
    "api_endpoint": api_endpoint,
    "data": doc,
    "tableName": "demo",
}

print(url + "insertData")
r = requests.post(url + "insertData", json=body_arguments)

print(f"Status Code: {r.status_code}, Content: {r.text}")


In [None]:
body_arguments = {
    "token": token,
    "api_endpoint": api_endpoint,
    "tableName": "demo",
}

print(url + "readData")
r = requests.post(url + "readData", json=body_arguments)

print(f"Status Code: {r.status_code}, Content: {r.text}")

In [None]:
body_arguments = {
    "token": token,
    "api_endpoint": api_endpoint,
    "tableName": "demo",
    "fieldUpdate": {"$set": {"name": "Coded Cleats Copy Test Update"}},
}

print(url + "updateData")
r = requests.post(url + "updateData", json=body_arguments)

print(f"Status Code: {r.status_code}, Content: {r.text}")

In [None]:
body_arguments = {
    "token": token,
    "api_endpoint": api_endpoint,
    "tableName": "demo",
    "filter": {"name": "Coded Cleats Copy Test Update"},
}

print(url + "deleteData")
r = requests.post(url + "deleteData", json=body_arguments)

print(f"Status Code: {r.status_code}, Content: {r.text}")

In [None]:
body_arguments = {
    "token": token,
    "api_endpoint": api_endpoint,
    "tableName": "demo",
}

print(url + "readData")
r = requests.post(url + "readData", json=body_arguments)

print(f"Status Code: {r.status_code}, Content: {r.text}")

## Creating and using a custom extension

### Create the extension

Now that you've set up the service to fulfill extension requests, you can create the extension itself.

First, you'll define selection, invocation, and response examples:

In [None]:
# Include multiple selection, invocation, and response examples for best results.
extension_selection_examples = [{
    "query": "I want to learn about the products",
    "multi_steps": [{
        "thought": "I should call astra_tool for this",
        "extension_execution": {
          "operation_id": "readData",
          "extension_instruction": "Describe the product that you want to learn about",
          "observation": "Product descriptions come from the description field"
        }
      },
      {
        "thought": "Since the observation was successful, I should respond back to the user with results",
        "respond_to_user": {}
      }],
}]

extension_invocation_examples = [{
      "extension_instruction": "Tell me about your product.",
      "operation_id": "readData",
      "thought": "Issue a readData operation request on hello_astra tool",
      "operation_param": "{\"prompt\": \"Tell me about the product.\"}",
      "parameters_mentioned": ["prompt"]
}]

extension_response_examples = [{
  "operation_id": "readData",
  "response_template": "{{ response }}",
}]

Then, you'll create your extension and include the examples from the previous cell:

In [None]:
extension_astra = llm_extension.Extension.create(
    display_name = "Read Astra",
    description = "Loads data from AstraDB and returns it to the user",
    manifest = {
        "name": "astra_tool",
        "description": "Access and process data from AstraDB",
        "api_spec": {
            "open_api_gcs_uri": f"gs://{BUCKET_NAME}/{extensions_prefix}/extension.yaml"
        },
        "auth_config": {
            "auth_type": "NO_AUTH",
        },
        "extension_selection_examples": extension_selection_examples,
        "extension_invocation_examples": extension_invocation_examples,
        "extension_response_examples": extension_response_examples,
    },
)
extension_astra

Now that you've create your extension, let's confirm that it's registered:

In [None]:
print("Name:", extension_astra.gca_resource.name)
print("Display Name:", extension_astra.display_name)
print("Description:", extension_astra.gca_resource.description)

print(extension_astra.to_dict())

And you can test the functionality of the extension by executing it:

In [None]:
extension_astra.execute("find_astra",
    operation_params = {token: token, api_endpoint: api_endpoint},  # TODO: Pass in vector or query!
)

### Create a controller

The extension controller allows an application developer to specify which extensions to use.

You'll create an extension controller that refers to the extension/tool that you created in the previous section:

In [None]:
# Define the extensions controller service client
client_options = {"api_endpoint": f"{REGION}-aiplatform.googleapis.com"}
controller_client = llm_extension.extensions.services.extension_controller_service.client.ExtensionControllerServiceClient(
    client_options=client_options)

controller_spec = llm_extension.gapic.types.ExtensionControllerSpec()

controller_req = llm_extension.gapic.types.ExtensionController()
controller_req.display_name = "Astra DB Extension Controller"
controller_req.description = "Loads data from Astra DB"
controller_req.extension_controller_spec.extensions = [{"extension": extension_astra.resource_name}]

parent = f"projects/{PROJECT_ID}/locations/{REGION}"

controller_op = controller_client.create_extension_controller(
    parent=f"projects/{PROJECT_ID}/locations/{REGION}",
    extension_controller=controller_req
)
controller = controller_op.result(timeout=300)
print(controller.name)

### Use the controller in a query

Now that you have an extension and an extension controller, you can start using the controller to answer queries.

In [None]:
execution_client = llm_extension.extensions.services.extension_controller_execution_service.client.ExtensionControllerExecutionServiceClient(
    client_options=client_options
)

req = {
    "query": {
        "query": "Question: Tell me about the product?",
    },
    "name": controller.name,
}

response = execution_client.query(req)

print(response)

## Cleaning up

To clean up all Google Cloud resources used in this project, you can [delete the Google Cloud
project](https://cloud.google.com/resource-manager/docs/creating-managing-projects#shutting_down_projects) you used for the tutorial.

Otherwise, you can delete the individual resources you created in this tutorial:

In [None]:
# Delete the controller
# op = controller_client.delete_extension_controller(name=controller.name)
# op.result()

# Delete the extension
# extension_astra.delete()

# Delete Cloud Storage objects that were created
#delete_bucket = False
#if delete_bucket or os.getenv("IS_TESTING"):
#! gsutil -m rm -r $BUCKET_URI