Skip to content

Commit

Permalink
MVP of skops integration
Browse files Browse the repository at this point in the history
  • Loading branch information
freddyaboulton committed Aug 30, 2022
1 parent 99833d5 commit 0a62f8f
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 4 deletions.
80 changes: 76 additions & 4 deletions gradio/external.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
"""This module should not be used directly as its API is subject to change. Instead,
use the `gr.Blocks.load()` or `gr.Interface.load()` functions."""

from __future__ import annotations

import base64
import json
import operator
import re
import warnings
from copy import deepcopy
from typing import Callable, Dict
from typing import TYPE_CHECKING, Callable, Dict, List, Tuple

import requests
import yaml

import gradio
from gradio import components, utils

if TYPE_CHECKING:
from gradio.components import DataframeData


class TooManyRequestsError(Exception):
"""Raised when the Hugging Face API returns a 429 status code."""
Expand Down Expand Up @@ -42,6 +49,42 @@ def load_blocks_from_repo(name, src=None, api_key=None, alias=None, **kwargs):
return blocks


def get_tabular_examples(model_name) -> Dict[str, List[float]]:
readme = requests.get(f"https://huggingface.co/{model_name}/resolve/main/README.md")
if readme.status_code != 200:
warnings.warn(f"Cannot load examples from README for {model_name}", UserWarning)
example_data = {}
else:
yaml_regex = re.search(
"(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", readme.text
)
example_yaml = next(yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]]))
example_data = example_yaml["widget"]["structuredData"]
return example_data


def cols_to_rows(
example_data: Dict[str, List[float]]
) -> Tuple[List[str], List[List[float]]]:
headers = list(example_data.keys())
data = []
for row_index in range(len(example_data[headers[0]])):
row_data = []
for header in headers:
row_data.append(example_data[header][row_index])
data.append(row_data)
return headers, data


def rows_to_cols(
incoming_data: DataframeData,
) -> Dict[str, Dict[str, Dict[str, List[str]]]]:
data_column_wise = {}
for i, header in enumerate(incoming_data["headers"]):
data_column_wise[header] = [str(row[i]) for row in incoming_data["data"]]
return {"inputs": {"data": data_column_wise}}


def get_models_interface(model_name, api_key, alias, **kwargs):
model_url = "https://huggingface.co/{}".format(model_name)
api_url = "https://api-inference.huggingface.co/models/{}".format(model_name)
Expand Down Expand Up @@ -260,6 +303,28 @@ def encode_to_base64(r: requests.Response) -> str:
},
}

if p in ["tabular-classification", "tabular-regression"]:
example_data = get_tabular_examples(model_name)
col_names, example_data = cols_to_rows(example_data)

pipelines[p] = {
"inputs": components.Dataframe(
label="Input Rows",
type="pandas",
headers=col_names,
col_count=(len(col_names), "fixed"),
),
"outputs": components.Dataframe(
label="Predictions", type="array", headers=["prediction"]
),
"preprocess": rows_to_cols,
"postprocess": lambda r: {
"headers": ["prediction"],
"data": [[pred] for pred in json.loads(r.text)],
},
"examples": [[example_data]],
}

if p is None or not (p in pipelines):
raise ValueError("Unsupported pipeline type: {}".format(p))

Expand All @@ -275,10 +340,16 @@ def query_huggingface_api(*params):
data = json.dumps(data)
response = requests.request("POST", api_url, headers=headers, data=data)
if not (response.status_code == 200):
errors_json = response.json()
errors, warns = "", ""
if errors_json.get("error"):
errors = f", Error: {errors_json.get('error')}"
if errors_json.get("warnings"):
warns = f", Warnings: {errors_json.get('warnings')}"
raise ValueError(
"Could not complete request to HuggingFace API, Error {}".format(
response.status_code
)
f"Could not complete request to HuggingFace API, Status Code: {response.status_code}"
+ errors
+ warns
)
if (
p == "token-classification"
Expand All @@ -299,6 +370,7 @@ def query_huggingface_api(*params):
"inputs": pipeline["inputs"],
"outputs": pipeline["outputs"],
"title": model_name,
"examples": pipeline.get("examples"),
}

kwargs = dict(interface_info, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pillow
pycryptodome
python-multipart
pydub
pyyaml
requests
uvicorn
Jinja2
Expand Down

0 comments on commit 0a62f8f

Please sign in to comment.