# Text Embeddings using Spanner's DBAPI Driver

Spanner has its own Python Client with a variety of Spanner-specific extensions and concepts.  However, for some applications, it's simpler to use Spanner's standards-compliant DBAPI Driver.

This driver provides the same Python API as is implemented by most other database engines' Python drivers.  So it can be easier to use in a mixed-database environment, or for developers who are coming from other database systems.

## Step 1: Install Dependencies

Spanner's DBAPI Driver is bundled into Spanner's client package.  Let's go ahead and install that package.

Let's also install Pandas, a popular library for manipulating datasets (dataframes) in Python.  We'll also install Scikit Learn because it comes with a useful collection of example datasets for ML purposes.  You can substitute the example dataset for your own data if you prefer.

In [None]:
!pip install google-cloud-spanner pandas scikit-learn

# Let's go ahead and import Pandas, since we'll use it in several places below.
import pandas as pd



## Step 2:  Authenticate to GCP

Google offers a variety of options for authenticating to GCP.  Please see the [documentation](https://googleapis.dev/python/google-api-core/latest/auth.html) for more details.

Google's hosted Notebook offerings provide a convenient built-in authentication method, as illustrated below.  This method will open a pop-up window asking you to authenticate this notebook using your Google credentials.

In [None]:
from google.colab import auth
auth.authenticate_user()

## Step 3: Connecting to Cloud Spanner

Now that we're authenticated, let's establish a connection to Cloud Spanner.  This connection will connect directly to your production Spanner instance and use the compute allocated to that instance.  (It doesn't use [DataBoost](https://cloud.google.com/spanner/docs/databoost/databoost-overview), which has great advantages for supported queries but does not support DML, DDL, or non-root-partitioned SELECT.)

Please modify this example to specify your own instance and database IDs.

In [None]:
import os

PROJECT_ID = os.environ.get("PROJECT_ID") or "span-cloud-testing"
INSTANCE_ID = os.environ.get("INSTANCE_ID") or "aseering-us-east4"
DATABASE_ID = os.environ.get("DATABASE_ID") or "gsql-test"

from google.cloud.spanner_dbapi import connect

connection = connect(INSTANCE_ID, DATABASE_ID, project=PROJECT_ID)
cursor = connection.cursor()
cursor.autocommit = True  # TODO: appears to be a no-op?

## Step 4: Load some Data

Now that we're connected, let's go ahead and create a table and load some data into it.  Let's go ahead and load some data based on

Spanner is often used to host production applications that generate their own data.  If you already have tables, feel free to skip this step and update the following steps to point at your own tables.  Otherwise, this section is a bit dense; hang onto your hats!  Or just run it and skip ahead assuming that you now have some data in your database.

In [None]:
# Import a bunch of dependencies
import email.parser
import sklearn.datasets
import time
import uuid

# Download the sklearn "20newsgroups" dataset into a local variable
newsgroups = sklearn.datasets.fetch_20newsgroups(subset="all")

# The newsgroups are NNTP messages.  NNTP messages are structured like e-mails.
# Parse them; then construct a DataFrame from a hardcoded list of fields.
# Spanner needs a unique ID for each record, so add a UUID column.
parser = email.parser.Parser()
newsgroup_messages = [parser.parsestr(message) for message in newsgroups.data]
df = pd.DataFrame({
    'id': [str(uuid.uuid4()) for _ in newsgroup_messages],
    'from': [message['From'] for message in newsgroup_messages],
    'subject': [message['Subject'] for message in newsgroup_messages],
    'nntp_posting_host': [message['Nntp-Posting-Host']
                          for message in newsgroup_messages],
    'organization': [message['Organization'] for message in newsgroup_messages],
    'body': [message.get_payload() for message in newsgroup_messages],
})

# Create a table with columns corresponding to common fields above.
# Treat all of the fields as un-bounded strings for now.  There's no cost to
# doing so, and we don't know how long future values might be.
# (Constrain the `id` field; it should be a valid UUID.)
cursor.execute("""
DROP TABLE IF EXISTS spanner_ml_example_20newsgroups;
CREATE TABLE spanner_ml_example_20newsgroups (
  `id` STRING(36) NOT NULL,
  `from` STRING(MAX) NOT NULL,
  `subject` STRING(MAX) NOT NULL,
  `nntp_posting_host` STRING(MAX) NOT NULL,
  `organization` STRING(MAX) NOT NULL,
  `body` STRING(MAX) NOT NULL
) PRIMARY KEY (id)
""")

# Spanner DDL statements don't require a commit on the Spanenr backend, but
# this causes Spanner's DBAPI Driver to flush queued-up DDL changes to the
# backend and to refresh its local transaction pool.
connection.commit()

# Load the data from the dataframe into Spanner.
# Slice and load several rows of data at a time for slightly better parallelism
# within Spanner, and to avoid exceeding Spanner transaction size limits.
BATCH_SIZE=1000
for i in range(0, len(df), BATCH_SIZE):
  cursor.executemany("""
    INSERT INTO spanner_ml_example_20newsgroups (
      `id`,    `from`,   `subject`,   `nntp_posting_host`,   `organization`,   `body`
    ) VALUES (
      %(id)s, %(from)s, %(subject)s, %(nntp_posting_host)s, %(organization)s, %(body)s
    )
    """, df[i:i+BATCH_SIZE].to_dict(orient="records"))
  connection.commit()

## Step 5: Register a Model

Spanner supports both custom and pre-trained ML models.  It uses models that are managed by [Vertex AI](https://cloud.google.com/vertex-ai).

Vertex AI provides a powerful suite of tools for managing models.  For now, let's just use Vertex AI's pre-trained ["Gecko" PaLM embedding model](https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings).

This model takes a string (such as a message body) as an argument, and returns a tuple of `{statistics, values}`:

* `values` - the actual embedding
* `statistics` - counts and other metadata, gathered while generating the embedding

It's automatically available in all GCP projects, published at the URL in the example below.  The model should be available to Spanner, but depending on your IAM configuration, you may be prompted to enable permissions to allow access.

In [None]:
cursor.execute("""
CREATE OR REPLACE MODEL spanner_ml_example_textembedding_gecko
INPUT (
  content STRING(MAX)
)
OUTPUT (
  embeddings STRUCT<
    statistics STRUCT<
      truncated BOOL, token_count DOUBLE
    >,
    values ARRAY<DOUBLE>
  >
)
REMOTE OPTIONS (
  endpoint = '//aiplatform.googleapis.com/projects/{PROJECT_ID}/locations/us-central1/publishers/google/models/textembedding-gecko'
)
""".format(PROJECT_ID=connection.instance._client.project))
connection.commit()


FailedPrecondition: ignored

## Step 6:  Add computed embedding column

Let's say you want to be able to build a tool to enable quick searching for relevant messages.  Now that we have registered a text-embedding model with Spanner, we can use that model to have Spanner automatically calculate and store the embedding for messages in the dataset.

With this approach, the embedding will automatically (and transactionally) be updated whenever messages are inserted or modified.

In [None]:
cursor.execute("""
ALTER TABLE spanner_ml_example_20newsgroups
ADD COLUMN body_embedding ARRAY<FLOAT64> NOT NULL
  GENERATED AS spanner_ml_example_textembedding_gecko(body).values STORED
""")
connection.commit()

## Step 7: Read back embeddings

Typically once embeddings are generated, your application would then be updated to use them.  This way, Spanner will maintain the embeddings over time, and your application can read an up-to-date embedding whenever it needs one.

As a simple example, let's read back the embeddings that we just generated and add them to our DataFrame.

In [None]:
# Read embeddings back from Spanner.
# Read back the row ID for each embedding as well, so we can match up the
# returned embeddings with the rows that we already have.
cursor.execute("""
SELECT `id`, `body_embedding` FROM spanner_ml_example_20newsgroups
""")

# Generate a dictionary from the resultset.
# Map each row's ID to its embedding.
embeddings = {row["id"]: row["body_embedding"] for row in cursor}

# For each ID in our dataframe,
# insert the corresponding embedding into a new "body_embeddings" field
# in the dataframe.
df["body_embedding"] = [embeddings[x] for x in df["id"]]

ProgrammingError: ignored

And there you have it!  Embeddings generated and maintained automatically by Spanner for your data, accessible in Python via query.

In [None]:
df.head(5)

## Step 7.1:  Construct and read embeddings dynamically

What if you don't want to store an embedding, you just want to generate it?  Spanner can invoke the new model as part of a query as well.

In [None]:
cursor.execute("""
SELECT id, spanner_ml_example_textembedding_gecko(body).values
FROM spanner_ml_example_20newsgroups
""")

# Query computed and returned the same values as before
assert {row["id"]: row["body_embedding"] for row in cursor} == embeddings

cursor.execute("""
SELECT spanner_ml_example_textembedding_gecko(%(text)s).values AS embedding
""", {'text': "Hello World!"})
next(cursor)['embedding']