Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove generate endpoints #3654

Merged
merged 1 commit into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions python/kserve/kserve/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from .protocol.grpc import grpc_predict_v2_pb2_grpc
from .protocol.grpc.grpc_predict_v2_pb2 import ModelInferRequest, ModelInferResponse
from .protocol.infer_type import InferRequest, InferResponse
from .protocol.rest.v2_datamodels import GenerateRequest, GenerateResponse

PREDICTOR_URL_FORMAT = "{0}://{1}/v1/models/{2}:predict"
EXPLAINER_URL_FORMAT = "{0}://{1}/v1/models/{2}:explain"
Expand Down Expand Up @@ -417,12 +416,6 @@ async def predict(
else res
)

async def generate(
self, payload: GenerateRequest, headers: Dict[str, str] = None
) -> Union[GenerateResponse, AsyncIterator[Any]]:
"""`generate` handler can be overridden to implement text generation."""
raise NotImplementedError("generate is not implemented")

async def explain(self, payload: Dict, headers: Dict[str, str] = None) -> Dict:
"""`explain` handler can be overridden to implement the model explanation.
The default implementation makes call to the explainer if ``explainer_host`` is specified.
Expand Down
27 changes: 1 addition & 26 deletions python/kserve/kserve/protocol/dataplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
import time
from importlib import metadata
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import cloudevents.exceptions as ce
import orjson
Expand All @@ -30,7 +30,6 @@
from ..model_repository import ModelRepository
from ..utils.utils import create_response_cloudevent, is_structured_cloudevent
from .infer_type import InferRequest, InferResponse
from .rest.v2_datamodels import GenerateRequest, GenerateResponse

JSON_HEADERS = [
"application/json",
Expand Down Expand Up @@ -340,30 +339,6 @@ async def infer(
response = await model(request, headers=headers)
return response, headers

async def generate(
self,
model_name: str,
request: Union[Dict, GenerateRequest],
headers: Optional[Dict[str, str]] = None,
) -> Tuple[Union[GenerateResponse, AsyncIterator[Any]], Dict[str, str]]:
"""Generate the text with the provided text prompt.

Args:
model_name (str): Model name.
request (bytes|GenerateRequest): Generate Request / ChatCompletion Request body data.
headers: (Optional[Dict[str, str]]): Request headers.

Returns:
response: The generated output or output stream.
response_headers: Headers to construct the HTTP response.

Raises:
InvalidInput: An error when the body bytes can't be decoded as JSON.
"""
model = self.get_model(model_name)
response = await model.generate(request, headers=headers)
return response, headers

async def explain(
self,
model_name: str,
Expand Down
12 changes: 0 additions & 12 deletions python/kserve/kserve/protocol/rest/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,18 +177,6 @@ def create_application(self) -> FastAPI:
response_model=InferenceResponse,
tags=["V2"],
),
FastAPIRoute(
r"/v2/models/{model_name}/generate",
v2_endpoints.generate,
methods=["POST"],
tags=["V2"],
),
FastAPIRoute(
r"/v2/models/{model_name}/generate_stream",
v2_endpoints.generate_stream,
methods=["POST"],
tags=["V2"],
),
FastAPIRoute(
r"/v2/models/{model_name}/versions/{model_version}/infer",
v2_endpoints.infer,
Expand Down
217 changes: 0 additions & 217 deletions python/kserve/kserve/protocol/rest/v2_datamodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,220 +283,3 @@ class InferenceResponse(BaseModel):

class Config:
schema_extra = inference_response_schema_extra


generate_request_schema_extra = {
"example": {
"text_input": "Tell me about the AI",
"parameters": {
"temperature": 0.8,
"top_p": 0.9,
},
}
}


class GenerateRequest(BaseModel):
"""GenerateRequest Model

$generate_request =
{
"text_input" : $string,
"parameters" : $string #optional,
}
"""

text_input: str
parameters: Optional[Parameters] = None

if is_pydantic_2:
model_config = ConfigDict(json_schema_extra=generate_request_schema_extra)
else:

class Config:
json_loads = orjson.loads
schema_extra = generate_request_schema_extra


token_schema_extra = {
"example": {
"id": 267,
"logprob": -2.0723474,
"special": False,
"text": " a",
}
}


class Token(BaseModel):
"""Token Data Model"""

id: int
logprob: float
special: bool
text: str

if is_pydantic_2:
model_config = ConfigDict(json_schema_extra=token_schema_extra)
else:

class Config:
json_loads = orjson.loads
schema_extra = token_schema_extra


details_schema_extra = {
"example": {
"finish_reason": "stop",
"logprobs": [
{
"id": 267,
"logprob": -2.0723474,
"special": False,
"text": " a",
}
],
}
}


class Details(BaseModel):
"""Generate response details"""

finish_reason: str
logprobs: List[Token]

if is_pydantic_2:
model_config = ConfigDict(
json_schema_extra=details_schema_extra,
)
else:

class Config:
json_loads = orjson.loads
schema_extra = details_schema_extra


streaming_details_schema_extra = {
"example": {
"finish_reason": "stop",
"logprobs": {
"id": 267,
"logprob": -2.0723474,
"special": False,
"text": " a",
},
}
}


class StreamingDetails(BaseModel):
"""Generate response details"""

finish_reason: str
logprobs: Token

if is_pydantic_2:
model_config = ConfigDict(
json_schema_extra=streaming_details_schema_extra,
)
else:

class Config:
json_loads = orjson.loads
schema_extra = streaming_details_schema_extra


generate_response_schema_extra = {
"example": {
"text_output": "Tell me about the AI",
"model_name": "bloom7b1",
"details": {
"finish_reason": "stop",
"logprobs": [
{
"id": "267",
"logprob": -2.0723474,
"special": False,
"text": " a",
}
],
},
}
}


class GenerateResponse(BaseModel):
"""GenerateResponse Model

$generate_response =
{
"text_output" : $string,
"model_name" : $string,
"model_version" : $string #optional,
"details": $Details #optional
}
"""

text_output: str
model_name: str
model_version: Optional[str] = None
details: Optional[Details] = None

if is_pydantic_2:
model_config = ConfigDict(
protected_namespaces=(),
json_schema_extra=generate_response_schema_extra,
)
else:

class Config:
json_loads = orjson.loads
schema_extra = generate_response_schema_extra


generate_streaming_response_schema_extra = {
"example": {
"text_output": "Tell me about the AI",
"model_name": "bloom7b1",
"details": {
"finish_reason": "stop",
"logprobs": {
"id": "267",
"logprob": -2.0723474,
"special": False,
"text": " a",
},
},
}
}


class GenerateStreamingResponse(BaseModel):
"""GenerateStreamingResponse Model

$generate_response =
{
"text_output" : $string,
"model_name" : $string,
"model_version" : $string #optional,
"details": $Details #optional
}
"""

text_output: str
model_name: str
model_version: Optional[str] = None
details: Optional[StreamingDetails] = None

if is_pydantic_2:
model_config = ConfigDict(
protected_namespaces=(),
json_schema_extra=generate_streaming_response_schema_extra,
)

else:

class Config:
json_loads = orjson.loads
schema_extra = generate_streaming_response_schema_extra
Loading
Loading