-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Skops integration: Load tabular classification and regression models from the hub #2126
Changes from 6 commits
4566c86
06db951
aae0305
798fbcb
c0b4265
54c5424
6a29df7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,27 @@ | ||
"""This module should not be used directly as its API is subject to change. Instead, | ||
use the `gr.Blocks.load()` or `gr.Interface.load()` functions.""" | ||
|
||
from __future__ import annotations | ||
|
||
import base64 | ||
import json | ||
import math | ||
import numbers | ||
import operator | ||
import re | ||
import warnings | ||
from copy import deepcopy | ||
from typing import Callable, Dict | ||
from typing import TYPE_CHECKING, Callable, Dict, List, Tuple | ||
|
||
import requests | ||
import yaml | ||
|
||
import gradio | ||
from gradio import components, utils | ||
|
||
if TYPE_CHECKING: | ||
from gradio.components import DataframeData | ||
|
||
|
||
class TooManyRequestsError(Exception): | ||
"""Raised when the Hugging Face API returns a 429 status code.""" | ||
|
@@ -42,6 +51,58 @@ def load_blocks_from_repo(name, src=None, api_key=None, alias=None, **kwargs): | |
return blocks | ||
|
||
|
||
def get_tabular_examples(model_name) -> Dict[str, List[float]]: | ||
readme = requests.get(f"https://huggingface.co/{model_name}/resolve/main/README.md") | ||
if readme.status_code != 200: | ||
warnings.warn(f"Cannot load examples from README for {model_name}", UserWarning) | ||
example_data = {} | ||
else: | ||
yaml_regex = re.search( | ||
"(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", readme.text | ||
) | ||
example_yaml = next(yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]])) | ||
example_data = example_yaml.get("widget", {}).get("structuredData", {}) | ||
if not example_data: | ||
raise ValueError( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My reasoning for error-ing if there is not example data in the repo is that without it we'd display a bare dataframe as input and it'd be cumbersome for users to type out all the feature names and inputs. Cumbersome enough that it defeats the shareability of gradio demos. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Regardless of the feature names being provided, each feature has it's own value range or feature type anyway, so it doesn't make sense even if you provide everything. What would make sense would be people calling it and loading the interface with dynamic dataframe and still provide an example themselves in the interface. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Heads up: ended up filing the issue we talked about #2155 . Once this is fixed it may be possible to show an empty dataframe and have users type in all the values themselves. |
||
f"No example data found in README.md of {model_name} - Cannot build gradio demo. " | ||
"See the README.md here: https://huggingface.co/scikit-learn/tabular-playground/blob/main/README.md " | ||
"for a reference on how to provide example data to your model." | ||
) | ||
# replace nan with string NaN for inference API | ||
for data in example_data.values(): | ||
for i, val in enumerate(data): | ||
if isinstance(val, numbers.Number) and math.isnan(val): | ||
data[i] = "NaN" | ||
return example_data | ||
|
||
|
||
def cols_to_rows( | ||
example_data: Dict[str, List[float]] | ||
) -> Tuple[List[str], List[List[float]]]: | ||
headers = list(example_data.keys()) | ||
n_rows = max(len(example_data[header] or []) for header in headers) | ||
data = [] | ||
for row_index in range(n_rows): | ||
row_data = [] | ||
for header in headers: | ||
col = example_data[header] or [] | ||
if row_index >= len(col): | ||
row_data.append("NaN") | ||
else: | ||
row_data.append(col[row_index]) | ||
data.append(row_data) | ||
return headers, data | ||
|
||
|
||
def rows_to_cols( | ||
incoming_data: DataframeData, | ||
) -> Dict[str, Dict[str, Dict[str, List[str]]]]: | ||
data_column_wise = {} | ||
for i, header in enumerate(incoming_data["headers"]): | ||
data_column_wise[header] = [str(row[i]) for row in incoming_data["data"]] | ||
return {"inputs": {"data": data_column_wise}} | ||
|
||
|
||
def get_models_interface(model_name, api_key, alias, **kwargs): | ||
model_url = "https://huggingface.co/{}".format(model_name) | ||
api_url = "https://api-inference.huggingface.co/models/{}".format(model_name) | ||
|
@@ -260,6 +321,29 @@ def encode_to_base64(r: requests.Response) -> str: | |
}, | ||
} | ||
|
||
if p in ["tabular-classification", "tabular-regression"]: | ||
example_data = get_tabular_examples(model_name) | ||
col_names, example_data = cols_to_rows(example_data) | ||
example_data = [[example_data]] if example_data else None | ||
|
||
pipelines[p] = { | ||
"inputs": components.Dataframe( | ||
label="Input Rows", | ||
type="pandas", | ||
headers=col_names, | ||
col_count=(len(col_names), "fixed"), | ||
), | ||
"outputs": components.Dataframe( | ||
label="Predictions", type="array", headers=["prediction"] | ||
), | ||
"preprocess": rows_to_cols, | ||
"postprocess": lambda r: { | ||
"headers": ["prediction"], | ||
"data": [[pred] for pred in json.loads(r.text)], | ||
}, | ||
"examples": example_data, | ||
} | ||
|
||
if p is None or not (p in pipelines): | ||
raise ValueError("Unsupported pipeline type: {}".format(p)) | ||
|
||
|
@@ -275,10 +359,16 @@ def query_huggingface_api(*params): | |
data = json.dumps(data) | ||
response = requests.request("POST", api_url, headers=headers, data=data) | ||
if not (response.status_code == 200): | ||
errors_json = response.json() | ||
errors, warns = "", "" | ||
if errors_json.get("error"): | ||
errors = f", Error: {errors_json.get('error')}" | ||
if errors_json.get("warnings"): | ||
warns = f", Warnings: {errors_json.get('warnings')}" | ||
raise ValueError( | ||
"Could not complete request to HuggingFace API, Error {}".format( | ||
response.status_code | ||
) | ||
f"Could not complete request to HuggingFace API, Status Code: {response.status_code}" | ||
+ errors | ||
+ warns | ||
) | ||
if ( | ||
p == "token-classification" | ||
|
@@ -299,6 +389,7 @@ def query_huggingface_api(*params): | |
"inputs": pipeline["inputs"], | ||
"outputs": pipeline["outputs"], | ||
"title": model_name, | ||
"examples": pipeline.get("examples"), | ||
} | ||
|
||
kwargs = dict(interface_info, **kwargs) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ pillow | |
pycryptodome | ||
python-multipart | ||
pydub | ||
pyyaml | ||
requests | ||
uvicorn | ||
Jinja2 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,16 @@ | ||
import json | ||
import os | ||
import pathlib | ||
import textwrap | ||
import unittest | ||
from unittest.mock import patch | ||
from unittest.mock import MagicMock, patch | ||
|
||
import pytest | ||
import transformers | ||
|
||
import gradio as gr | ||
from gradio.external import TooManyRequestsError | ||
from gradio import utils | ||
from gradio.external import TooManyRequestsError, cols_to_rows, get_tabular_examples | ||
|
||
""" | ||
WARNING: These tests have an external dependency: namely that Hugging Face's | ||
|
@@ -242,5 +244,92 @@ def test_interface_load_cache_examples(tmp_path): | |
) | ||
|
||
|
||
def test_get_tabular_examples_replaces_nan_with_str_nan(): | ||
readme = """ | ||
--- | ||
tags: | ||
- sklearn | ||
- skops | ||
- tabular-classification | ||
widget: | ||
structuredData: | ||
attribute_0: | ||
- material_7 | ||
- material_7 | ||
- material_7 | ||
measurement_2: | ||
- 14.206 | ||
- 15.094 | ||
- .nan | ||
--- | ||
""" | ||
mock_response = MagicMock() | ||
mock_response.status_code = 200 | ||
mock_response.text = textwrap.dedent(readme) | ||
|
||
with patch("gradio.external.requests.get", return_value=mock_response): | ||
examples = get_tabular_examples("foo-model") | ||
assert examples["measurement_2"] == [14.206, 15.094, "NaN"] | ||
|
||
|
||
def test_cols_to_rows(): | ||
assert cols_to_rows({"a": [1, 2, "NaN"], "b": [1, "NaN", 3]}) == ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a small question, if there's a cell left empty, how is it handled? Do you impute "NaN" directly? (how is it sent to inference?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right now it would be left empty but that doesn't work for the inference API so I'll replace with "NaN" instead! Thank you |
||
["a", "b"], | ||
[[1, 1], [2, "NaN"], ["NaN", 3]], | ||
) | ||
assert cols_to_rows({"a": [1, 2, "NaN", 4], "b": [1, "NaN", 3]}) == ( | ||
["a", "b"], | ||
[[1, 1], [2, "NaN"], ["NaN", 3], [4, "NaN"]], | ||
) | ||
assert cols_to_rows({"a": [1, 2, "NaN"], "b": [1, "NaN", 3, 5]}) == ( | ||
["a", "b"], | ||
[[1, 1], [2, "NaN"], ["NaN", 3], ["NaN", 5]], | ||
) | ||
assert cols_to_rows({"a": None, "b": [1, "NaN", 3, 5]}) == ( | ||
["a", "b"], | ||
[["NaN", 1], ["NaN", "NaN"], ["NaN", 3], ["NaN", 5]], | ||
) | ||
assert cols_to_rows({"a": None, "b": None}) == (["a", "b"], []) | ||
|
||
|
||
def check_dataframe(config): | ||
input_df = next( | ||
c for c in config["components"] if c["props"].get("label", "") == "Input Rows" | ||
) | ||
assert input_df["props"]["headers"] == ["a", "b"] | ||
assert input_df["props"]["row_count"] == (1, "dynamic") | ||
assert input_df["props"]["col_count"] == (2, "fixed") | ||
|
||
|
||
def check_dataset(config, readme_examples): | ||
# No Examples | ||
if not any(readme_examples.values()): | ||
assert not any([c for c in config["components"] if c["type"] == "dataset"]) | ||
else: | ||
dataset = next(c for c in config["components"] if c["type"] == "dataset") | ||
assert dataset["props"]["samples"] == [ | ||
[utils.delete_none(cols_to_rows(readme_examples)[1])] | ||
] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should cover most of the weirdness from malformed READMEs. Haven't actually come across a repo with bad README data but it's possible. |
||
"hypothetical_readme", | ||
[ | ||
{"a": [1, 2, "NaN"], "b": [1, "NaN", 3]}, | ||
{"a": [1, 2, "NaN", 4], "b": [1, "NaN", 3]}, | ||
{"a": [1, 2, "NaN"], "b": [1, "NaN", 3, 5]}, | ||
{"a": None, "b": [1, "NaN", 3, 5]}, | ||
{"a": None, "b": None}, | ||
], | ||
) | ||
def test_can_load_tabular_model_with_different_widget_data(hypothetical_readme): | ||
with patch( | ||
"gradio.external.get_tabular_examples", return_value=hypothetical_readme | ||
): | ||
io = gr.Interface.load("models/scikit-learn/tabular-playground") | ||
check_dataframe(io.config) | ||
check_dataset(io.config, hypothetical_readme) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can either get the example data from the README or the
config.json
but theconfig.json
will only have the example data if the model was uploaded with skops.I think it would be better if gradio could create a demo for any tabular model and not just those created with skops. Downside is that it introduces a pyyaml dependency.
In the future, once the skops config json file contains richer metadata about feature types (categorical vs null) etc we can read from the config.json if it's present.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it was something we didn't want to have in model card specifically
(cc @adrinjalali is working on having dtypes atm)
Maybe you could check for both? @freddyaboulton