Skip to content

Commit

Permalink
Reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
BeatrixCohere committed May 23, 2024
1 parent 6f1c18f commit e56fb71
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
8 changes: 4 additions & 4 deletions src/backend/chat/custom/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ class CustomChat(BaseChat):

logger = get_logger()

def chat(
self, chat_request: CohereChatRequest, **kwargs: Any
) -> Any:
def chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
"""
Chat flow for custom models.
Expand All @@ -32,7 +30,9 @@ def chat(
Generator[StreamResponse, None, None]: Chat response.
"""
# Choose the deployment model - validation already performed by request validator
deployment_model = get_deployment(kwargs.get("deployment_name"), kwargs.get("deployment_config"))
deployment_model = get_deployment(
kwargs.get("deployment_name"), kwargs.get("deployment_config")
)
self.logger.info(f"Using deployment {deployment_model.__class__.__name__}")

if len(chat_request.tools) > 0 and len(chat_request.documents) > 0:
Expand Down
10 changes: 5 additions & 5 deletions src/backend/chat/custom/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from backend.model_deployments.base import BaseDeployment


def get_deployment(deployment_name, model_config) -> BaseDeployment:
def get_deployment(name, config) -> BaseDeployment:
"""Get the deployment implementation.
Args:
Expand All @@ -14,17 +14,17 @@ def get_deployment(deployment_name, model_config) -> BaseDeployment:
Raises:
ValueError: If the deployment is not supported.
"""
deployment = AVAILABLE_MODEL_DEPLOYMENTS.get(deployment_name)
deployment = AVAILABLE_MODEL_DEPLOYMENTS.get(name)

# Check provided deployment against config const
if deployment is not None and deployment.is_available:
return deployment.deployment_class(model_config, **deployment.kwargs)
return deployment.deployment_class(config, **deployment.kwargs)

# Fallback to first available deployment
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.is_available:
return deployment.deployment_class(model_config)
return deployment.deployment_class(config)

raise ValueError(
f"Deployment {deployment_name} is not supported, and no available deployments were found."
f"Deployment {name} is not supported, and no available deployments were found."
)
3 changes: 1 addition & 2 deletions src/backend/model_deployments/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def is_available(cls) -> bool:
return all([os.environ.get(var) is not None for var in SAGE_MAKER_ENV_VARS])

def invoke_chat_stream(
self, model_config: dict, chat_request: CohereChatRequest, **kwargs: Any
self, chat_request: CohereChatRequest, **kwargs: Any
) -> Generator[StreamedChatResponse, None, None]:
# Create the payload for the request
json_params = {
Expand All @@ -95,7 +95,6 @@ def invoke_chat_stream(

def invoke_search_queries(
self,
model_config: dict,
message: str,
chat_history: List[Dict[str, str]] | None = None,
**kwargs: Any
Expand Down

0 comments on commit e56fb71

Please sign in to comment.