diff --git a/pkg/workloads/cortex/serve/serve.py b/pkg/workloads/cortex/serve/serve.py index af2a548d2e..44f322775c 100644 --- a/pkg/workloads/cortex/serve/serve.py +++ b/pkg/workloads/cortex/serve/serve.py @@ -42,7 +42,7 @@ def start(args): assert_api_version() storage = S3(bucket=os.environ["CORTEX_BUCKET"], region=os.environ["AWS_REGION"]) try: - raw_api_spec = get_spec(args.cache_dir, args.spec) + raw_api_spec = get_spec(storage, args.cache_dir, args.spec) api = API(storage=storage, cache_dir=args.cache_dir, **raw_api_spec) client = api.predictor.initialize_client(args) cx_logger().info("loading the predictor from {}".format(api.predictor.path)) @@ -122,7 +122,7 @@ def after_request(response): try: api.post_latency_metrics(response.status_code, g.start_time) - if api.tracker is not None: + if int(response.status_code / 100) == 2 and api.tracker is not None: predicted_value = api.tracker.extract_predicted_value(prediction) api.post_tracker_metrics(predicted_value) if predicted_value is not None and predicted_value not in local_cache["class_set"]: @@ -164,10 +164,10 @@ def assert_api_version(): ) -def get_spec(cache_dir, s3_path): +def get_spec(storage, cache_dir, s3_path): local_spec_path = os.path.join(cache_dir, "api_spec.msgpack") - bucket, key = S3.deconstruct_s3_path(s3_path) - S3(bucket, client_config={}).download_file(key, local_spec_path) + _, key = S3.deconstruct_s3_path(s3_path) + storage.download_file(key, local_spec_path) return util.read_msgpack(local_spec_path)