Skip to content

Commit

Permalink
When pipelines fail to load in framework code, for whatever reason
Browse files Browse the repository at this point in the history
We need to display some error about what is happening.
  • Loading branch information
Narsil committed Jun 10, 2021
1 parent 5585292 commit 708425c
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 17 deletions.
10 changes: 5 additions & 5 deletions api-inference-community/api_inference_community/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ async def pipeline_route(request: Request) -> Response:
start = time.time()
payload = await request.body()
task = os.environ["TASK"]
pipe = request.app.pipeline
try:
sampling_rate = pipe.sampling_rate
except Exception:
sampling_rate = None
try:
pipe = request.app.get_pipeline()
try:
sampling_rate = pipe.sampling_rate
except Exception:
sampling_rate = None
inputs, params = normalize_payload(payload, task, sampling_rate=sampling_rate)
except ValidationError as e:
errors = []
Expand Down
32 changes: 21 additions & 11 deletions api-inference-community/docker_images/common/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,20 @@
}


def get_pipeline(task: str, model_id: str) -> Pipeline:
if task not in ALLOWED_TASKS:
raise EnvironmentError(f"{task} is not a valid pipeline for model : {model_id}")
return ALLOWED_TASKS[task](model_id)
PIPELINE = None


def get_pipeline() -> Pipeline:
global PIPELINE
if PIPELINE is None:
task = os.environ["TASK"]
model_id = os.environ["MODEL_ID"]
if task not in ALLOWED_TASKS:
raise EnvironmentError(
f"{task} is not a valid pipeline for model : {model_id}"
)
PIPELINE = ALLOWED_TASKS[task](model_id)
return PIPELINE


routes = [
Expand All @@ -61,13 +71,13 @@ async def startup_event():
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
logger.handlers = [handler]

task = os.environ["TASK"]
model_id = os.environ["MODEL_ID"]
app.pipeline = get_pipeline(task, model_id)
# Link between `api-inference-community` and framework code.
app.get_pipeline = get_pipeline


if __name__ == "__main__":
task = os.environ["TASK"]
model_id = os.environ["MODEL_ID"]

get_pipeline(task, model_id)
try:
get_pipeline()
except Exception:
# We can fail so we can show exception later.
pass
38 changes: 37 additions & 1 deletion api-inference-community/tests/test_dockers.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,44 @@ def test_speechbrain(self):
def test_timm(self):
self.framework_docker_test("timm", "image-classification", "sgugger/resnet50d")

def framework_docker_test(self, framework: str, task: str, model_id: str):
def framework_invalid_test(self, framework: str):
task = "invalid"
model_id = "invalid"
tag = self.create_docker(framework)
run_docker_command = [
"docker",
"run",
"-p",
"8000:80",
"-e",
f"TASK={task}",
"-e",
f"MODEL_ID={model_id}",
"-v",
"/tmp:/data",
"-t",
tag,
]

url = "http://localhost:8000"
timeout = 60
with DockerPopen(run_docker_command) as proc:
for i in range(400):
try:
response = httpx.get(url, timeout=10)
break
except Exception:
time.sleep(1)
self.assertEqual(response.content, b'{"ok":"ok"}')

response = httpx.post(url, data=b"This is a test", timeout=timeout)
self.assertEqual(response.status_code, 500)
self.assertEqual(response.headers["content-type"], "application/json")

proc.terminate()
proc.wait(5)

def framework_docker_test(self, framework: str, task: str, model_id: str):
tag = self.create_docker(framework)
run_docker_command = [
"docker",
Expand Down
42 changes: 42 additions & 0 deletions api-inference-community/tests/test_routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
from unittest import TestCase
import logging
from api_inference_community.routes import status_ok, pipeline_route
from starlette.applications import Starlette
from starlette.routing import Route
from starlette.testclient import TestClient


class ValidationTestCase(TestCase):
def test_invalid_pipeline(self):
os.environ["TASK"] = "invalid"

def get_pipeline():
raise Exception("We cannot load the pipeline")

routes = [
Route("/{whatever:path}", status_ok),
Route("/{whatever:path}", pipeline_route, methods=["POST"]),
]

app = Starlette(routes=routes)

@app.on_event("startup")
async def startup_event():
logger = logging.getLogger("uvicorn.access")
handler = logging.StreamHandler()
handler.setFormatter(
logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
)
logger.handlers = [handler]

# Link between `api-inference-community` and framework code.
app.get_pipeline = get_pipeline

with TestClient(app) as client:
response = client.post("/", data=b"")
self.assertEqual(
response.status_code,
500,
)
self.assertEqual(response.content, b'{"error":"We cannot load the pipeline"}')

0 comments on commit 708425c

Please sign in to comment.