Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add user toggleable web search #2004

Merged
merged 31 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
501ff7a
feat: backend implementation of various search APIs
cheahjs May 6, 2024
99e4edd
feat: add websearch endpoint to RAG API
cheahjs May 6, 2024
83f086c
fix: do not return raw search exception due to API keys in URLs
cheahjs May 6, 2024
8b3e370
fix: run formatter
cheahjs May 6, 2024
635951b
Merge branch 'dev' into feat/backend-web-search
tjbck May 6, 2024
fb80691
feat: add WEB_SEARCH_RESULT_COUNT to control max number of results
cheahjs May 11, 2024
14a902f
feat: add web search toggle on chat
cheahjs May 11, 2024
619c2f9
fix: toggle style
cheahjs May 11, 2024
2660a6e
feat: prototype frontend web search integration
cheahjs May 11, 2024
77928ae
Merge branch 'dev' of https://github.com/open-webui/open-webui into f…
cheahjs May 11, 2024
7538dc0
feat: use url as source name for citations
cheahjs May 11, 2024
9ed1a31
fix: continue with failures when bulk loading urls with WebBaseLoader
cheahjs May 12, 2024
d45804d
feat: web search available is inferred from env vars
cheahjs May 12, 2024
3baeda7
feat: add in-message progress indicator for web search
cheahjs May 12, 2024
d980518
feat: mark websearch docs differently from standard docs
cheahjs May 12, 2024
654cc09
feat: run i18next
cheahjs May 12, 2024
466b3e3
feat: add support for using previous messages for query generation
cheahjs May 12, 2024
44c8b0b
feat: rename title generation model to task model
cheahjs May 12, 2024
f49e1af
feat: inject search result doc in the response, not the query
cheahjs May 12, 2024
5e1c408
Merge branch 'dev' into feat/backend-web-search
cheahjs May 14, 2024
f946903
chore: formatting
cheahjs May 14, 2024
81a3c97
Merge branch 'dev' into feat/backend-web-search
cheahjs May 14, 2024
9021f06
Merge remote-tracking branch 'origin/dev' into feat/backend-web-search
cheahjs May 16, 2024
b95027f
feat: add searched urls to document
cheahjs May 16, 2024
eb509c4
Merge remote-tracking branch 'origin/dev' into feat/backend-web-search
cheahjs May 20, 2024
69bac2a
feat: use the conversation's model instead of the first model for que…
cheahjs May 20, 2024
224a578
Merge remote-tracking branch 'upstream/dev' into feat/backend-web-search
cheahjs May 20, 2024
6043385
Merge remote-tracking branch 'upstream/dev' into feat/backend-web-search
cheahjs May 22, 2024
b1265c9
Merge remote-tracking branch 'upstream/dev' into feat/backend-web-search
cheahjs May 25, 2024
276b7b9
Merge remote-tracking branch 'upstream/dev' into feat/backend-web-search
cheahjs May 26, 2024
bced907
Merge branch 'websearch' into feat/backend-web-search
tjbck May 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 72 additions & 17 deletions backend/apps/rag/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import os, shutil, logging, re

from pathlib import Path
from typing import List
from typing import List, Union, Sequence

from chromadb.utils.batch_utils import create_batches

Expand Down Expand Up @@ -59,6 +59,7 @@
query_doc_with_hybrid_search,
query_collection,
query_collection_with_hybrid_search,
search_web,
)

from utils.misc import (
Expand Down Expand Up @@ -95,6 +96,7 @@
RAG_TEMPLATE,
ENABLE_RAG_LOCAL_WEB_FETCH,
YOUTUBE_LOADER_LANGUAGE,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
AppConfig,
)

Expand Down Expand Up @@ -201,6 +203,10 @@ class UrlForm(CollectionNameForm):
url: str


class SearchForm(CollectionNameForm):
query: str


@app.get("/")
async def get_status():
return {
Expand Down Expand Up @@ -589,24 +595,40 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
)


def get_web_loader(url: str, verify_ssl: bool = True):
def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
# Check if the URL is valid
if isinstance(validators.url(url), validators.ValidationError):
if not validate_url(url):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_RAG_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url)
# Get IPv4 and IPv6 addresses
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
# Check if any of the resolved addresses are private
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
for ip in ipv4_addresses:
if validators.ipv4(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
for ip in ipv6_addresses:
if validators.ipv6(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return WebBaseLoader(url, verify_ssl=verify_ssl)
return WebBaseLoader(
url,
verify_ssl=verify_ssl,
requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
continue_on_failure=True,
)


def validate_url(url: Union[str, Sequence[str]]):
if isinstance(url, str):
if isinstance(validators.url(url), validators.ValidationError):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_RAG_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url)
# Get IPv4 and IPv6 addresses
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
# Check if any of the resolved addresses are private
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
for ip in ipv4_addresses:
if validators.ipv4(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
for ip in ipv6_addresses:
if validators.ipv6(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return True
elif isinstance(url, Sequence):
return all(validate_url(u) for u in url)
else:
return False


def resolve_hostname(hostname):
Expand All @@ -620,6 +642,39 @@ def resolve_hostname(hostname):
return ipv4_addresses, ipv6_addresses


@app.post("/websearch")
def store_websearch(form_data: SearchForm, user=Depends(get_current_user)):
try:
try:
web_results = search_web(form_data.query)
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR,
)
urls = [result.link for result in web_results]
loader = get_web_loader(urls)
data = loader.aload()

collection_name = form_data.collection_name
if collection_name == "":
collection_name = calculate_sha256_string(form_data.query)[:63]

store_data_in_vector_db(data, collection_name, overwrite=True)
return {
"status": True,
"collection_name": collection_name,
"filenames": urls,
}
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)


def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:

text_splitter = RecursiveCharacterTextSplitter(
Expand Down
37 changes: 37 additions & 0 deletions backend/apps/rag/search/brave.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import logging

import requests

from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT

log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])


def search_brave(api_key: str, query: str) -> list[SearchResult]:
"""Search using Brave's Search API and return the results as a list of SearchResult objects.

Args:
api_key (str): A Brave Search API key
query (str): The query to search for
"""
url = "https://api.search.brave.com/res/v1/web/search"
headers = {
"Accept": "application/json",
"Accept-Encoding": "gzip",
"X-Subscription-Token": api_key,
}
params = {"q": query, "count": RAG_WEB_SEARCH_RESULT_COUNT}

response = requests.get(url, headers=headers, params=params)
response.raise_for_status()

json_response = response.json()
results = json_response.get("web", {}).get("results", [])
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
)
for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
]
45 changes: 45 additions & 0 deletions backend/apps/rag/search/google_pse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import json
import logging

import requests

from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT

log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])


def search_google_pse(
api_key: str, search_engine_id: str, query: str
) -> list[SearchResult]:
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.

Args:
api_key (str): A Programmable Search Engine API key
search_engine_id (str): A Programmable Search Engine ID
query (str): The query to search for
"""
url = "https://www.googleapis.com/customsearch/v1"

headers = {"Content-Type": "application/json"}
params = {
"cx": search_engine_id,
"q": query,
"key": api_key,
"num": RAG_WEB_SEARCH_RESULT_COUNT,
}

response = requests.request("GET", url, headers=headers, params=params)
response.raise_for_status()

json_response = response.json()
results = json_response.get("items", [])
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
)
for result in results
]
9 changes: 9 additions & 0 deletions backend/apps/rag/search/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Optional

from pydantic import BaseModel


class SearchResult(BaseModel):
link: str
title: Optional[str]
snippet: Optional[str]
44 changes: 44 additions & 0 deletions backend/apps/rag/search/searxng.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import logging

import requests

from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT

log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])


def search_searxng(query_url: str, query: str) -> list[SearchResult]:
"""Search a SearXNG instance for a query and return the results as a list of SearchResult objects.

Args:
query_url (str): The URL of the SearXNG instance to search. Must contain "<query>" as a placeholder
query (str): The query to search for
"""
url = query_url.replace("<query>", query)
if "&format=json" not in url:
url += "&format=json"
log.debug(f"searching {url}")

r = requests.get(
url,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Accept": "text/html",
"Accept-Encoding": "gzip, deflate",
"Accept-Language": "en-US,en;q=0.5",
"Connection": "keep-alive",
},
)
r.raise_for_status()

json_response = r.json()
results = json_response.get("results", [])
sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True)
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("content")
)
for result in sorted_results[:RAG_WEB_SEARCH_RESULT_COUNT]
]
39 changes: 39 additions & 0 deletions backend/apps/rag/search/serper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import json
import logging

import requests

from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT

log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])


def search_serper(api_key: str, query: str) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects.

Args:
api_key (str): A serper.dev API key
query (str): The query to search for
"""
url = "https://google.serper.dev/search"

payload = json.dumps({"q": query})
headers = {"X-API-KEY": api_key, "Content-Type": "application/json"}

response = requests.request("POST", url, headers=headers, data=payload)
response.raise_for_status()

json_response = response.json()
results = sorted(
json_response.get("organic", []), key=lambda x: x.get("position", 0)
)
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("description"),
)
for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
]
43 changes: 43 additions & 0 deletions backend/apps/rag/search/serpstack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import json
import logging

import requests

from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT

log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])


def search_serpstack(
api_key: str, query: str, https_enabled: bool = True
) -> list[SearchResult]:
"""Search using serpstack.com's and return the results as a list of SearchResult objects.

Args:
api_key (str): A serpstack.com API key
query (str): The query to search for
https_enabled (bool): Whether to use HTTPS or HTTP for the API request
"""
url = f"{'https' if https_enabled else 'http'}://api.serpstack.com/search"

headers = {"Content-Type": "application/json"}
params = {
"access_key": api_key,
"query": query,
}

response = requests.request("POST", url, headers=headers, params=params)
response.raise_for_status()

json_response = response.json()
results = sorted(
json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
)
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
)
for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
]
Loading