Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge KernelConfig with Kernel #731

Merged
merged 13 commits into from
May 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ kernel = sk.Kernel()

# Prepare OpenAI service using credentials stored in the `.env` file
api_key, org_id = sk.openai_settings_from_dot_env()
kernel.config.add_text_completion_service("dv", OpenAITextCompletion("text-davinci-003", api_key, org_id))
kernel.add_text_completion_service("dv", OpenAITextCompletion("text-davinci-003", api_key, org_id))

# Alternative using Azure:
# deployment, api_key, endpoint = sk.azure_openai_settings_from_dot_env()
# kernel.config.add_text_completion_service("dv", AzureTextCompletion(deployment, endpoint, api_key))
# kernel.add_text_completion_service("dv", AzureTextCompletion(deployment, endpoint, api_key))

# Wrap your prompt in a function
prompt = kernel.create_semantic_function("""
Expand Down
2 changes: 0 additions & 2 deletions python/semantic_kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from semantic_kernel import core_skills, memory
from semantic_kernel.kernel import Kernel
from semantic_kernel.kernel_config import KernelConfig
from semantic_kernel.orchestration.context_variables import ContextVariables
from semantic_kernel.orchestration.sk_context import SKContext
from semantic_kernel.orchestration.sk_function_base import SKFunctionBase
Expand All @@ -22,7 +21,6 @@

__all__ = [
"Kernel",
"KernelConfig",
"NullLogger",
"openai_settings_from_dot_env",
"azure_openai_settings_from_dot_env",
Expand Down
267 changes: 256 additions & 11 deletions python/semantic_kernel/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import inspect
from logging import Logger
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union

from semantic_kernel.connectors.ai.ai_exception import AIException
from semantic_kernel.connectors.ai.chat_completion_client_base import (
Expand All @@ -12,11 +12,13 @@
from semantic_kernel.connectors.ai.complete_request_settings import (
CompleteRequestSettings,
)
from semantic_kernel.connectors.ai.embeddings.embedding_generator_base import (
EmbeddingGeneratorBase,
)
from semantic_kernel.connectors.ai.text_completion_client_base import (
TextCompletionClientBase,
)
from semantic_kernel.kernel_base import KernelBase
from semantic_kernel.kernel_config import KernelConfig
from semantic_kernel.kernel_exception import KernelException
from semantic_kernel.kernel_extensions import KernelExtensions
from semantic_kernel.memory.memory_store_base import MemoryStoreBase
Expand All @@ -26,6 +28,10 @@
from semantic_kernel.orchestration.sk_context import SKContext
from semantic_kernel.orchestration.sk_function import SKFunction
from semantic_kernel.orchestration.sk_function_base import SKFunctionBase
from semantic_kernel.reliability.pass_through_without_retry import (
PassThroughWithoutRetry,
)
from semantic_kernel.reliability.retry_mechanism import RetryMechanism
from semantic_kernel.semantic_functions.semantic_function_config import (
SemanticFunctionConfig,
)
Expand All @@ -41,10 +47,10 @@
from semantic_kernel.utils.null_logger import NullLogger
from semantic_kernel.utils.validation import validate_function_name, validate_skill_name

T = TypeVar("T")

class Kernel(KernelBase, KernelExtensions):
_log: Logger
_config: KernelConfig
_skill_collection: SkillCollectionBase
_prompt_template_engine: PromptTemplatingEngine
_memory: SemanticTextMemoryBase
Expand All @@ -54,11 +60,9 @@ def __init__(
skill_collection: Optional[SkillCollectionBase] = None,
prompt_template_engine: Optional[PromptTemplatingEngine] = None,
memory: Optional[SemanticTextMemoryBase] = None,
config: Optional[KernelConfig] = None,
log: Optional[Logger] = None,
) -> None:
self._log = log if log else NullLogger()
self._config = config if config else KernelConfig()
self._skill_collection = (
skill_collection if skill_collection else SkillCollection(self._log)
)
Expand All @@ -69,13 +73,25 @@ def __init__(
)
self._memory = memory if memory else NullMemory()

self._text_completion_services: Dict[
str, Callable[["KernelBase"], TextCompletionClientBase]
] = {}
self._chat_services: Dict[
str, Callable[["KernelBase"], ChatCompletionClientBase]
] = {}
self._text_embedding_generation_services: Dict[
str, Callable[["KernelBase"], EmbeddingGeneratorBase]
] = {}

self._default_text_completion_service: Optional[str] = None
self._default_chat_service: Optional[str] = None
self._default_text_embedding_generation_service: Optional[str] = None

self._retry_mechanism: RetryMechanism = PassThroughWithoutRetry()

def kernel(self) -> KernelBase:
return self

@property
def config(self) -> KernelConfig:
return self._config

@property
def logger(self) -> Logger:
return self._log
Expand Down Expand Up @@ -248,6 +264,235 @@ def import_skill(

return skill

def get_ai_service(
self, type: Type[T], service_id: Optional[str] = None
) -> Callable[["KernelBase"], T]:
matching_type = {}
if type == TextCompletionClientBase:
service_id = service_id or self._default_text_completion_service
matching_type = self._text_completion_services
elif type == ChatCompletionClientBase:
service_id = service_id or self._default_chat_service
matching_type = self._chat_services
elif type == EmbeddingGeneratorBase:
service_id = service_id or self._default_text_embedding_generation_service
matching_type = self._text_embedding_generation_services
else:
raise ValueError(f"Unknown AI service type: {type.__name__}")

if service_id not in matching_type:
raise ValueError(
f"{type.__name__} service with service_id '{service_id}' not found"
)

return matching_type[service_id]

def all_text_completion_services(self) -> List[str]:
return list(self._text_completion_services.keys())

def all_chat_services(self) -> List[str]:
return list(self._chat_services.keys())

def all_text_embedding_generation_services(self) -> List[str]:
return list(self._text_embedding_generation_services.keys())

def add_text_completion_service(
self,
service_id: str,
service: Union[
TextCompletionClientBase, Callable[["KernelBase"], TextCompletionClientBase]
],
overwrite: bool = True,
) -> "Kernel":
if not service_id:
raise ValueError("service_id must be a non-empty string")
if not overwrite and service_id in self._text_completion_services:
raise ValueError(
f"Text service with service_id '{service_id}' already exists"
)

self._text_completion_services[service_id] = (
service if isinstance(service, Callable) else lambda _: service
)
if self._default_text_completion_service is None:
self._default_text_completion_service = service_id

return self

def add_chat_service(
self,
service_id: str,
service: Union[
ChatCompletionClientBase, Callable[["KernelBase"], ChatCompletionClientBase]
],
overwrite: bool = True,
) -> "Kernel":
if not service_id:
raise ValueError("service_id must be a non-empty string")
if not overwrite and service_id in self._chat_services:
raise ValueError(
f"Chat service with service_id '{service_id}' already exists"
)

self._chat_services[service_id] = (
service if isinstance(service, Callable) else lambda _: service
)
if self._default_chat_service is None:
self._default_chat_service = service_id

if isinstance(service, TextCompletionClientBase):
self.add_text_completion_service(service_id, service)
if self._default_text_completion_service is None:
self._default_text_completion_service = service_id

return self

def add_text_embedding_generation_service(
self,
service_id: str,
service: Union[
EmbeddingGeneratorBase, Callable[["KernelBase"], EmbeddingGeneratorBase]
],
overwrite: bool = False,
) -> "Kernel":
if not service_id:
raise ValueError("service_id must be a non-empty string")
if not overwrite and service_id in self._text_embedding_generation_services:
raise ValueError(
f"Embedding service with service_id '{service_id}' already exists"
)

self._text_embedding_generation_services[service_id] = (
service if isinstance(service, Callable) else lambda _: service
)
if self._default_text_embedding_generation_service is None:
self._default_text_embedding_generation_service = service_id

return self

def set_default_text_completion_service(self, service_id: str) -> "Kernel":
if service_id not in self._text_completion_services:
raise ValueError(
f"AI service with service_id '{service_id}' does not exist"
)

self._default_text_completion_service = service_id
return self

def set_default_chat_service(self, service_id: str) -> "Kernel":
if service_id not in self._chat_services:
raise ValueError(
f"AI service with service_id '{service_id}' does not exist"
)

self._default_chat_service = service_id
return self

def set_default_text_embedding_generation_service(
self, service_id: str
) -> "Kernel":
if service_id not in self._text_embedding_generation_services:
raise ValueError(
f"AI service with service_id '{service_id}' does not exist"
)

self._default_text_embedding_generation_service = service_id
return self

def get_text_completion_service_service_id(
self, service_id: Optional[str] = None
) -> str:
if service_id is None or service_id not in self._text_completion_services:
if self._default_text_completion_service is None:
raise ValueError("No default text service is set")
return self._default_text_completion_service

return service_id

def get_chat_service_service_id(self, service_id: Optional[str] = None) -> str:
if service_id is None or service_id not in self._chat_services:
if self._default_chat_service is None:
raise ValueError("No default chat service is set")
return self._default_chat_service

return service_id

def get_text_embedding_generation_service_id(
self, service_id: Optional[str] = None
) -> str:
if (
service_id is None
or service_id not in self._text_embedding_generation_services
):
if self._default_text_embedding_generation_service is None:
raise ValueError("No default embedding service is set")
return self._default_text_embedding_generation_service

return service_id

def remove_text_completion_service(self, service_id: str) -> "Kernel":
if service_id not in self._text_completion_services:
raise ValueError(
f"AI service with service_id '{service_id}' does not exist"
)

del self._text_completion_services[service_id]
if self._default_text_completion_service == service_id:
self._default_text_completion_service = next(
iter(self._text_completion_services), None
)
return self

def remove_chat_service(self, service_id: str) -> "Kernel":
if service_id not in self._chat_services:
raise ValueError(
f"AI service with service_id '{service_id}' does not exist"
)

del self._chat_services[service_id]
if self._default_chat_service == service_id:
self._default_chat_service = next(iter(self._chat_services), None)
return self

def remove_text_embedding_generation_service(self, service_id: str) -> "Kernel":
if service_id not in self._text_embedding_generation_services:
raise ValueError(
f"AI service with service_id '{service_id}' does not exist"
)

del self._text_embedding_generation_services[service_id]
if self._default_text_embedding_generation_service == service_id:
self._default_text_embedding_generation_service = next(
iter(self._text_embedding_generation_services), None
)
return self

def clear_all_text_completion_services(self) -> "Kernel":
self._text_completion_services = {}
self._default_text_completion_service = None
return self

def clear_all_chat_services(self) -> "Kernel":
self._chat_services = {}
self._default_chat_service = None
return self

def clear_all_text_embedding_generation_services(self) -> "Kernel":
self._text_embedding_generation_services = {}
self._default_text_embedding_generation_service = None
return self

def clear_all_services(self) -> "Kernel":
self._text_completion_services = {}
self._chat_services = {}
self._text_embedding_generation_services = {}

self._default_text_completion_service = None
self._default_chat_service = None
self._default_text_embedding_generation_service = None

return self

def _create_semantic_function(
self,
skill_name: str,
Expand All @@ -274,7 +519,7 @@ def _create_semantic_function(
function.set_default_skill_collection(self.skills)

if function_config.has_chat_prompt:
service = self._config.get_ai_service(
service = self.get_ai_service(
ChatCompletionClientBase,
function_config.prompt_template_config.default_services[0]
if len(function_config.prompt_template_config.default_services) > 0
Expand All @@ -297,7 +542,7 @@ def _create_semantic_function(

function.set_chat_service(lambda: service(self))
else:
service = self._config.get_ai_service(
service = self.get_ai_service(
TextCompletionClientBase,
function_config.prompt_template_config.default_services[0]
if len(function_config.prompt_template_config.default_services) > 0
Expand Down
6 changes: 0 additions & 6 deletions python/semantic_kernel/kernel_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from logging import Logger
from typing import Any, Dict, Optional

from semantic_kernel.kernel_config import KernelConfig
from semantic_kernel.memory.semantic_text_memory_base import SemanticTextMemoryBase
from semantic_kernel.orchestration.context_variables import ContextVariables
from semantic_kernel.orchestration.sk_context import SKContext
Expand All @@ -21,11 +20,6 @@


class KernelBase(ABC):
@property
@abstractmethod
def config(self) -> KernelConfig:
pass

@property
@abstractmethod
def logger(self) -> Logger:
Expand Down