Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for vector search and document level access control #180

Merged
merged 1 commit into from Aug 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
76 changes: 69 additions & 7 deletions app.py
Expand Up @@ -36,10 +36,14 @@ def assets(path):
AZURE_SEARCH_FILENAME_COLUMN = os.environ.get("AZURE_SEARCH_FILENAME_COLUMN")
AZURE_SEARCH_TITLE_COLUMN = os.environ.get("AZURE_SEARCH_TITLE_COLUMN")
AZURE_SEARCH_URL_COLUMN = os.environ.get("AZURE_SEARCH_URL_COLUMN")
AZURE_SEARCH_VECTOR_COLUMNS = os.environ.get("AZURE_SEARCH_VECTOR_COLUMNS")
AZURE_SEARCH_QUERY_TYPE = os.environ.get("AZURE_SEARCH_QUERY_TYPE")
AZURE_SEARCH_PERMITTED_GROUPS_COLUMN = os.environ.get("AZURE_SEARCH_PERMITTED_GROUPS_COLUMN")

# AOAI Integration Settings
AZURE_OPENAI_RESOURCE = os.environ.get("AZURE_OPENAI_RESOURCE")
AZURE_OPENAI_MODEL = os.environ.get("AZURE_OPENAI_MODEL")
AZURE_OPENAI_ENDPOINT = os.environ.get("AZURE_OPENAI_ENDPOINT")
AZURE_OPENAI_KEY = os.environ.get("AZURE_OPENAI_KEY")
AZURE_OPENAI_TEMPERATURE = os.environ.get("AZURE_OPENAI_TEMPERATURE", 0)
AZURE_OPENAI_TOP_P = os.environ.get("AZURE_OPENAI_TOP_P", 1.0)
Expand All @@ -49,6 +53,9 @@ def assets(path):
AZURE_OPENAI_PREVIEW_API_VERSION = os.environ.get("AZURE_OPENAI_PREVIEW_API_VERSION", "2023-06-01-preview")
AZURE_OPENAI_STREAM = os.environ.get("AZURE_OPENAI_STREAM", "true")
AZURE_OPENAI_MODEL_NAME = os.environ.get("AZURE_OPENAI_MODEL_NAME", "gpt-35-turbo") # Name of the model, e.g. 'gpt-35-turbo' or 'gpt-4'
AZURE_OPENAI_EMBEDDING_ENDPOINT = os.environ.get("AZURE_OPENAI_EMBEDDING_ENDPOINT")
AZURE_OPENAI_EMBEDDING_KEY = os.environ.get("AZURE_OPENAI_EMBEDDING_KEY")


SHOULD_STREAM = True if AZURE_OPENAI_STREAM.lower() == "true" else False

Expand All @@ -66,10 +73,60 @@ def should_use_data():
def format_as_ndjson(obj: dict) -> str:
return json.dumps(obj, ensure_ascii=False) + "\n"

def fetchUserGroups(userToken, nextLink=None):
# Recursively fetch group membership
if nextLink:
endpoint = nextLink
else:
endpoint = "https://graph.microsoft.com/v1.0/me/transitiveMemberOf?$select=id"

headers = {
'Authorization': "bearer " + userToken
}
try :
r = requests.get(endpoint, headers=headers)
if r.status_code != 200:
return []

r = r.json()
if "@odata.nextLink" in r:
nextLinkData = fetchUserGroups(userToken, r["@odata.nextLink"])
r['value'].extend(nextLinkData)

return r['value']
except Exception as e:
return []


def generateFilterString(userToken):
# Get list of groups user is a member of
userGroups = fetchUserGroups(userToken)

# Construct filter string
if userGroups:
group_ids = ", ".join([obj['id'] for obj in userGroups])
return f"{AZURE_SEARCH_PERMITTED_GROUPS_COLUMN}/any(g:search.in(g, '{group_ids}'))"

return None


def prepare_body_headers_with_data(request):
request_messages = request.json["messages"]

# Set query type
query_type = "simple"
if AZURE_SEARCH_QUERY_TYPE:
query_type = AZURE_SEARCH_QUERY_TYPE
elif AZURE_SEARCH_USE_SEMANTIC_SEARCH.lower() == "true" and AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG:
query_type = "semantic"

# Set filter
filter = None
userToken = None
if AZURE_SEARCH_PERMITTED_GROUPS_COLUMN:
userToken = request.headers.get('X-MS-TOKEN-AAD-ACCESS-TOKEN', "")
filter = generateFilterString(userToken)

body = {
"messages": request_messages,
"temperature": float(AZURE_OPENAI_TEMPERATURE),
Expand All @@ -88,13 +145,17 @@ def prepare_body_headers_with_data(request):
"contentFields": AZURE_SEARCH_CONTENT_COLUMNS.split("|") if AZURE_SEARCH_CONTENT_COLUMNS else [],
"titleField": AZURE_SEARCH_TITLE_COLUMN if AZURE_SEARCH_TITLE_COLUMN else None,
"urlField": AZURE_SEARCH_URL_COLUMN if AZURE_SEARCH_URL_COLUMN else None,
"filepathField": AZURE_SEARCH_FILENAME_COLUMN if AZURE_SEARCH_FILENAME_COLUMN else None
"filepathField": AZURE_SEARCH_FILENAME_COLUMN if AZURE_SEARCH_FILENAME_COLUMN else None,
"vectorFields": AZURE_SEARCH_VECTOR_COLUMNS.split("|") if AZURE_SEARCH_VECTOR_COLUMNS else []
},
"inScope": True if AZURE_SEARCH_ENABLE_IN_DOMAIN.lower() == "true" else False,
"topNDocuments": AZURE_SEARCH_TOP_K,
"queryType": "semantic" if AZURE_SEARCH_USE_SEMANTIC_SEARCH.lower() == "true" else "simple",
"semanticConfiguration": AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG if AZURE_SEARCH_USE_SEMANTIC_SEARCH.lower() == "true" and AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG else "",
"roleInformation": AZURE_OPENAI_SYSTEM_MESSAGE
"queryType": query_type,
"semanticConfiguration": AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG if AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG else "",
"roleInformation": AZURE_OPENAI_SYSTEM_MESSAGE,
"embeddingEndpoint": AZURE_OPENAI_EMBEDDING_ENDPOINT,
"embeddingKey": AZURE_OPENAI_EMBEDDING_KEY,
"filter": filter
}
}
]
Expand All @@ -103,7 +164,7 @@ def prepare_body_headers_with_data(request):
headers = {
'Content-Type': 'application/json',
'api-key': AZURE_OPENAI_KEY,
"x-ms-useragent": "GitHubSampleWebApp/PublicAPI/1.0.0"
"x-ms-useragent": "GitHubSampleWebApp/PublicAPI/2.0.0"
}

return body, headers
Expand Down Expand Up @@ -152,7 +213,8 @@ def stream_with_data(body, headers, endpoint):

def conversation_with_data(request):
body, headers = prepare_body_headers_with_data(request)
endpoint = f"https://{AZURE_OPENAI_RESOURCE}.openai.azure.com/openai/deployments/{AZURE_OPENAI_MODEL}/extensions/chat/completions?api-version={AZURE_OPENAI_PREVIEW_API_VERSION}"
base_url = AZURE_OPENAI_ENDPOINT if AZURE_OPENAI_ENDPOINT else f"https://{AZURE_OPENAI_RESOURCE}.openai.azure.com/"
endpoint = f"{base_url}openai/deployments/{AZURE_OPENAI_MODEL}/extensions/chat/completions?api-version={AZURE_OPENAI_PREVIEW_API_VERSION}"

if not SHOULD_STREAM:
r = requests.post(endpoint, headers=headers, json=body)
Expand Down Expand Up @@ -190,7 +252,7 @@ def stream_without_data(response):

def conversation_without_data(request):
openai.api_type = "azure"
openai.api_base = f"https://{AZURE_OPENAI_RESOURCE}.openai.azure.com/"
openai.api_base = AZURE_OPENAI_ENDPOINT if AZURE_OPENAI_ENDPOINT else f"https://{AZURE_OPENAI_RESOURCE}.openai.azure.com/"
openai.api_version = "2023-03-15-preview"
openai.api_key = AZURE_OPENAI_KEY

Expand Down