In [None]:
!pip install --upgrade --user google-cloud-aiplatform>=1.29.0 google-cloud-storage langchain chromadb

In [None]:
# Restart kernel after installs so that your environment can access the new packages
import IPython

app = IPython.Application.instance()
app.kernel.do_shutdown(True)

# Set Variables

In [None]:
# get project ID
PROJECT_ID = ! gcloud config get-value project
PROJECT_ID = PROJECT_ID[0]
LOCATION = "us-central1" # @param {type:"string"}
DOCUMENT_URL = "https://www.gutenberg.org/cache/epub/55/pg55.txt" # @param {type:"string"}

# define project information manually if the above code didn't work
if PROJECT_ID == "(unset)":
  PROJECT_ID = "[your-project-id]" # @param {type:"string"}

print(PROJECT_ID)

In [None]:
import getpass
import os

# getpass will prompt for an API Key
# The API Key is needed for Chroma DB
API_KEY = getpass.getpass("Provide your Google API Key")

## Initial Vertex AI

In [None]:
# init the aiplatform package
from google.cloud import aiplatform
aiplatform.init(project=PROJECT_ID, location=LOCATION)

## Just test the Embeddings model

In [None]:
from vertexai.language_models import TextEmbeddingModel

def text_embedding(text_to_embed) -> list:
    """Text embedding with a Large Language Model."""
    model = TextEmbeddingModel.from_pretrained("textembedding-gecko@002")
    embeddings = model.get_embeddings([text_to_embed])
    for embedding in embeddings:
        vector = embedding.values
        print(f"Length of Embedding Vector: {len(vector)}")
    return vector

In [None]:
emb1 = text_embedding("Hello World")
print(emb1)

In [None]:
from vertexai.language_models import TextEmbeddingInput

emd_with_task_type  = TextEmbeddingInput(
    text="Hello World",
    task_type="RETRIEVAL_QUERY"
)

emb2 = text_embedding(emd_with_task_type)
print(emb2)


emd_with_task_type  = TextEmbeddingInput(
    text="Hello World",
    task_type="RETRIEVAL_DOCUMENT"
)

emb3 = text_embedding(emd_with_task_type)
print(emb3)

In [None]:
print(emb1 == emb2)
print(emb2 == emb3)

## Load the document that you want to analyze with the LLM

In [None]:
from langchain_community.document_loaders import WebBaseLoader

loader = WebBaseLoader(DOCUMENT_URL)

data = loader.load()
document = data[0].page_content

# Print first 50 characters
print(document[:50])

## The document needs to be split into chucks.

Chunk size and chunk overlap are the interesting variables. Each chunk will be represented by an embedding. Larger chunks will mean there is more data for the LLM to analyze. Smaller chunks will mean less data will be passed to the LLM at inference time.

Smaller chunks also mean more embeddings.

In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter

CHUNK_SIZE = 10000
CHUNK_OVERLAP = 20

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=CHUNK_SIZE,
    chunk_overlap=CHUNK_OVERLAP,
    length_function=len,
    is_separator_regex=False,
)

chunks = text_splitter.create_documents([document])

# Convert chunks into list of strings
chunks = [chunk.page_content for chunk in chunks]

print(len(chunks))
print(chunks[10][:50])

## Generate the embeddings in batches

Below are helper functions that are used to rate limit and batch the geberation of the embeddings.

In [None]:
from typing import Generator, List, Optional, Tuple
import functools
import time
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from tqdm import tqdm
import math


# Define an embedding method that uses the model
def encode_texts_to_embeddings(chunks: List[str]) -> List[Optional[List[float]]]:
    try:
        model = TextEmbeddingModel.from_pretrained("textembedding-gecko@002")

        # convert chunks into list[TextEmbeddingInput]
        inputs = [TextEmbeddingInput(text=chunk, task_type="RETRIEVAL_DOCUMENT") for chunk in chunks]
        embeddings = model.get_embeddings(inputs)

        # You could also generate the embeddings without the task_type.
        # Then, you are just passing a collection of strings. In a real app
        # test it multiple ways.
        # The alternative would be as follows
        # embeddings = model.get_embeddings(chunks)

        return [embedding.values for embedding in embeddings]
    except Exception:
        return [None for _ in range(len(chunks))]


# Generator function to yield batches of descriptions
def generate_batches(
    chunks: List[str], batch_size: int
) -> Generator[List[str], None, None]:
    for i in range(0, len(chunks), batch_size):
        yield chunks[i : i + batch_size]


def encode_text_to_embedding_batched(
    chunks: List[str], api_calls_per_minute: int = 20, batch_size: int = 5
) -> Tuple[List[bool], np.ndarray]:

    embeddings_list: List[List[float]] = []

    # Prepare the batches using a generator
    batches = generate_batches(chunks, batch_size)

    seconds_per_job = 60 / api_calls_per_minute

    with ThreadPoolExecutor() as executor:
        futures = []
        for batch in tqdm(
            batches, total=math.ceil(len(chunks) / batch_size), position=0
        ):
            futures.append(
                executor.submit(functools.partial(encode_texts_to_embeddings), batch)
            )
            time.sleep(seconds_per_job)

        for future in futures:
            embeddings_list.extend(future.result())

    is_successful = [
        embedding is not None for sentence, embedding in zip(chunks, embeddings_list)
    ]
    embeddings_list_successful = np.squeeze(
        np.stack([embedding for embedding in embeddings_list if embedding is not None])
    )
    return is_successful, embeddings_list_successful


In [None]:
embeddings_response = encode_text_to_embedding_batched(chunks, api_calls_per_minute=20)

In [None]:
# Need the IDs for the Vector database collection
# Just use counters
ids = [str(i) for i in range(len(chunks))]

# chunks has the text, embeddings_response has the embeddings
print(chunks[10][:50])
print(embeddings_response[1][10][:10])

# All the collections need the same number of items.
print(len(chunks))
print(len(embeddings_response[1]))
print(len(ids))


# Chroma DB is an open source database

Google Vector Search may be a better production solution, but Chroma DB is free. So, it makes for a better demo.

In [None]:
import chromadb
import chromadb.utils.embedding_functions as embedding_functions

palm_embedding = embedding_functions.GooglePalmEmbeddingFunction(
    api_key=API_KEY)

chroma_client = chromadb.Client()

# Make sure the collection does not exist
try:
  chroma_client.delete_collection("document-collection")
except:
  pass

# Create the collection
chroma_collection = chroma_client.create_collection("document-collection", embedding_function=palm_embedding)

# Add the data to the collection
chroma_collection.add(ids=ids, documents=chunks, embeddings=embeddings_response[1])
chroma_collection.count()


## Helper functions

In [None]:
import vertexai
from vertexai.preview.generative_models import GenerativeModel, Part

# Prompt template just uses a string template to
# format the promopt for the LLM
prompt_template = """

Answer the users question using the following data.

Data: {0}

Question: {1}

Answer:
"""

# Send the prompt to Gemini
def generate(prompt):
  model = GenerativeModel("gemini-pro")
  responses = model.generate_content(
    prompt,
    generation_config={
        "max_output_tokens": 2048,
        "temperature": 0.2,
        "top_p": 1
    },
    )
  return responses.text

# The question is converted to an embedding.
# Then, that embedding is used to query the vector DB.
# the 5 closest embeddings are returned.
def query_vector_db(question):

  results = chroma_collection.query(query_texts=[question], n_results=5)
  retrieved_documents = results['documents'][0]

  # concatentate all the strings in the retrieved_documents collection
  DATA = " ".join(retrieved_documents)
  return DATA

def ask_question(question):
  DATA = query_vector_db(question)
  prompt = prompt_template.format(DATA, question)
  answer = generate(prompt)
  return answer



In [None]:
from IPython.display import display

QUESTION1 = "Who wrote the Wizard of Oz?"
QUESTION2 = "Who were the main characters in the Wizard of Oz?"
QUESTION3 = "What was the plot of the Wizard of Oz?"
QUESTION4 = "Tell me about the Scarecrow?"

questions = [QUESTION1, QUESTION2, QUESTION3, QUESTION4]

for question in questions:
  answer = ask_question(question)
  display(answer)
  print("-----------------------")
