Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: Add Logit-Bias to TextCompletion abstractions #1880

Merged
merged 73 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from 69 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
d02d268
added token_selection_biases as added in #1647
sneha-afk Jul 5, 2023
1a12f85
Merge branch 'microsoft:main' into issue1854
am831 Jul 5, 2023
ab18185
added token_selection_biases
am831 Jul 5, 2023
b46f3d3
Merge branch 'issue1854' of https://github.com/am831/semantic-kernel …
am831 Jul 5, 2023
2f203d0
Merge branch 'microsoft:main' into issue1854
am831 Jul 6, 2023
99d3cad
added for loop for updating token_selection_biases
am831 Jul 6, 2023
cf3137a
Merge branch 'issue1854' of https://github.com/am831/semantic-kernel …
am831 Jul 6, 2023
190b239
use Dict and default_factory for desired behavior, compatible with Py…
sneha-afk Jul 6, 2023
78e2dd9
Merge branch 'issue1854' of github.com:am831/semantic-kernel into iss…
sneha-afk Jul 6, 2023
04a4f7a
Merge branch 'main' into issue1854
sneha-afk Jul 6, 2023
9e0e8d9
added logit bias example file
am831 Jul 6, 2023
6fe64ae
Merge branch 'microsoft:main' into issue1854
am831 Jul 6, 2023
5f1d407
Merge branch 'issue1854' of https://github.com/am831/semantic-kernel …
am831 Jul 6, 2023
47b8445
Added loop to check token_selection_biases
SebasX5 Jul 6, 2023
149bb31
Merge branch 'issue1854' of github.com:am831/semantic-kernel into iss…
sneha-afk Jul 6, 2023
512f63b
working on example file
am831 Jul 6, 2023
628b2f9
Merge branch 'issue1854' of github.com:am831/semantic-kernel into iss…
sneha-afk Jul 6, 2023
35e6730
minor changes w/ getting api_key
sneha-afk Jul 6, 2023
11d2815
working on example file
am831 Jul 6, 2023
3cea14c
Merge branch 'microsoft:main' into issue1854
am831 Jul 6, 2023
4ad8a15
working on example
am831 Jul 7, 2023
88f1098
Merge branch 'issue1854' of https://github.com/am831/semantic-kernel …
am831 Jul 7, 2023
317b015
Merge branch 'microsoft:main' into issue1854
am831 Jul 7, 2023
1affc17
logit_bias example file in progress
am831 Jul 7, 2023
cc71848
Merge branch 'issue1854' of https://github.com/am831/semantic-kernel …
am831 Jul 7, 2023
9e5eee2
logit_bias example progress
am831 Jul 7, 2023
5ccc763
logit_bias example in progress
am831 Jul 7, 2023
e218cb3
finished example file
sneha-afk Jul 7, 2023
26f5583
resolve merge conflicts
sneha-afk Jul 7, 2023
ed6cfc0
Merge branch 'main' of github.com:am831/semantic-kernel into issue1854
sneha-afk Jul 7, 2023
9b5cef1
mistakenly had .env.example deleted
sneha-afk Jul 7, 2023
f8dd569
fixed naming of .env.example
sneha-afk Jul 7, 2023
843d007
Merge branch 'microsoft:main' into issue1854
am831 Jul 8, 2023
a69173c
Merge branch 'microsoft:main' into issue1854
am831 Jul 10, 2023
abbb73f
check for banned words
am831 Jul 10, 2023
6c397ad
refactored to logit_bias, got more accurate banning of tokens
sneha-afk Jul 10, 2023
1fe3baf
text_complete_request function
am831 Jul 10, 2023
284d7ef
Merge branch 'issue1854' of https://github.com/am831/semantic-kernel …
am831 Jul 10, 2023
ee648b7
fixed output format
am831 Jul 10, 2023
551e0ee
text completion
am831 Jul 10, 2023
f4b2b1b
Added another example to logit_bias.py
SebasX5 Jul 10, 2023
fe6fbe4
text completion logic
am831 Jul 10, 2023
96e348d
Merge branch 'issue1854' of https://github.com/am831/semantic-kernel …
SebasX5 Jul 10, 2023
b634033
added mroe tokens to ban
am831 Jul 10, 2023
be31797
Merge branch 'main' of github.com:am831/semantic-kernel into issue1854
sneha-afk Jul 10, 2023
d01d099
Merge branch 'issue1854' of https://github.com/am831/semantic-kernel …
am831 Jul 10, 2023
9d309e9
Merge branch 'microsoft:main' into issue1854
am831 Jul 10, 2023
419341f
Merge branch 'issue1854' of github.com:am831/semantic-kernel into iss…
sneha-afk Jul 10, 2023
8d84a9a
fixed format
am831 Jul 10, 2023
a30ca94
Merge branch 'issue1854' of https://github.com/am831/semantic-kernel …
am831 Jul 10, 2023
42f5fba
Merge branch 'issue1854' of github.com:am831/semantic-kernel into iss…
sneha-afk Jul 10, 2023
3b06615
added clarity to logit_bias.py, removed some redundancies in the comp…
sneha-afk Jul 11, 2023
73ef0a7
fix unit tests to work with added logit_bias field
sneha-afk Jul 11, 2023
b0aaf0f
Merge branch 'main' into issue1854
awharrison-28 Jul 11, 2023
d0e107b
using complete_chat_async in logit_bias, adding back ChatRequestSettings
sneha-afk Jul 11, 2023
7ff0f54
Merge branch 'issue1854' of github.com:am831/semantic-kernel into iss…
sneha-afk Jul 11, 2023
7a24209
added unit test with logit_bias != none for test_azure_text_completion
am831 Jul 11, 2023
c65d13d
Merge branch 'microsoft:main' into issue1854
am831 Jul 11, 2023
5dd457d
added unit test with logit_bias != none
am831 Jul 11, 2023
874572a
Merge branch 'issue1854' of https://github.com/am831/semantic-kernel …
am831 Jul 11, 2023
878f03b
Added Unit test to check a call with logit_bias != None
SebasX5 Jul 11, 2023
9fdf969
Merge branch 'microsoft:main' into issue1854
am831 Jul 11, 2023
77ece2f
Merge branch 'microsoft:main' into issue1854
am831 Jul 11, 2023
a50664c
Merge branch 'main' into issue1854
sneha-afk Jul 11, 2023
620e1b3
Merge branch 'main' into issue1854
sneha-afk Jul 12, 2023
9f412d9
run pre-commit command for proper linting
sneha-afk Jul 12, 2023
c85eba0
Merge branch 'main' into issue1854
sneha-afk Jul 12, 2023
7e99a58
Merge branch 'issue1854' of github.com:am831/semantic-kernel into iss…
sneha-afk Jul 12, 2023
5bdde62
Merge branch 'microsoft:main' into issue1854
am831 Jul 13, 2023
42c5027
Rename logit_bias.py to openai_logit_bias.py
SebasX5 Jul 13, 2023
622326e
Merge branch 'main' into issue1854
SebasX5 Jul 13, 2023
aa41ea9
Merge branch 'main' into issue1854
awharrison-28 Jul 13, 2023
a0b2b60
Merge branch 'main' into issue1854
SebasX5 Jul 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
206 changes: 206 additions & 0 deletions python/samples/kernel-syntax-examples/logit_bias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright (c) Microsoft. All rights reserved.
SebasX5 marked this conversation as resolved.
Show resolved Hide resolved

import asyncio

import semantic_kernel as sk
import semantic_kernel.connectors.ai.open_ai as sk_oai
from semantic_kernel.connectors.ai.chat_request_settings import ChatRequestSettings
from semantic_kernel.connectors.ai.complete_request_settings import (
CompleteRequestSettings,
)

"""
Logit bias enables prioritizing certain tokens within a given output.
To utilize the logit bias function, you will need to know the token ids of the words you are using.
See the GPT Tokenizer to obtain token ids: https://platform.openai.com/tokenizer
Read more about logit bias and how to configure output: https://help.openai.com/en/articles/5247780-using-logit-bias-to-define-token-probability
"""


def _config_ban_tokens(settings_type, keys):
settings = (
ChatRequestSettings() if settings_type == "chat" else CompleteRequestSettings()
)

# Map each token in the keys list to a bias value from -100 (a potential ban) to 100 (exclusive selection)
for k in keys:
# -100 to potentially ban all tokens in the list
settings.token_selection_biases[k] = -100
return settings


async def chat_request_example(kernel, api_key, org_id):
openai_chat_completion = sk_oai.OpenAIChatCompletion(
"gpt-3.5-turbo", api_key, org_id
)
kernel.add_chat_service("chat_service", openai_chat_completion)

# Spaces and capitalization affect the token ids.
# The following is the token ids of basketball related words.
keys = [
2032,
680,
9612,
26675,
3438,
42483,
21265,
6057,
11230,
1404,
2484,
12494,
35,
822,
11108,
]
banned_words = [
"swish",
"screen",
"score",
"dominant",
"basketball",
"game",
"GOAT",
"Shooting",
"Dribbling",
]

# Model will try its best to avoid using any of the above words
settings = _config_ban_tokens("chat", keys)

prompt_config = sk.PromptTemplateConfig.from_completion_parameters(
max_tokens=2000, temperature=0.7, top_p=0.8
)
prompt_template = sk.ChatPromptTemplate(
"{{$user_input}}", kernel.prompt_template_engine, prompt_config
)

# Setup chat with prompt
prompt_template.add_system_message("You are a basketball expert")
user_mssg = "I love the LA Lakers, tell me an interesting fact about LeBron James."
prompt_template.add_user_message(user_mssg)
function_config = sk.SemanticFunctionConfig(prompt_config, prompt_template)
kernel.register_semantic_function("ChatBot", "Chat", function_config)

chat_messages = list()
chat_messages.append(("user", user_mssg))
answer = await openai_chat_completion.complete_chat_async(chat_messages, settings)
chat_messages.append(("assistant", str(answer)))

user_mssg = "What are his best all-time stats?"
chat_messages.append(("user", user_mssg))
answer = await openai_chat_completion.complete_chat_async(chat_messages, settings)
chat_messages.append(("assistant", str(answer)))

context_vars = sk.ContextVariables()
context_vars["chat_history"] = ""
context_vars["chat_bot_ans"] = ""
for role, mssg in chat_messages:
if role == "user":
context_vars["chat_history"] += f"User:> {mssg}\n"
elif role == "assistant":
context_vars["chat_history"] += f"ChatBot:> {mssg}\n"
context_vars["chat_bot_ans"] += f"{mssg}\n"

kernel.remove_chat_service("chat_service")
return context_vars, banned_words


async def text_complete_request_example(kernel, api_key, org_id):
openai_text_completion = sk_oai.OpenAITextCompletion(
"text-davinci-002", api_key, org_id
)
kernel.add_text_completion_service("text_service", openai_text_completion)

# Spaces and capitalization affect the token ids.
# The following is the token ids of pie related words.
keys = [
18040,
17180,
16108,
4196,
79,
931,
5116,
30089,
36724,
47,
931,
5116,
431,
5171,
613,
5171,
350,
721,
272,
47,
721,
272,
]
banned_words = [
"apple",
" apple",
"Apple",
" Apple",
"pumpkin",
" pumpkin",
" Pumpkin",
"pecan",
" pecan",
" Pecan",
"Pecan",
]

# Model will try its best to avoid using any of the above words
settings = _config_ban_tokens("complete", keys)

user_mssg = "The best pie flavor to have in autumn is"
answer = await openai_text_completion.complete_async(user_mssg, settings)

context_vars = sk.ContextVariables()
context_vars["chat_history"] = f"User:> {user_mssg}\nChatBot:> {answer}\n"
context_vars["chat_bot_ans"] = str(answer)

kernel.remove_text_completion_service("text_service")
return context_vars, banned_words


def _check_banned_words(banned_list, actual_list) -> bool:
passed = True
for word in banned_list:
if word in actual_list:
print(f'The banned word "{word}" was found in the answer')
passed = False
return passed


def _format_output(context, banned_words) -> None:
print(context["chat_history"])
chat_bot_ans_words = context["chat_bot_ans"].split()
if _check_banned_words(banned_words, chat_bot_ans_words):
print("None of the banned words were found in the answer")


async def main() -> None:
kernel = sk.Kernel()
api_key, org_id = sk.openai_settings_from_dot_env()

print("Chat completion example:")
print("------------------------")
chat, banned_words = await chat_request_example(kernel, api_key, org_id)
_format_output(chat, banned_words)

print("------------------------")

print("\nText completion example:")
print("------------------------")
chat, banned_words = await text_complete_request_example(kernel, api_key, org_id)
_format_output(chat, banned_words)

return


if __name__ == "__main__":
asyncio.run(main())
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

from dataclasses import dataclass
from typing import TYPE_CHECKING
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict

if TYPE_CHECKING:
from semantic_kernel.semantic_functions.prompt_template_config import (
Expand All @@ -17,6 +17,7 @@ class ChatRequestSettings:
frequency_penalty: float = 0.0
number_of_responses: int = 1
max_tokens: int = 256
token_selection_biases: Dict[int, int] = field(default_factory=dict)

def update_from_completion_config(
self, completion_config: "PromptTemplateConfig.CompletionConfig"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, Dict, List

if TYPE_CHECKING:
from semantic_kernel.semantic_functions.prompt_template_config import (
Expand All @@ -19,6 +19,7 @@ class CompleteRequestSettings:
stop_sequences: List[str] = field(default_factory=list)
number_of_responses: int = 1
logprobs: int = 0
token_selection_biases: Dict[int, int] = field(default_factory=dict)

def update_from_completion_config(
self, completion_config: "PromptTemplateConfig.CompletionConfig"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ async def complete_async(
frequency_penalty=request_settings.frequency_penalty,
max_tokens=request_settings.max_tokens,
number_of_responses=request_settings.number_of_responses,
token_selection_biases=request_settings.token_selection_biases,
)
response = await self._send_chat_request(
prompt_to_message, chat_settings, False
Expand All @@ -129,6 +130,7 @@ async def complete_stream_async(
frequency_penalty=request_settings.frequency_penalty,
max_tokens=request_settings.max_tokens,
number_of_responses=request_settings.number_of_responses,
token_selection_biases=request_settings.token_selection_biases,
)
response = await self._send_chat_request(prompt_to_message, chat_settings, True)

Expand Down Expand Up @@ -208,6 +210,12 @@ async def _send_chat_request(
max_tokens=request_settings.max_tokens,
n=request_settings.number_of_responses,
stream=stream,
logit_bias=(
request_settings.token_selection_biases
if request_settings.token_selection_biases is not None
and len(request_settings.token_selection_biases) > 0
else None
),
)
except Exception as ex:
raise AIException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ async def _send_completion_request(
and len(request_settings.stop_sequences) > 0
else None
),
logit_bias=(
request_settings.token_selection_biases
if request_settings.token_selection_biases is not None
and len(request_settings.token_selection_biases) > 0
else None
),
)
except Exception as ex:
raise AIException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,54 @@ async def test_azure_chat_completion_call_with_parameters() -> None:
frequency_penalty=complete_request_settings.frequency_penalty,
n=complete_request_settings.number_of_responses,
stream=False,
SebasX5 marked this conversation as resolved.
Show resolved Hide resolved
logit_bias=None,
)


@pytest.mark.asyncio
async def test_azure_chat_completion_call_with_parameters_and_Logit_Bias_Defined() -> None:
mock_openai = AsyncMock()
with patch(
"semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion.openai",
new=mock_openai,
):
deployment_name = "test_deployment"
endpoint = "https://test-endpoint.com"
api_key = "test_api_key"
api_type = "azure"
api_version = "2023-03-15-preview"
logger = Logger("test_logger")
prompt = "hello world"
messages = [{"role": "user", "content": prompt}]
complete_request_settings = CompleteRequestSettings()

token_bias = {1: -100}
complete_request_settings.token_selection_biases = token_bias

azure_chat_completion = AzureChatCompletion(
deployment_name=deployment_name,
endpoint=endpoint,
api_key=api_key,
api_version=api_version,
logger=logger,
)

await azure_chat_completion.complete_async(prompt, complete_request_settings)

mock_openai.ChatCompletion.acreate.assert_called_once_with(
engine=deployment_name,
api_key=api_key,
api_type=api_type,
api_base=endpoint,
api_version=api_version,
organization=None,
messages=messages,
temperature=complete_request_settings.temperature,
max_tokens=complete_request_settings.max_tokens,
top_p=complete_request_settings.top_p,
presence_penalty=complete_request_settings.presence_penalty,
frequency_penalty=complete_request_settings.frequency_penalty,
n=complete_request_settings.number_of_responses,
stream=False,
logit_bias=token_bias,
)
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,54 @@ async def test_azure_text_completion_call_with_parameters() -> None:
stop=None,
n=complete_request_settings.number_of_responses,
stream=False,
SebasX5 marked this conversation as resolved.
Show resolved Hide resolved
logit_bias=None,
)


@pytest.mark.asyncio
async def test_azure_text_completion_call_with_parameters_logit_bias_not_none() -> None:
mock_openai = AsyncMock()
with patch(
"semantic_kernel.connectors.ai.open_ai.services.open_ai_text_completion.openai",
new=mock_openai,
):
deployment_name = "test_deployment"
endpoint = "https://test-endpoint.com"
api_key = "test_api_key"
api_type = "azure"
api_version = "2023-03-15-preview"
logger = Logger("test_logger")
prompt = "hello world"
complete_request_settings = CompleteRequestSettings()

token_bias = {200: 100}
complete_request_settings.token_selection_biases = token_bias

azure_text_completion = AzureTextCompletion(
deployment_name=deployment_name,
endpoint=endpoint,
api_key=api_key,
api_version=api_version,
logger=logger,
)

await azure_text_completion.complete_async(prompt, complete_request_settings)

mock_openai.Completion.acreate.assert_called_once_with(
engine=deployment_name,
api_key=api_key,
api_type=api_type,
api_base=endpoint,
api_version=api_version,
organization=None,
prompt=prompt,
temperature=complete_request_settings.temperature,
max_tokens=complete_request_settings.max_tokens,
top_p=complete_request_settings.top_p,
presence_penalty=complete_request_settings.presence_penalty,
frequency_penalty=complete_request_settings.frequency_penalty,
stop=None,
n=complete_request_settings.number_of_responses,
stream=False,
logit_bias=token_bias,
)