Skip to content

Commit

Permalink
Merge KernelConfig with Kernel (microsoft#731)
Browse files Browse the repository at this point in the history
### Motivation and Context
KernelConfig isn't needed as it's less pythonic and is not used
independently from Kernel. This moves all its methods and members into
Kernel.

### Description
- Removed KernelConfig
- Put KernelConfig's methods and variables into Kernel
- Ran unit and end-to-end tests
- Did not run notebooks because they are out of date with the current
repo.

---------
Co-authored-by: Shawn Callegari <36091529+shawncal@users.noreply.github.com>
  • Loading branch information
mkarle committed May 6, 2023
1 parent 73c1dd2 commit 656d4c4
Show file tree
Hide file tree
Showing 39 changed files with 486 additions and 347 deletions.
4 changes: 2 additions & 2 deletions python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ kernel = sk.Kernel()

# Prepare OpenAI service using credentials stored in the `.env` file
api_key, org_id = sk.openai_settings_from_dot_env()
kernel.config.add_text_completion_service("dv", OpenAITextCompletion("text-davinci-003", api_key, org_id))
kernel.add_text_completion_service("dv", OpenAITextCompletion("text-davinci-003", api_key, org_id))

# Alternative using Azure:
# deployment, api_key, endpoint = sk.azure_openai_settings_from_dot_env()
# kernel.config.add_text_completion_service("dv", AzureTextCompletion(deployment, endpoint, api_key))
# kernel.add_text_completion_service("dv", AzureTextCompletion(deployment, endpoint, api_key))

# Wrap your prompt in a function
prompt = kernel.create_semantic_function("""
Expand Down
2 changes: 0 additions & 2 deletions python/semantic_kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from semantic_kernel import core_skills, memory
from semantic_kernel.kernel import Kernel
from semantic_kernel.kernel_config import KernelConfig
from semantic_kernel.orchestration.context_variables import ContextVariables
from semantic_kernel.orchestration.sk_context import SKContext
from semantic_kernel.orchestration.sk_function_base import SKFunctionBase
Expand All @@ -22,7 +21,6 @@

__all__ = [
"Kernel",
"KernelConfig",
"NullLogger",
"openai_settings_from_dot_env",
"azure_openai_settings_from_dot_env",
Expand Down
267 changes: 256 additions & 11 deletions python/semantic_kernel/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import inspect
from logging import Logger
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union

from semantic_kernel.connectors.ai.ai_exception import AIException
from semantic_kernel.connectors.ai.chat_completion_client_base import (
Expand All @@ -12,11 +12,13 @@
from semantic_kernel.connectors.ai.complete_request_settings import (
CompleteRequestSettings,
)
from semantic_kernel.connectors.ai.embeddings.embedding_generator_base import (
EmbeddingGeneratorBase,
)
from semantic_kernel.connectors.ai.text_completion_client_base import (
TextCompletionClientBase,
)
from semantic_kernel.kernel_base import KernelBase
from semantic_kernel.kernel_config import KernelConfig
from semantic_kernel.kernel_exception import KernelException
from semantic_kernel.kernel_extensions import KernelExtensions
from semantic_kernel.memory.memory_store_base import MemoryStoreBase
Expand All @@ -26,6 +28,10 @@
from semantic_kernel.orchestration.sk_context import SKContext
from semantic_kernel.orchestration.sk_function import SKFunction
from semantic_kernel.orchestration.sk_function_base import SKFunctionBase
from semantic_kernel.reliability.pass_through_without_retry import (
PassThroughWithoutRetry,
)
from semantic_kernel.reliability.retry_mechanism import RetryMechanism
from semantic_kernel.semantic_functions.semantic_function_config import (
SemanticFunctionConfig,
)
Expand All @@ -41,10 +47,10 @@
from semantic_kernel.utils.null_logger import NullLogger
from semantic_kernel.utils.validation import validate_function_name, validate_skill_name

T = TypeVar("T")

class Kernel(KernelBase, KernelExtensions):
_log: Logger
_config: KernelConfig
_skill_collection: SkillCollectionBase
_prompt_template_engine: PromptTemplatingEngine
_memory: SemanticTextMemoryBase
Expand All @@ -54,11 +60,9 @@ def __init__(
skill_collection: Optional[SkillCollectionBase] = None,
prompt_template_engine: Optional[PromptTemplatingEngine] = None,
memory: Optional[SemanticTextMemoryBase] = None,
config: Optional[KernelConfig] = None,
log: Optional[Logger] = None,
) -> None:
self._log = log if log else NullLogger()
self._config = config if config else KernelConfig()
self._skill_collection = (
skill_collection if skill_collection else SkillCollection(self._log)
)
Expand All @@ -69,13 +73,25 @@ def __init__(
)
self._memory = memory if memory else NullMemory()

self._text_completion_services: Dict[
str, Callable[["KernelBase"], TextCompletionClientBase]
] = {}
self._chat_services: Dict[
str, Callable[["KernelBase"], ChatCompletionClientBase]
] = {}
self._text_embedding_generation_services: Dict[
str, Callable[["KernelBase"], EmbeddingGeneratorBase]
] = {}

self._default_text_completion_service: Optional[str] = None
self._default_chat_service: Optional[str] = None
self._default_text_embedding_generation_service: Optional[str] = None

self._retry_mechanism: RetryMechanism = PassThroughWithoutRetry()

def kernel(self) -> KernelBase:
return self

@property
def config(self) -> KernelConfig:
return self._config

@property
def logger(self) -> Logger:
return self._log
Expand Down Expand Up @@ -248,6 +264,235 @@ def import_skill(

return skill

def get_ai_service(
self, type: Type[T], service_id: Optional[str] = None
) -> Callable[["KernelBase"], T]:
matching_type = {}
if type == TextCompletionClientBase:
service_id = service_id or self._default_text_completion_service
matching_type = self._text_completion_services
elif type == ChatCompletionClientBase:
service_id = service_id or self._default_chat_service
matching_type = self._chat_services
elif type == EmbeddingGeneratorBase:
service_id = service_id or self._default_text_embedding_generation_service
matching_type = self._text_embedding_generation_services
else:
raise ValueError(f"Unknown AI service type: {type.__name__}")

if service_id not in matching_type:
raise ValueError(
f"{type.__name__} service with service_id '{service_id}' not found"
)

return matching_type[service_id]

def all_text_completion_services(self) -> List[str]:
return list(self._text_completion_services.keys())

def all_chat_services(self) -> List[str]:
return list(self._chat_services.keys())

def all_text_embedding_generation_services(self) -> List[str]:
return list(self._text_embedding_generation_services.keys())

def add_text_completion_service(
self,
service_id: str,
service: Union[
TextCompletionClientBase, Callable[["KernelBase"], TextCompletionClientBase]
],
overwrite: bool = True,
) -> "Kernel":
if not service_id:
raise ValueError("service_id must be a non-empty string")
if not overwrite and service_id in self._text_completion_services:
raise ValueError(
f"Text service with service_id '{service_id}' already exists"
)

self._text_completion_services[service_id] = (
service if isinstance(service, Callable) else lambda _: service
)
if self._default_text_completion_service is None:
self._default_text_completion_service = service_id

return self

def add_chat_service(
self,
service_id: str,
service: Union[
ChatCompletionClientBase, Callable[["KernelBase"], ChatCompletionClientBase]
],
overwrite: bool = True,
) -> "Kernel":
if not service_id:
raise ValueError("service_id must be a non-empty string")
if not overwrite and service_id in self._chat_services:
raise ValueError(
f"Chat service with service_id '{service_id}' already exists"
)

self._chat_services[service_id] = (
service if isinstance(service, Callable) else lambda _: service
)
if self._default_chat_service is None:
self._default_chat_service = service_id

if isinstance(service, TextCompletionClientBase):
self.add_text_completion_service(service_id, service)
if self._default_text_completion_service is None:
self._default_text_completion_service = service_id

return self

def add_text_embedding_generation_service(
self,
service_id: str,
service: Union[
EmbeddingGeneratorBase, Callable[["KernelBase"], EmbeddingGeneratorBase]
],
overwrite: bool = False,
) -> "Kernel":
if not service_id:
raise ValueError("service_id must be a non-empty string")
if not overwrite and service_id in self._text_embedding_generation_services:
raise ValueError(
f"Embedding service with service_id '{service_id}' already exists"
)

self._text_embedding_generation_services[service_id] = (
service if isinstance(service, Callable) else lambda _: service
)
if self._default_text_embedding_generation_service is None:
self._default_text_embedding_generation_service = service_id

return self

def set_default_text_completion_service(self, service_id: str) -> "Kernel":
if service_id not in self._text_completion_services:
raise ValueError(
f"AI service with service_id '{service_id}' does not exist"
)

self._default_text_completion_service = service_id
return self

def set_default_chat_service(self, service_id: str) -> "Kernel":
if service_id not in self._chat_services:
raise ValueError(
f"AI service with service_id '{service_id}' does not exist"
)

self._default_chat_service = service_id
return self

def set_default_text_embedding_generation_service(
self, service_id: str
) -> "Kernel":
if service_id not in self._text_embedding_generation_services:
raise ValueError(
f"AI service with service_id '{service_id}' does not exist"
)

self._default_text_embedding_generation_service = service_id
return self

def get_text_completion_service_service_id(
self, service_id: Optional[str] = None
) -> str:
if service_id is None or service_id not in self._text_completion_services:
if self._default_text_completion_service is None:
raise ValueError("No default text service is set")
return self._default_text_completion_service

return service_id

def get_chat_service_service_id(self, service_id: Optional[str] = None) -> str:
if service_id is None or service_id not in self._chat_services:
if self._default_chat_service is None:
raise ValueError("No default chat service is set")
return self._default_chat_service

return service_id

def get_text_embedding_generation_service_id(
self, service_id: Optional[str] = None
) -> str:
if (
service_id is None
or service_id not in self._text_embedding_generation_services
):
if self._default_text_embedding_generation_service is None:
raise ValueError("No default embedding service is set")
return self._default_text_embedding_generation_service

return service_id

def remove_text_completion_service(self, service_id: str) -> "Kernel":
if service_id not in self._text_completion_services:
raise ValueError(
f"AI service with service_id '{service_id}' does not exist"
)

del self._text_completion_services[service_id]
if self._default_text_completion_service == service_id:
self._default_text_completion_service = next(
iter(self._text_completion_services), None
)
return self

def remove_chat_service(self, service_id: str) -> "Kernel":
if service_id not in self._chat_services:
raise ValueError(
f"AI service with service_id '{service_id}' does not exist"
)

del self._chat_services[service_id]
if self._default_chat_service == service_id:
self._default_chat_service = next(iter(self._chat_services), None)
return self

def remove_text_embedding_generation_service(self, service_id: str) -> "Kernel":
if service_id not in self._text_embedding_generation_services:
raise ValueError(
f"AI service with service_id '{service_id}' does not exist"
)

del self._text_embedding_generation_services[service_id]
if self._default_text_embedding_generation_service == service_id:
self._default_text_embedding_generation_service = next(
iter(self._text_embedding_generation_services), None
)
return self

def clear_all_text_completion_services(self) -> "Kernel":
self._text_completion_services = {}
self._default_text_completion_service = None
return self

def clear_all_chat_services(self) -> "Kernel":
self._chat_services = {}
self._default_chat_service = None
return self

def clear_all_text_embedding_generation_services(self) -> "Kernel":
self._text_embedding_generation_services = {}
self._default_text_embedding_generation_service = None
return self

def clear_all_services(self) -> "Kernel":
self._text_completion_services = {}
self._chat_services = {}
self._text_embedding_generation_services = {}

self._default_text_completion_service = None
self._default_chat_service = None
self._default_text_embedding_generation_service = None

return self

def _create_semantic_function(
self,
skill_name: str,
Expand All @@ -274,7 +519,7 @@ def _create_semantic_function(
function.set_default_skill_collection(self.skills)

if function_config.has_chat_prompt:
service = self._config.get_ai_service(
service = self.get_ai_service(
ChatCompletionClientBase,
function_config.prompt_template_config.default_services[0]
if len(function_config.prompt_template_config.default_services) > 0
Expand All @@ -297,7 +542,7 @@ def _create_semantic_function(

function.set_chat_service(lambda: service(self))
else:
service = self._config.get_ai_service(
service = self.get_ai_service(
TextCompletionClientBase,
function_config.prompt_template_config.default_services[0]
if len(function_config.prompt_template_config.default_services) > 0
Expand Down
6 changes: 0 additions & 6 deletions python/semantic_kernel/kernel_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from logging import Logger
from typing import Any, Dict, Optional

from semantic_kernel.kernel_config import KernelConfig
from semantic_kernel.memory.semantic_text_memory_base import SemanticTextMemoryBase
from semantic_kernel.orchestration.context_variables import ContextVariables
from semantic_kernel.orchestration.sk_context import SKContext
Expand All @@ -21,11 +20,6 @@


class KernelBase(ABC):
@property
@abstractmethod
def config(self) -> KernelConfig:
pass

@property
@abstractmethod
def logger(self) -> Logger:
Expand Down

0 comments on commit 656d4c4

Please sign in to comment.