Skip to content

Commit

Permalink
Add support for vector search and document level access control (#180)
Browse files Browse the repository at this point in the history
Co-authored-by: Sarah Widder <sawidder@microsoft.com>
  • Loading branch information
sarah-widder and sarah-widder committed Aug 21, 2023
1 parent 594ac86 commit 0e14412
Showing 1 changed file with 69 additions and 7 deletions.
76 changes: 69 additions & 7 deletions app.py
Expand Up @@ -36,10 +36,14 @@ def assets(path):
AZURE_SEARCH_FILENAME_COLUMN = os.environ.get("AZURE_SEARCH_FILENAME_COLUMN")
AZURE_SEARCH_TITLE_COLUMN = os.environ.get("AZURE_SEARCH_TITLE_COLUMN")
AZURE_SEARCH_URL_COLUMN = os.environ.get("AZURE_SEARCH_URL_COLUMN")
AZURE_SEARCH_VECTOR_COLUMNS = os.environ.get("AZURE_SEARCH_VECTOR_COLUMNS")
AZURE_SEARCH_QUERY_TYPE = os.environ.get("AZURE_SEARCH_QUERY_TYPE")
AZURE_SEARCH_PERMITTED_GROUPS_COLUMN = os.environ.get("AZURE_SEARCH_PERMITTED_GROUPS_COLUMN")

# AOAI Integration Settings
AZURE_OPENAI_RESOURCE = os.environ.get("AZURE_OPENAI_RESOURCE")
AZURE_OPENAI_MODEL = os.environ.get("AZURE_OPENAI_MODEL")
AZURE_OPENAI_ENDPOINT = os.environ.get("AZURE_OPENAI_ENDPOINT")
AZURE_OPENAI_KEY = os.environ.get("AZURE_OPENAI_KEY")
AZURE_OPENAI_TEMPERATURE = os.environ.get("AZURE_OPENAI_TEMPERATURE", 0)
AZURE_OPENAI_TOP_P = os.environ.get("AZURE_OPENAI_TOP_P", 1.0)
Expand All @@ -49,6 +53,9 @@ def assets(path):
AZURE_OPENAI_PREVIEW_API_VERSION = os.environ.get("AZURE_OPENAI_PREVIEW_API_VERSION", "2023-06-01-preview")
AZURE_OPENAI_STREAM = os.environ.get("AZURE_OPENAI_STREAM", "true")
AZURE_OPENAI_MODEL_NAME = os.environ.get("AZURE_OPENAI_MODEL_NAME", "gpt-35-turbo") # Name of the model, e.g. 'gpt-35-turbo' or 'gpt-4'
AZURE_OPENAI_EMBEDDING_ENDPOINT = os.environ.get("AZURE_OPENAI_EMBEDDING_ENDPOINT")
AZURE_OPENAI_EMBEDDING_KEY = os.environ.get("AZURE_OPENAI_EMBEDDING_KEY")


SHOULD_STREAM = True if AZURE_OPENAI_STREAM.lower() == "true" else False

Expand All @@ -66,10 +73,60 @@ def should_use_data():
def format_as_ndjson(obj: dict) -> str:
return json.dumps(obj, ensure_ascii=False) + "\n"

def fetchUserGroups(userToken, nextLink=None):
# Recursively fetch group membership
if nextLink:
endpoint = nextLink
else:
endpoint = "https://graph.microsoft.com/v1.0/me/transitiveMemberOf?$select=id"

headers = {
'Authorization': "bearer " + userToken
}
try :
r = requests.get(endpoint, headers=headers)
if r.status_code != 200:
return []

r = r.json()
if "@odata.nextLink" in r:
nextLinkData = fetchUserGroups(userToken, r["@odata.nextLink"])
r['value'].extend(nextLinkData)

return r['value']
except Exception as e:
return []


def generateFilterString(userToken):
# Get list of groups user is a member of
userGroups = fetchUserGroups(userToken)

# Construct filter string
if userGroups:
group_ids = ", ".join([obj['id'] for obj in userGroups])
return f"{AZURE_SEARCH_PERMITTED_GROUPS_COLUMN}/any(g:search.in(g, '{group_ids}'))"

return None


def prepare_body_headers_with_data(request):
request_messages = request.json["messages"]

# Set query type
query_type = "simple"
if AZURE_SEARCH_QUERY_TYPE:
query_type = AZURE_SEARCH_QUERY_TYPE
elif AZURE_SEARCH_USE_SEMANTIC_SEARCH.lower() == "true" and AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG:
query_type = "semantic"

# Set filter
filter = None
userToken = None
if AZURE_SEARCH_PERMITTED_GROUPS_COLUMN:
userToken = request.headers.get('X-MS-TOKEN-AAD-ACCESS-TOKEN', "")
filter = generateFilterString(userToken)

body = {
"messages": request_messages,
"temperature": float(AZURE_OPENAI_TEMPERATURE),
Expand All @@ -88,13 +145,17 @@ def prepare_body_headers_with_data(request):
"contentFields": AZURE_SEARCH_CONTENT_COLUMNS.split("|") if AZURE_SEARCH_CONTENT_COLUMNS else [],
"titleField": AZURE_SEARCH_TITLE_COLUMN if AZURE_SEARCH_TITLE_COLUMN else None,
"urlField": AZURE_SEARCH_URL_COLUMN if AZURE_SEARCH_URL_COLUMN else None,
"filepathField": AZURE_SEARCH_FILENAME_COLUMN if AZURE_SEARCH_FILENAME_COLUMN else None
"filepathField": AZURE_SEARCH_FILENAME_COLUMN if AZURE_SEARCH_FILENAME_COLUMN else None,
"vectorFields": AZURE_SEARCH_VECTOR_COLUMNS.split("|") if AZURE_SEARCH_VECTOR_COLUMNS else []
},
"inScope": True if AZURE_SEARCH_ENABLE_IN_DOMAIN.lower() == "true" else False,
"topNDocuments": AZURE_SEARCH_TOP_K,
"queryType": "semantic" if AZURE_SEARCH_USE_SEMANTIC_SEARCH.lower() == "true" else "simple",
"semanticConfiguration": AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG if AZURE_SEARCH_USE_SEMANTIC_SEARCH.lower() == "true" and AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG else "",
"roleInformation": AZURE_OPENAI_SYSTEM_MESSAGE
"queryType": query_type,
"semanticConfiguration": AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG if AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG else "",
"roleInformation": AZURE_OPENAI_SYSTEM_MESSAGE,
"embeddingEndpoint": AZURE_OPENAI_EMBEDDING_ENDPOINT,
"embeddingKey": AZURE_OPENAI_EMBEDDING_KEY,
"filter": filter
}
}
]
Expand All @@ -103,7 +164,7 @@ def prepare_body_headers_with_data(request):
headers = {
'Content-Type': 'application/json',
'api-key': AZURE_OPENAI_KEY,
"x-ms-useragent": "GitHubSampleWebApp/PublicAPI/1.0.0"
"x-ms-useragent": "GitHubSampleWebApp/PublicAPI/2.0.0"
}

return body, headers
Expand Down Expand Up @@ -152,7 +213,8 @@ def stream_with_data(body, headers, endpoint):

def conversation_with_data(request):
body, headers = prepare_body_headers_with_data(request)
endpoint = f"https://{AZURE_OPENAI_RESOURCE}.openai.azure.com/openai/deployments/{AZURE_OPENAI_MODEL}/extensions/chat/completions?api-version={AZURE_OPENAI_PREVIEW_API_VERSION}"
base_url = AZURE_OPENAI_ENDPOINT if AZURE_OPENAI_ENDPOINT else f"https://{AZURE_OPENAI_RESOURCE}.openai.azure.com/"
endpoint = f"{base_url}openai/deployments/{AZURE_OPENAI_MODEL}/extensions/chat/completions?api-version={AZURE_OPENAI_PREVIEW_API_VERSION}"

if not SHOULD_STREAM:
r = requests.post(endpoint, headers=headers, json=body)
Expand Down Expand Up @@ -190,7 +252,7 @@ def stream_without_data(response):

def conversation_without_data(request):
openai.api_type = "azure"
openai.api_base = f"https://{AZURE_OPENAI_RESOURCE}.openai.azure.com/"
openai.api_base = AZURE_OPENAI_ENDPOINT if AZURE_OPENAI_ENDPOINT else f"https://{AZURE_OPENAI_RESOURCE}.openai.azure.com/"
openai.api_version = "2023-03-15-preview"
openai.api_key = AZURE_OPENAI_KEY

Expand Down

0 comments on commit 0e14412

Please sign in to comment.