{{ question | replace("+", " ") | replace("%3F", "?")}}
+To verify this information, please check out:
+ {{ fact_link | safe }} +(.*?)", "\\1", text) + text = re.sub(r"
(.*?)", "\\1", text)
+ text = re.sub(r"(?m)(.*?)", "\\1", text)
+ text = re.sub(
+ r"(^|)(Important|Note|Caution|Tip|Warning|Important|Key Point|Key Term):\s?",
+ "",
+ text,
+ )
+ text = re.sub(
+ r"(^|)(Objective|Success|Beta|Preview|Deprecated):\s?",
+ "",
+ text,
+ )
+ text = re.sub(r"(Project|Book):(.*)\n", "", text)
+ text = text.strip() + "\n"
+ return text
+
+
+# Function to verify that include exists and exports its content
+def read_markdown(file):
+ try:
+ with open(file, "r", encoding="utf-8") as mdfile:
+ output = mdfile.read()
+ return output
+ except FileNotFoundError:
+ print("[FileNotFound] Missing the include file: " + file)
+
+
+# This function converts Markdown page (#), section (##), and subsection (###)
+# headings into plain English.
+def process_page_and_section_titles(markdown_text):
+ updated_markdown = ""
+ page_title = ""
+ section_title = ""
+ subsection_title = ""
+ new_line = ""
+ metadata = {}
+ # Processes the frontmatter in a markdown file
+ data = frontmatter.loads(markdown_text)
+ if "title" in data:
+ page_title = data["title"]
+ markdown_text = data.content
+ metadata = data.metadata
+ if "URL" in data:
+ final_url = data["URL"]
+ metadata["URL"] = final_url
+ for line in markdown_text.split("\n"):
+ new_line = ""
+ skip_this_line = False
+ if line.startswith("#"):
+ match = re.search(r"^(\#*)\s+(.*)$", line)
+ heading = ""
+ captured_title = ""
+ if match:
+ heading = match[1]
+ # Remove {: } in devsite Markdown
+ captured_title = re.sub(r"\{:(.*?)\}", "", match[2])
+ # Special case of RFC pages.
+ if re.search(r"^\{\{\s+(.*)\.(.*)\s+\}\}$", captured_title):
+ heading = ""
+ page_title = "RFC"
+ skip_this_line = True
+
+ # Detect Markdown heading levels
+ if heading == "#":
+ page_title = captured_title.strip()
+ metadata["title"] = page_title
+ subsection_title = ""
+ section_title = ""
+ elif heading == "##":
+ section_title = captured_title.strip()
+ subsection_title = ""
+ elif heading == "###":
+ subsection_title = captured_title.strip()
+
+ # Convert Markdown headings into plain English
+ # (but keep `#` for the `process_document_into_sections()`
+ # function to detect these headings for splitting).
+ if page_title:
+ new_line = (
+ '# The "'
+ + page_title
+ + '" page contains the following content:\n\n'
+ )
+
+ if section_title:
+ new_line = (
+ '# The "'
+ + page_title
+ + '" page has the "'
+ + section_title
+ + '" section that contains the following content:\n'
+ )
+
+ if subsection_title:
+ new_line = (
+ '# On the "'
+ + page_title
+ + '" page, the "'
+ + section_title
+ + '" section has the "'
+ + subsection_title
+ + '" subsection that contains the following content:\n'
+ )
+
+ if skip_this_line is False:
+ if new_line:
+ updated_markdown += new_line + "\n"
+ else:
+ updated_markdown += line + "\n"
+ return updated_markdown, metadata
+
+
+# This function replaces Markdown's includes sections with content.
+def process_includes(markdown_text, root):
+ updated_markdown = ""
+ for line in markdown_text.split("\n"):
+ new_line = ""
+ # Replaces Markdown includes with content
+ if line.startswith("<<"):
+ include_match = re.search("^<<(.*?)>>", line)
+ if include_match:
+ include_file = os.path.abspath(root + "/" + include_match[1])
+ new_line = read_markdown(include_file)
+ if new_line:
+ updated_markdown += new_line + "\n"
+ else:
+ updated_markdown += line + "\n"
+ return updated_markdown
+
+
+# This function divides Markdown content into sections and
+# returns an array containing these sections.
+# But this function requires pre-processed Markdown headings from
+# the `process_page_and_section_titles()` function, which simplifies
+# three levels of Markdown headings (#, ##, and ###) into just a single #.
+def process_document_into_sections(markdown_text):
+ sections = []
+ buffer = ""
+ first_section = True
+ for line in markdown_text.split("\n"):
+ if line.startswith("#"):
+ match = re.search(r"^(\#*)\s+(.*)$", line)
+ heading = ""
+ if match:
+ heading = match[1]
+ if heading == "#":
+ if first_section is True:
+ # Ignore the first detection of `#`.
+ first_section = False
+ else:
+ # When a new `#` is detected, store the text in `buffer` into
+ # an array entry and clear the buffer for the next section.
+ sections.append(buffer)
+ buffer = ""
+ buffer += line + "\n"
+ # Add the last section on the page.
+ sections.append(buffer)
+ return sections
+
+
+# This function processes Markdown files in the `input_path` directory
+# into plain text files.
+def process_markdown_files_from_source(configs, inputpath, counter, excludepath):
+ f_count = 0
+ for root, dirs, files in os.walk(resolve_path(inputpath)):
+ if IS_CONFIG_FILE:
+ if "exclude_path" in configs[counter]:
+ dirs[:] = [d for d in dirs if d not in excludepath]
+ if "url_prefix" in configs[counter]:
+ namespace_uuid = uuid.uuid3(
+ uuid.NAMESPACE_DNS, configs[counter]["url_prefix"]
+ )
+ for file in files:
+ f_count += 1
+ # Process only Markdown files
+ if file.endswith(".md"):
+ with open(os.path.join(root, file), "r", encoding="utf-8") as auto:
+ # Construct a new sub-directory for storing output plain text files
+ new_path = MY_OUTPUT_PATH + re.sub(
+ resolve_path(inputpath), "", os.path.join(root, "")
+ )
+ is_exist = os.path.exists(new_path)
+ if not is_exist:
+ os.makedirs(new_path)
+ # Grab the filename without the .md extension
+ new_filename = os.path.join(new_path, file)
+ # Add filename to a list
+ file_slash = "/" + file
+ relative_path = os.path.relpath(root + file_slash, inputpath)
+ file_index.append(relative_path)
+ match = re.search(r"(.*)\.md$", new_filename)
+ new_filename_no_ext = match[1]
+ # Read the input Markdown content
+ to_file = auto.read()
+ # Reformat the page and section titles
+ to_file, metadata = process_page_and_section_titles(to_file)
+ # Process includes lines in Markdown
+ to_file = process_includes(to_file, root)
+ doc = []
+ if USE_CUSTOM_MARKDOWN_SPLITTER is True:
+ # Use a custom splitter to split into small chunks
+ docs = process_document_into_sections(to_file)
+ else:
+ # Use the Markdown splitter to split into small chunks
+ docs = markdown_splitter.create_documents([to_file])
+ i = 0
+ for doc in docs:
+ # Clean up Makrdown and HTML syntax
+ if USE_CUSTOM_MARKDOWN_SPLITTER is True:
+ content = markdown_to_text(doc)
+ else:
+ content = markdown_to_text(doc.page_content)
+ # Save clean plain text to a new filename appended with an index
+ filename_to_save = new_filename_no_ext + "_" + str(i) + ".md"
+ # Generate UUID for each plain text chunk and collect its metadata,
+ # which will be written to the top-level `file_index.json` file.
+ md_hash = uuid.uuid3(namespace_uuid, content)
+ uuid_file = uuid.uuid3(namespace_uuid, filename_to_save)
+ if bool(metadata):
+ full_file_metadata[filename_to_save] = {
+ "UUID": str(uuid_file),
+ "source": input_path,
+ "source_file": relative_path,
+ "source_id": counter,
+ "URL": url_pre,
+ "md_hash": str(md_hash),
+ "metadata": metadata,
+ }
+ else:
+ full_file_metadata[filename_to_save] = {
+ "UUID": str(uuid_file),
+ "source": input_path,
+ "source_file": relative_path,
+ "source_id": counter,
+ "URL": url_pre,
+ "md_hash": str(md_hash),
+ }
+ with open(filename_to_save, "w", encoding="utf-8") as new_file:
+ new_file.write(content)
+ new_file.close()
+ i = i + 1
+ auto.close()
+ print("Processed " + str(f_count) + " Markdown files from the source: " + inputpath)
+ return f_count
+
+
+# Write the recorded input variables into a file: `file_index.json`
+def save_file_index_json(src_file_index):
+ json_out_file = MY_OUTPUT_PATH + "/file_index.json"
+ with open(json_out_file, "w", encoding="utf-8") as outfile:
+ json.dump(src_file_index, outfile)
+ print(
+ "Created " + json_out_file + " to store the complete list of processed files."
+ )
+
+
+#### Main ####
+source_file_index = {}
+input_counter = 0
+total_file_count = 0
+
+# Main for-loop
+for input_counter in range(input_len):
+ full_file_metadata = {}
+ file_index = []
+ exclude = []
+ # Process `input-values.yaml` into input variables.
+ if IS_CONFIG_FILE:
+ # Reads all the input values defined in the configuration file
+ config_values = config.returnConfigValue("input")
+ if "path" in config_values[input_counter]:
+ input_path = config_values[input_counter]["path"]
+ if "url_prefix" in config_values[input_counter]:
+ url_pre = config_values[input_counter]["url_prefix"]
+ if "exclude_path" in config_values[input_counter]:
+ exclude = config_values[input_counter]["exclude_path"]
+ else:
+ input_path = MY_INPUT_PATH[input_counter]
+ url_pre = URL_PREFIX[input_counter]
+
+ # Process Markdown files in the `input` path
+ file_count = process_markdown_files_from_source(
+ config_values, input_path, input_counter, exclude
+ )
+ if not input_path.endswith("/"):
+ input_path = input_path + "/"
+ input_path = resolve_path(input_path)
+ # Record the input variables used in this path.
+ file_list = {}
+ for file in file_index:
+ file_obj = {file: {"source": input_path, "URL": url_pre}}
+ file_list[file] = file_obj
+ source_file_index[input_counter] = full_file_metadata
+ input_counter += 1
+ total_file_count += file_count
+
+# Write the recorded input variables into `file_index.json`.
+save_file_index_json(source_file_index)
+
+print(
+ "Processed a total of "
+ + str(total_file_count)
+ + " Markdown files from "
+ + str(input_counter)
+ + " sources."
+)
diff --git a/demos/palm/python/docs-agent/scripts/populate_vector_database.py b/demos/palm/python/docs-agent/scripts/populate_vector_database.py
new file mode 100644
index 000000000..51eeb0ce8
--- /dev/null
+++ b/demos/palm/python/docs-agent/scripts/populate_vector_database.py
@@ -0,0 +1,307 @@
+#
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Populate the vector database with embeddings generated from text chunks"""
+
+import os
+import sys
+import re
+import json
+import chromadb
+import flatdict
+import uuid
+from chromadb.config import Settings
+from chromadb.utils import embedding_functions
+from chromadb.api.types import Documents, Embeddings
+import google.generativeai as palm
+from ratelimit import limits, sleep_and_retry
+import read_config
+
+### Notes on how to use this script ###
+#
+# Prerequisites:
+# - Have plain text files stored in the PLAIN_TEXT_DIR directory
+# (see `markdown_to_plain_text.py`)
+#
+# Do the following:
+# 1. If you are not using a `input-values.yaml` file,
+# edit PLAIN_TEXT_DIR in this script (see below).
+# 2. Run:
+# $ python3 ./scripts/populate-vector-database.py
+#
+# To test, run:
+# $ python3 ./script/test-vector-database.py
+#
+
+BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+### Select the input directory of plain text files, this will be overridden by
+### `input-values.yaml`
+### Set up the path to the local LLM ###
+LOCAL_VECTOR_DB_DIR = os.path.join(BASE_DIR, "vector_stores/chroma")
+COLLECTION_NAME = "docs_collection"
+
+IS_CONFIG_FILE = True
+if IS_CONFIG_FILE:
+ config_values = read_config.ReadConfig()
+ PLAIN_TEXT_DIR = config_values.returnConfigValue("output_path")
+ input_len = config_values.returnInputCount()
+ LOCAL_VECTOR_DB_DIR = config_values.returnConfigValue("vector_db_dir")
+ COLLECTION_NAME = config_values.returnConfigValue("collection_name")
+
+### Select the file index that is generated with your plain text files, same directory
+INPUT_FILE_INDEX = "file_index.json"
+
+# Select the type of embeddings to use, PALM or LOCAL
+EMBEDDINGS_TYPE = "PALM"
+
+### Set up the PaLM API key from the environment ###
+API_KEY = os.getenv("PALM_API_KEY")
+if API_KEY is None:
+ sys.exit("Please set the environment variable PALM_API_KEY to be your API key.")
+
+# PaLM API call limit to 300 per minute
+API_CALLS = 280
+API_CALL_PERIOD = 60
+
+# Enable relative directories.
+if not BASE_DIR.endswith("/"):
+ BASE_DIR = BASE_DIR + "/"
+
+if not PLAIN_TEXT_DIR.endswith("/"):
+ PLAIN_TEXT_DIR = PLAIN_TEXT_DIR + "/"
+
+FULL_BASE_DIR = BASE_DIR + PLAIN_TEXT_DIR
+print("Plain text directory: " + FULL_BASE_DIR + "\n")
+
+FULL_INDEX_PATH = PLAIN_TEXT_DIR + INPUT_FILE_INDEX
+try:
+ with open(FULL_INDEX_PATH, "r", encoding="utf-8") as index_file:
+ index = json.load(index_file)
+except FileNotFoundError:
+ msg = "The file " + FULL_INDEX_PATH + "does not exist."
+
+if EMBEDDINGS_TYPE == "PALM":
+ palm.configure(api_key=API_KEY)
+ # This returns models/embedding-gecko-001"
+ models = [
+ m for m in palm.list_models() if "embedText" in m.supported_generation_methods
+ ]
+ # MODEL = "models/embedding-gecko-001"
+ MODEL = models[0]
+elif EMBEDDINGS_TYPE == "LOCAL":
+ MODEL = os.path.join(BASE_DIR, "models/all-mpnet-base-v2")
+ emb_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=MODEL)
+else:
+ MODEL = os.path.join(BASE_DIR, "models/all-mpnet-base-v2")
+ emb_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=MODEL)
+
+chroma_client = chromadb.PersistentClient(path=LOCAL_VECTOR_DB_DIR)
+
+# Create embed function for PaLM
+# API call limit to 5 qps
+@sleep_and_retry
+@limits(calls=API_CALLS, period=API_CALL_PERIOD)
+def embed_function(texts: Documents) -> Embeddings:
+ # Embed the documents using any supported method
+ return [
+ palm.generate_embeddings(model=MODEL, text=text)["embedding"] for text in texts
+ ]
+
+
+if EMBEDDINGS_TYPE == "PALM":
+ collection = chroma_client.get_or_create_collection(
+ name=COLLECTION_NAME, embedding_function=embed_function
+ )
+elif EMBEDDINGS_TYPE == "LOCAL":
+ collection = chroma_client.get_or_create_collection(
+ name=COLLECTION_NAME, embedding_function=emb_fn
+ )
+else:
+ collection = chroma_client.get_or_create_collection(
+ name=COLLECTION_NAME, embedding_function=emb_fn
+ )
+
+documents = []
+metadatas = []
+ids = []
+i = 0
+updated_count = 0
+new_count = 0
+unchanged_count = 0
+
+# Read plain text files (.md) from the PLAIN_TEXT_DIR dir and
+# add their content to the vector database.
+# Embeddings are generated automatically as they are added to the database.
+for root, dirs, files in os.walk(PLAIN_TEXT_DIR):
+ for file in files:
+ file_update = False
+ # Persists every nth time and if there was an actual update or added file.
+ # However, we don't need to persist, which takes time, if there are no updates.
+ if i % 100 == 0 and file_update == True:
+ chroma_client.persist()
+ if file.endswith(".md"):
+ with open(os.path.join(root, file), "r", encoding="utf-8") as auto:
+ print("Process an entry into the database: " + str(i))
+ print("Opening a file: " + file)
+ # Extract the original filename used (without a file extension)
+ match = re.search(r"(.*)\.md$", file)
+ filename_no_ext = match[1]
+ toFile = auto.read()
+ # Contruct the URL
+ match2 = re.search(r"(.*)_\d*$", filename_no_ext)
+ filename_for_url = match2[1]
+ clean_filename = re.sub(PLAIN_TEXT_DIR, "", os.path.join(root, ""))
+ url = clean_filename + filename_for_url + ".md"
+ url_path = ""
+ md_hash = ""
+ uuid_file = ""
+ # Build the full filename to match entries in file_index.json
+ # Using the full path avoids mismatches
+ full_file_name = FULL_BASE_DIR + clean_filename + file
+ metadata_dict_extra = {}
+ # Flag to see if there is a predefined URL from frontmatter
+ final_url = False
+ # Reads the metadata associated with files
+ for key in index:
+ if full_file_name in index[key]:
+ if (
+ "URL" in index[key][full_file_name]
+ and "source_id" in index[key][full_file_name]
+ ):
+ # This ensures the URL is retrived from the correct file.
+ # Avoids issues with common file names such as README.md
+ if int(key) == index[key][full_file_name]["source_id"]:
+ if index[key][full_file_name]["URL"]:
+ url_path = index[key][full_file_name]["URL"]
+ else:
+ print("No valid URL value for: " + file)
+ # If metadata exists, add these to a dictionary that is then
+ # merged with other metadata values
+ if "metadata" in index[key][full_file_name]:
+ # Save and flatten dictionary
+ metadata_dict_extra = flatdict.FlatterDict(
+ index[key][full_file_name]["metadata"], delimiter="_"
+ )
+ metadata_dict_extra = dict(metadata_dict_extra)
+ # Extracts user specified URL
+ if "URL" in metadata_dict_extra:
+ final_url = True
+ final_url_value = metadata_dict_extra["URL"]
+ else:
+ metadata_dict_extra = {}
+ if "UUID" in index[key][full_file_name]:
+ uuid_file = index[key][full_file_name]["UUID"]
+ if "md_hash" in index[key][full_file_name]:
+ md_hash = str(index[key][full_file_name]["md_hash"])
+ # Add a trailing "/" to the url path in case the configuration file
+ # didn't have it.
+ # Do not add slashes to PSAs.
+ if (
+ not url_path.endswith("/")
+ and not url_path.startswith("PSA")
+ and not url.startswith("/")
+ ):
+ url_path = url_path + "/"
+ url = url_path + url
+ # Remove .md at the end of URLs by default.
+ match3 = re.search(r"(.*)\.md$", url)
+ url = match3[1]
+ # Replaces the URL if it comes from frontmatter
+ if (final_url):
+ url = final_url_value
+ # Creates a dictionary with basic metadata values
+ # (i.e. source, URL, and md_hash)
+ metadata_dict_main = {
+ "source": filename_no_ext,
+ "url": url,
+ "md_hash": md_hash,
+ }
+ # Merges dictionaries with main metadata and additional metadata
+ metadata_dict_final = metadata_dict_main | metadata_dict_extra
+ str_uuid_file = str(uuid_file)
+ print("UUID: " + str_uuid_file)
+ print("Markdown hash: " + str(md_hash))
+ print("URL: " + url)
+ if toFile and toFile.strip():
+ # Skip if the file size is larger than 10000 bytes (API limit)
+ filesize = len(toFile)
+ if filesize < 10000:
+ if md_hash != "" and str_uuid_file != "":
+ query = {}
+ # The query looks for the UUID, which is unique and
+ # compares to see if the hash has changed
+ query = collection.get(
+ include=["metadatas"],
+ ids=str_uuid_file,
+ where={"md_hash": {"$ne": md_hash}},
+ )
+ # Extract any id whose content may have changed
+ id_to_remove = query["ids"]
+ if id_to_remove != []:
+ print("Out of date content.")
+ # Delete the existing entry
+ collection.delete(ids=id_to_remove)
+ # Add a new entry
+ collection.add(
+ documents=toFile,
+ metadatas=metadata_dict_final,
+ ids=str_uuid_file,
+ )
+ print("Updated.")
+ updated_count += 1
+ file_update = True
+ else:
+ query_2 = collection.get(
+ include=["metadatas"],
+ ids=str_uuid_file,
+ where={"md_hash": {"$eq": md_hash}},
+ )
+ id_up_to_date = query_2["ids"]
+ if id_up_to_date != []:
+ print("Up to date content.")
+ unchanged_count += 1
+ else:
+ collection.add(
+ documents=toFile,
+ metadatas=metadata_dict_final,
+ ids=str_uuid_file,
+ )
+ print("Added content.")
+ new_count += 1
+ file_update = True
+ i += 1
+ else:
+ print(
+ "[Warning] Skipped "
+ + file
+ + " because the file size is too large!"
+ )
+ else:
+ print("[Warning] Empty file!")
+ print("")
+ auto.close()
+# results = collection.query(
+# query_texts=["What are some differences between apples and oranges?"],
+# n_results=3,
+# )
+# print("\nTesting:")
+# print(results)
+
+print("")
+print("Total number of entries: " + str(i))
+print("New entries: " + str(new_count))
+print("Updated entries: " + str(updated_count))
+print("Unchanged entries: " + str(unchanged_count))
diff --git a/demos/palm/python/docs-agent/scripts/read_config.py b/demos/palm/python/docs-agent/scripts/read_config.py
new file mode 100644
index 000000000..852ab65e2
--- /dev/null
+++ b/demos/palm/python/docs-agent/scripts/read_config.py
@@ -0,0 +1,96 @@
+#
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Read the configuration file to import user settings"""
+
+import os
+import sys
+import yaml
+
+# The configuration file config.yaml exists in the root of the project
+BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+INPUT_YAML = os.path.join(BASE_DIR, "config.yaml")
+### Set up the path to the local LLM ###
+LOCAL_VECTOR_DB_DIR = os.path.join(BASE_DIR, "vector_stores/chroma")
+
+# Define the required keys to run scripts and chatbot
+required_keys = ["output_path", "input", "product_name", "vector_db_dir"]
+# Define any supported optional keys to run scripts and chatbot
+optional_keys = []
+# Define any required keys that define the properties of input paths
+required_input_keys = ["path", "url_prefix"]
+# Define any optional keys that define the properties of input paths
+optional_input__keys = ["md_extension", "exlude_path"]
+
+
+class ReadConfig:
+ # Tries to ingest the configuration file and validate its keys
+ def __init__(self):
+ try:
+ with open(INPUT_YAML, "r", encoding="utf-8") as inp_yaml:
+ self.config_values = yaml.safe_load(inp_yaml)
+ self.IS_CONFIG_FILE = True
+ print("Configuration defined in: " + INPUT_YAML)
+ # Check that the required keys exist
+ self.validateKeys()
+ except FileNotFoundError:
+ print("The file " + INPUT_YAML + " does not exist.")
+ # Exits the scripts if there is no valid config file
+ return sys.exit(1)
+
+ # Function to return the full configuration file
+ def returnFullConfig(self):
+ return self.config_values
+
+ # Function to return the path of the configuration file
+ def returnConfigFile(self):
+ configFilePath = BASE_DIR + INPUT_YAML
+ return configFilePath
+
+ # Function to count the quantity of input paths
+ def returnInputCount(self):
+ count = len(self.returnConfigValue("input"))
+ return count
+
+ # Validates that a configuratioon file contains the required or optional keys
+ def validateKeys(self):
+ for key in required_keys:
+ if key in self.config_values:
+ # Validates lists such as input with their respective keys
+ if key == "input":
+ count = 0
+ for input in self.config_values["input"]:
+ count += 1
+ for required_key in required_input_keys:
+ if required_key not in input:
+ print(
+ "Missing input configuration key: "
+ + required_key
+ + " from input source "
+ + str(count)
+ )
+ else:
+ print("Missing required configuration key: " + key)
+ for key in optional_keys:
+ if key not in self.config_values:
+ print("Missing optional configuration key: " + key)
+
+ # Checks if a key exists and returns its value
+ def returnConfigValue(self, key):
+ if key in self.config_values:
+ return self.config_values[key]
+ else:
+ print("Error: " + key + " does not exist in the " + INPUT_YAML + " file.")
diff --git a/demos/palm/python/docs-agent/scripts/test_vector_database.py b/demos/palm/python/docs-agent/scripts/test_vector_database.py
new file mode 100644
index 000000000..075bda6e3
--- /dev/null
+++ b/demos/palm/python/docs-agent/scripts/test_vector_database.py
@@ -0,0 +1,125 @@
+#
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Test the vector database"""
+
+import os
+import sys
+import google.generativeai as palm
+import chromadb
+from chromadb.config import Settings
+from chromadb.utils import embedding_functions
+from chromadb.api.types import Document, Embedding, Documents, Embeddings
+from rich.console import Console
+from rich.markdown import Markdown
+from rich.panel import Panel
+from ratelimit import limits, sleep_and_retry
+import read_config
+
+BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+# Set the directory path to locate the Chroma vector database
+LOCAL_VECTOR_DB_DIR = os.path.join(BASE_DIR, "vector_stores/chroma")
+COLLECTION_NAME = "docs_collection"
+
+IS_CONFIG_FILE = True
+if IS_CONFIG_FILE:
+ config_values = read_config.ReadConfig()
+ LOCAL_VECTOR_DB_DIR = config_values.returnConfigValue("vector_db_dir")
+ COLLECTION_NAME = config_values.returnConfigValue("collection_name")
+
+# Set a test question
+QUESTION = "What are some differences between apples and oranges?"
+NUM_RETURNS = 5
+
+# Set up the PaLM API key from the environment
+API_KEY = os.getenv("PALM_API_KEY")
+if API_KEY is None:
+ sys.exit("Please set the environment variable PALM_API_KEY to be your API key.")
+
+# Select your PaLM API endpoint
+PALM_API_ENDPOINT = "generativelanguage.googleapis.com"
+palm.configure(api_key=API_KEY, client_options={"api_endpoint": PALM_API_ENDPOINT})
+
+# Set up the path to the local LLM
+# This value is used only when `EMBEDDINGS_TYPE` is set to `LOCAL`
+LOCAL_LLM = os.path.join(BASE_DIR, "models/all-mpnet-base-v2")
+
+# Use the PaLM API for generating embeddings by default
+EMBEDDINGS_TYPE = "PALM"
+
+# PaLM API call limit to 300 per minute
+API_CALLS = 280
+API_CALL_PERIOD = 60
+
+
+# Create embed function for PaLM
+# API call limit to 5 qps
+@sleep_and_retry
+@limits(calls=API_CALLS, period=API_CALL_PERIOD)
+def embed_palm_api_call(text: Document) -> Embedding:
+ return palm.generate_embeddings(model=PALM_EMBEDDING_MODEL, text=text)["embedding"]
+
+
+def embed_palm(texts: Documents) -> Embeddings:
+ # Embed the documents using any supported method
+ return [embed_palm_api_call(text) for text in texts]
+
+
+# Initialize Rich console
+ai_console = Console(width=160)
+ai_console.rule("Fold")
+
+chroma_client = chromadb.PersistentClient(path=LOCAL_VECTOR_DB_DIR)
+
+if EMBEDDINGS_TYPE == "PALM":
+ PALM_EMBEDDING_MODEL = "models/embedding-gecko-001"
+ emb_fn = embed_palm
+elif EMBEDDINGS_TYPE == "LOCAL":
+ emb_fn = embedding_functions.SentenceTransformerEmbeddingFunction(
+ model_name=LOCAL_LLM
+ )
+else:
+ emb_fn = embedding_functions.SentenceTransformerEmbeddingFunction(
+ model_name=LOCAL_LLM
+ )
+
+collection = chroma_client.get_collection(
+ name=COLLECTION_NAME, embedding_function=emb_fn
+)
+
+results = collection.query(query_texts=[QUESTION], n_results=NUM_RETURNS)
+
+print("")
+ai_console.print(Panel.fit(Markdown("Question: " + QUESTION)))
+print("Results:")
+print(results)
+print("")
+
+i = 0
+for document in results["documents"]:
+ for content in document:
+ print("Content " + str(i) + ": ")
+ ai_console.print(Panel.fit(Markdown(content)))
+ source = results["metadatas"][0][i]
+ this_id = results["ids"][0][i]
+ distance = results["distances"][0][i]
+ print(" source: " + source["source"])
+ print(" URL: " + source["url"])
+ print(" ID: " + this_id)
+ print(" Distance: " + str(distance))
+ print("")
+ i += 1
diff --git a/demos/palm/python/docs-agent/setup.py b/demos/palm/python/docs-agent/setup.py
new file mode 100644
index 000000000..1d9261f80
--- /dev/null
+++ b/demos/palm/python/docs-agent/setup.py
@@ -0,0 +1,24 @@
+#
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""For setting up local packages in the project."""
+from setuptools import setup
+
+setup(
+ name="docs_agent",
+ packages=["chatbot", "docs_agent"],
+ install_requires=["flask"],
+)
diff --git a/demos/palm/python/docs-agent/third_party/css/chatbox.css b/demos/palm/python/docs-agent/third_party/css/chatbox.css
new file mode 100644
index 000000000..e7d562c98
--- /dev/null
+++ b/demos/palm/python/docs-agent/third_party/css/chatbox.css
@@ -0,0 +1,42 @@
+/**
+ * Copyright (c) 2023 by Landgreen (https://codepen.io/lilgreenland/pen/pyVvqB)
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy of
+ * this software and associated documentation files (the "Software"), to deal in
+ * the Software without restriction, including without limitation the rights to
+ * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
+ * of the Software, and to permit persons to whom the Software is furnished to do
+ * so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+ * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+ * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+ * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+ * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ */
+
+#body-box {
+ margin: auto;
+ max-width: 800px;
+ font: 15px arial, sans-serif;
+ background-color: white;
+ border-style: solid;
+ border-width: 1px;
+ padding: 20px 25px 25px;
+ box-shadow: 5px 5px 5px grey;
+ border-radius: 15px;
+}
+
+#chat-border {
+ border-style: solid;
+ background-color: #f6f9f6;
+ border-width: 3px;
+ margin: 20px;
+ padding: 10px 20px 15px 15px;
+ border-radius: 15px;
+}
+
diff --git a/demos/palm/python/docs-agent/vector_stores/.gitkeep b/demos/palm/python/docs-agent/vector_stores/.gitkeep
new file mode 100644
index 000000000..e69de29bb
diff --git a/demos/palm/web/quick-prompt/package-lock.json b/demos/palm/web/quick-prompt/package-lock.json
index c5db3cb46..d695ab416 100644
--- a/demos/palm/web/quick-prompt/package-lock.json
+++ b/demos/palm/web/quick-prompt/package-lock.json
@@ -1318,9 +1318,9 @@
}
},
"node_modules/@grpc/grpc-js/node_modules/protobufjs": {
- "version": "7.2.3",
- "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.3.tgz",
- "integrity": "sha512-TtpvOqwB5Gdz/PQmOjgsrGH1nHjAQVCN7JG4A6r1sXRWESL5rNMAiRcBQlCAdKxZcAbstExQePYG8xof/JVRgg==",
+ "version": "7.2.5",
+ "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.5.tgz",
+ "integrity": "sha512-gGXRSXvxQ7UiPgfw8gevrfRWcTlSbOFg+p/N+JVJEK5VhueL2miT6qTymqAmjr1Q5WbOCyJbyrk6JfWKwlFn6A==",
"hasInstallScript": true,
"dependencies": {
"@protobufjs/aspromise": "^1.1.2",
@@ -4077,9 +4077,9 @@
}
},
"node_modules/protobufjs": {
- "version": "6.11.3",
- "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-6.11.3.tgz",
- "integrity": "sha512-xL96WDdCZYdU7Slin569tFX712BxsxslWwAfAhCYjQKGTq7dAU91Lomy6nLLhh/dyGhk/YH4TwTSRxTzhuHyZg==",
+ "version": "6.11.4",
+ "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-6.11.4.tgz",
+ "integrity": "sha512-5kQWPaJHi1WoCpjTGszzQ32PG2F4+wRY6BmAT4Vfw56Q2FZ4YZzK20xUYQH4YkfehY1e6QSICrJquM6xXZNcrw==",
"hasInstallScript": true,
"dependencies": {
"@protobufjs/aspromise": "^1.1.2",
@@ -5876,9 +5876,9 @@
}
},
"protobufjs": {
- "version": "7.2.3",
- "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.3.tgz",
- "integrity": "sha512-TtpvOqwB5Gdz/PQmOjgsrGH1nHjAQVCN7JG4A6r1sXRWESL5rNMAiRcBQlCAdKxZcAbstExQePYG8xof/JVRgg==",
+ "version": "7.2.5",
+ "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.5.tgz",
+ "integrity": "sha512-gGXRSXvxQ7UiPgfw8gevrfRWcTlSbOFg+p/N+JVJEK5VhueL2miT6qTymqAmjr1Q5WbOCyJbyrk6JfWKwlFn6A==",
"requires": {
"@protobufjs/aspromise": "^1.1.2",
"@protobufjs/base64": "^1.1.2",
@@ -7929,9 +7929,9 @@
}
},
"protobufjs": {
- "version": "6.11.3",
- "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-6.11.3.tgz",
- "integrity": "sha512-xL96WDdCZYdU7Slin569tFX712BxsxslWwAfAhCYjQKGTq7dAU91Lomy6nLLhh/dyGhk/YH4TwTSRxTzhuHyZg==",
+ "version": "6.11.4",
+ "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-6.11.4.tgz",
+ "integrity": "sha512-5kQWPaJHi1WoCpjTGszzQ32PG2F4+wRY6BmAT4Vfw56Q2FZ4YZzK20xUYQH4YkfehY1e6QSICrJquM6xXZNcrw==",
"requires": {
"@protobufjs/aspromise": "^1.1.2",
"@protobufjs/base64": "^1.1.2",
diff --git a/demos/palm/web/quick-prompt/yarn.lock b/demos/palm/web/quick-prompt/yarn.lock
index 5551e9046..04289c327 100644
--- a/demos/palm/web/quick-prompt/yarn.lock
+++ b/demos/palm/web/quick-prompt/yarn.lock
@@ -2445,9 +2445,9 @@ prop-types@^15.8.1:
react-is "^16.13.1"
protobufjs@^6.11.3:
- version "6.11.3"
- resolved "https://registry.npmjs.org/protobufjs/-/protobufjs-6.11.3.tgz"
- integrity sha512-xL96WDdCZYdU7Slin569tFX712BxsxslWwAfAhCYjQKGTq7dAU91Lomy6nLLhh/dyGhk/YH4TwTSRxTzhuHyZg==
+ version "6.11.4"
+ resolved "https://registry.yarnpkg.com/protobufjs/-/protobufjs-6.11.4.tgz#29a412c38bf70d89e537b6d02d904a6f448173aa"
+ integrity sha512-5kQWPaJHi1WoCpjTGszzQ32PG2F4+wRY6BmAT4Vfw56Q2FZ4YZzK20xUYQH4YkfehY1e6QSICrJquM6xXZNcrw==
dependencies:
"@protobufjs/aspromise" "^1.1.2"
"@protobufjs/base64" "^1.1.2"
@@ -2464,9 +2464,9 @@ protobufjs@^6.11.3:
long "^4.0.0"
protobufjs@^7.0.0:
- version "7.2.3"
- resolved "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.3.tgz"
- integrity sha512-TtpvOqwB5Gdz/PQmOjgsrGH1nHjAQVCN7JG4A6r1sXRWESL5rNMAiRcBQlCAdKxZcAbstExQePYG8xof/JVRgg==
+ version "7.2.5"
+ resolved "https://registry.yarnpkg.com/protobufjs/-/protobufjs-7.2.5.tgz#45d5c57387a6d29a17aab6846dcc283f9b8e7f2d"
+ integrity sha512-gGXRSXvxQ7UiPgfw8gevrfRWcTlSbOFg+p/N+JVJEK5VhueL2miT6qTymqAmjr1Q5WbOCyJbyrk6JfWKwlFn6A==
dependencies:
"@protobufjs/aspromise" "^1.1.2"
"@protobufjs/base64" "^1.1.2"
diff --git a/third_party/docs-agent b/third_party/docs-agent
new file mode 120000
index 000000000..8ddc11904
--- /dev/null
+++ b/third_party/docs-agent
@@ -0,0 +1 @@
+../demos/palm/python/docs-agent/third_party
\ No newline at end of file