Skip to content

Commit

Permalink
[api] Support temp=0. Return errors as json. (microsoft#168)
Browse files Browse the repository at this point in the history
* [api] Support temp=0. Return errors as json.

* Update hub_utils.py
  • Loading branch information
stephenroller committed Jun 22, 2022
1 parent 1aa40c2 commit fc037c7
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 7 deletions.
5 changes: 4 additions & 1 deletion metaseq/hub_utils.py
Expand Up @@ -553,8 +553,11 @@ def generate(
self.cfg.generation.beam = best_of
if temperature > 0:
self.cfg.generation.temperature = temperature
else:
elif temperature == 0:
self.cfg.generation.sampling = False
self.cfg.generation.temperature = 1.0
elif temperature < 0:
raise ValueError("temperature must be >= 0 and <= 1")

MAX_SEQ_LEN = utils.resolve_max_positions(
self.task.max_positions(), *[model.max_positions() for model in self.models]
Expand Down
46 changes: 40 additions & 6 deletions metaseq_cli/interactive_hosted.py
Expand Up @@ -16,9 +16,11 @@
import pkg_resources
import random
import threading
import traceback

import torch
from flask import Flask, request
from flask import Flask, request, jsonify
from werkzeug.exceptions import HTTPException

from metaseq import options
from metaseq.dataclass.configs import MetaseqConfig
Expand Down Expand Up @@ -125,7 +127,11 @@ def batching_loop(timeout=100, max_tokens=MAX_BATCH_TOKENS):
dist_utils.broadcast_object(
request_object, src_rank=0, group=dist_utils.get_global_group()
)
generations = generator.generate(**request_object)
try:
generations = generator.generate(**request_object)
except Exception as e:
# propagate any exceptions to the response so we can report it
generations = [e] * len(batch)
# broadcast them back
for work_item, gen in zip(batch, generations):
work_item.return_queue.put((work_item.uid, gen))
Expand Down Expand Up @@ -166,10 +172,36 @@ def worker_main(cfg1: MetaseqConfig, namespace_args=None):
# useful in FSDP setting
logger.info(f"Looping engaged! {get_my_ip()}:{port}")
while True:
request_object = dist_utils.broadcast_object(
None, src_rank=0, group=dist_utils.get_global_group()
)
_ = generator.generate(**request_object)
try:
request_object = dist_utils.broadcast_object(
None, src_rank=0, group=dist_utils.get_global_group()
)
_ = generator.generate(**request_object)
except Exception:
# continue looping for the next generation so we don't lock up
pass


@app.errorhandler(Exception)
def handle_exception(e):
# pass through HTTP errors
if isinstance(e, HTTPException):
return e
# now you're handling non-HTTP exceptions only
response = jsonify(
{
"error": {
"message": str(e),
"type": "oops",
"stacktrace": traceback.format_tb(e.__traceback__),
}
}
)
if isinstance(e, ValueError):
response.status = 400
else:
response.status = 500
return response


@app.route("/completions", methods=["POST"])
Expand Down Expand Up @@ -244,6 +276,8 @@ def completions(engine=None):
reordered = sorted(unordered_results, key=lambda x: x[0])
results = []
for prompt, (_, generations) in zip(prompts, reordered):
if isinstance(generations, Exception):
raise generations
results += generations
# transform the result into the openai format
return OAIResponse(results).__dict__()
Expand Down

0 comments on commit fc037c7

Please sign in to comment.