Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gr.ScatterPlot component #2764

Merged
merged 34 commits into from
Dec 9, 2022
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
5fbc4e4
Try clean install
freddyaboulton Nov 29, 2022
3dd95fc
Resolve peer dependencies?
freddyaboulton Nov 29, 2022
50b6006
CHANGELOG
freddyaboulton Nov 29, 2022
1d56cde
Add outbreak_forcast notebook
freddyaboulton Dec 2, 2022
bfa84f7
generate again
freddyaboulton Dec 2, 2022
ee242ba
CHANGELOG
freddyaboulton Nov 29, 2022
1f832c9
Add image to changelog
freddyaboulton Nov 29, 2022
18079f7
Color palette
freddyaboulton Dec 1, 2022
44a043c
Fix colors + legend
freddyaboulton Dec 2, 2022
7f2a058
Tooltip
freddyaboulton Dec 2, 2022
d7f985a
Add axis titles
freddyaboulton Dec 2, 2022
80a0095
Clean up code a bit + quant scales
freddyaboulton Dec 5, 2022
c8f8a18
Add code
freddyaboulton Dec 5, 2022
f455f4d
Add size, shape + rename legend title
freddyaboulton Dec 5, 2022
f546f03
Fix demo
freddyaboulton Dec 5, 2022
289f075
Add update + demo
freddyaboulton Dec 7, 2022
c4006e8
Handle darkmode better
freddyaboulton Dec 7, 2022
6dbf58c
Try new font
freddyaboulton Dec 7, 2022
7401b80
Use sans-serif
freddyaboulton Dec 7, 2022
3e1f397
Add caption
freddyaboulton Dec 7, 2022
611e764
Changelog + tests
freddyaboulton Dec 7, 2022
4ef2d91
More tests
freddyaboulton Dec 7, 2022
d56eb10
Address comments
freddyaboulton Dec 8, 2022
08ffebb
Make caption fontsize smaller and enable interactivity
freddyaboulton Dec 8, 2022
130ee75
Add docstrings + add height + width
freddyaboulton Dec 8, 2022
076ab56
Merge branch 'main' into scatter-plot
abidlabs Dec 9, 2022
5dc41e2
Use normal font weight
freddyaboulton Dec 9, 2022
8715506
Merge main
freddyaboulton Dec 9, 2022
0b73f5e
Merge branch 'main' into scatter-plot
freddyaboulton Dec 9, 2022
620f656
Merge upstream changelog
freddyaboulton Dec 9, 2022
ec01da3
Make last values keyword only
freddyaboulton Dec 9, 2022
8146fda
Fix typo
freddyaboulton Dec 9, 2022
920cd92
Accept value as fn
freddyaboulton Dec 9, 2022
f2cb2e9
reword changelog a bit
freddyaboulton Dec 9, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,41 @@

## New Features:

### Scatter plot component

It is now possible to create a scatter plot without knowledge of a plotting library!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a suggestion, feel free to ignore

Suggested change
It is now possible to create a scatter plot without knowledge of a plotting library!
It is now possible to create a scatter plot natively in Gradio!


The `gr.ScatterPlot` component accepts a pandas dataframe and some optional configuration parameters
and will automatically create a plot for you!

This is the first of many native plotting components in Gradio!

For an example of how to use `gr.ScatterPlot` see below:

```python
import gradio as gr
from vega_datasets import data

cars = data.cars()

with gr.Blocks() as demo:
gr.ScatterPlot(show_label=False,
value=cars,
x="Horsepower",
y="Miles_per_Gallon",
color="Origin",
tooltip="Name",
title="Car Data",
y_title="Miles per Gallon",
color_legend_title="Origin of Car").style(container=False)

demo.launch()
```

By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2764](https://github.com/gradio-app/gradio/pull/2764)



### Support for altair plots

The `Plot` component can now accept altair plots as values!
Expand Down Expand Up @@ -33,7 +68,7 @@ demo.launch()

By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2741](https://github.com/gradio-app/gradio/pull/2741)

### Set the background color of a Label component
### Set the background color of a Label component

The `Label` component now accepts a `color` argument by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2736](https://github.com/gradio-app/gradio/pull/2736).
The `color` argument should either be a valid css color name or hexadecimal string.
Expand Down Expand Up @@ -70,7 +105,6 @@ demo.queue().launch()

![label_bg_color_update](https://user-images.githubusercontent.com/41651716/204400372-80e53857-f26f-4a38-a1ae-1acadff75e89.gif)


## Bug Fixes:
* Fixed issue where image thumbnails were not showing when an example directory was provided
by by [@abidlabs](https://github.com/abidlabs) in [PR 2745](https://github.com/gradio-app/gradio/pull/2745)
Expand Down
1 change: 1 addition & 0 deletions demo/native_plots/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
vega_datasets
1 change: 1 addition & 0 deletions demo/native_plots/run.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: native_plots"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio vega_datasets"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/scatter_plot_demo.py"]}, {"cell_type": "code", "execution_count": null, "id": 44380577570523278879349135829904343037, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "from scatter_plot_demo import scatter_plot\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Tabs():\n", " with gr.TabItem(\"Scatter Plot\"):\n", " scatter_plot.render()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
12 changes: 12 additions & 0 deletions demo/native_plots/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import gradio as gr

from scatter_plot_demo import scatter_plot


with gr.Blocks() as demo:
with gr.Tabs():
with gr.TabItem("Scatter Plot"):
scatter_plot.render()

if __name__ == "__main__":
demo.launch()
47 changes: 47 additions & 0 deletions demo/native_plots/scatter_plot_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import gradio as gr

from vega_datasets import data

cars = data.cars()
iris = data.iris()


def scatter_plot_fn(dataset):
if dataset == "iris":
return gr.ScatterPlot.update(
value=iris,
x="petalWidth",
y="petalLength",
color="species",
title="Iris Dataset",
color_legend_title="Species",
x_title="Petal Width",
y_title="Petal Length",
tooltip=["petalWidth", "petalLength", "species"],
caption="",
)
else:
return gr.ScatterPlot.update(
value=cars,
x="Horsepower",
y="Miles_per_Gallon",
color="Origin",
tooltip="Name",
title="Car Data",
y_title="Miles per Gallon",
color_legend_title="Origin of Car",
caption="MPG vs Horsepower of various cars"
)


with gr.Blocks() as scatter_plot:
with gr.Row():
with gr.Column():
dataset = gr.Dropdown(choices=["cars", "iris"], value="cars")
with gr.Column():
plot = gr.ScatterPlot(show_label=False).style(container=True)
dataset.change(scatter_plot_fn, inputs=dataset, outputs=plot)
scatter_plot.load(fn=scatter_plot_fn, inputs=dataset, outputs=plot)

if __name__ == "__main__":
scatter_plot.launch()
1 change: 1 addition & 0 deletions demo/scatterplot_component/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
vega_datasets
1 change: 1 addition & 0 deletions demo/scatterplot_component/run.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: scatterplot_component"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio vega_datasets"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from vega_datasets import data\n", "\n", "cars = data.cars()\n", "\n", "with gr.Blocks() as demo:\n", " gr.ScatterPlot(show_label=False,\n", " value=cars,\n", " x=\"Horsepower\",\n", " y=\"Miles_per_Gallon\",\n", " color=\"Origin\",\n", " tooltip=\"Name\",\n", " title=\"Car Data\",\n", " y_title=\"Miles per Gallon\",\n", " color_legend_title=\"Origin of Car\").style(container=False)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
18 changes: 18 additions & 0 deletions demo/scatterplot_component/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import gradio as gr
from vega_datasets import data

cars = data.cars()

with gr.Blocks() as demo:
gr.ScatterPlot(show_label=False,
value=cars,
x="Horsepower",
y="Miles_per_Gallon",
color="Origin",
tooltip="Name",
title="Car Data",
y_title="Miles per Gallon",
color_legend_title="Origin of Car").style(container=False)

if __name__ == "__main__":
demo.launch()
1 change: 1 addition & 0 deletions gradio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
Number,
Plot,
Radio,
ScatterPlot,
Slider,
State,
StatusTracker,
Expand Down
207 changes: 205 additions & 2 deletions gradio/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from types import ModuleType
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple

import altair as alt
import matplotlib.figure
import numpy as np
import pandas as pd
Expand All @@ -28,6 +29,7 @@
from ffmpy import FFmpeg
from markdown_it import MarkdownIt
from mdit_py_plugins.dollarmath import dollarmath_plugin
from pandas.api.types import is_numeric_dtype

from gradio import media_data, processing_utils, utils
from gradio.blocks import Block
Expand Down Expand Up @@ -3937,8 +3939,209 @@ def postprocess(self, y: str | None) -> Dict[str, str] | None:
out_y = y.to_json()
return {"type": dtype, "plot": out_y}

def style(self):
return self
def style(self, container: Optional[bool] = None):
return IOComponent.style(
self,
container=container,
)


@document("change", "clear")
class ScatterPlot(Plot):
def __init__(
self,
x: Optional[str] = None,
y: Optional[str] = None,
value: Optional[pd.DataFrame] = None,
freddyaboulton marked this conversation as resolved.
Show resolved Hide resolved
color: Optional[str] = None,
size: Optional[str] = None,
shape: Optional[str] = None,
title: Optional[str] = None,
tooltip: Optional[List[str] | str] = None,
x_title: Optional[str] = None,
y_title: Optional[str] = None,
color_legend_title: Optional[str] = None,
size_legend_title: Optional[str] = None,
shape_legend_title: Optional[str] = None,
caption: Optional[str] = None,
label: Optional[str] = None,
show_label: bool = True,
visible: bool = True,
elem_id: Optional[str] = None,
):
self.x = x
self.y = y
self.color = color
self.size = size
self.shape = shape
self.tooltip = tooltip
self.title = title
self.x_title = x_title
self.y_title = y_title
self.color_legend_title = color_legend_title
self.size_legend_title = size_legend_title
self.shape_legend_title = shape_legend_title
self.caption = caption
self.value = None
if value is not None:
self.value = self.postprocess(value)
super().__init__(
value, label=label, show_label=show_label, visible=visible, elem_id=elem_id
)

def get_config(self):
config = super().get_config()
config["caption"] = self.caption
return config

def get_block_name(self) -> str:
return "plot"

@staticmethod
def update(
value: Optional[Any] = _Keywords.NO_VALUE,
x: Optional[str] = None,
y: Optional[str] = None,
color: Optional[str] = None,
size: Optional[str] = None,
shape: Optional[str] = None,
title: Optional[str] = None,
tooltip: Optional[List[str] | str] = None,
x_title: Optional[str] = None,
y_title: Optional[str] = None,
color_legend_title: Optional[str] = None,
size_legend_title: Optional[str] = None,
shape_legend_title: Optional[str] = None,
caption: Optional[str] = None,
label: Optional[str] = None,
show_label: Optional[bool] = None,
visible: Optional[bool] = None,
):
properties = [
x,
y,
color,
size,
shape,
title,
tooltip,
x_title,
y_title,
color_legend_title,
size_legend_title,
shape_legend_title,
]
if any(properties):
if value is _Keywords.NO_VALUE:
raise ValueError(
"In order to update plot properties the value parameter "
"must be provided. Please pass a value parameter to "
"gr.ScatterPlot.update."
)
if x is None or y is None:
raise ValueError(
"In order to update plot properties, the x and y axis data "
"must be specified. Please pass valid values for x an y to "
"gr.ScatterPlot.update."
)

chart = ScatterPlot.create_plot(value, *properties)
new_chart_str = {"type": "altair", "plot": chart.to_json(), "chart": "scatter"}

updated_config = {
"label": label,
"show_label": show_label,
"visible": visible,
"value": new_chart_str,
"caption": caption,
"__type__": "update",
}
return updated_config

@staticmethod
def create_plot(
value: pd.DataFrame,
x: str,
y: str,
color: Optional[str] = None,
size: Optional[str] = None,
shape: Optional[str] = None,
title: Optional[str] = None,
tooltip: Optional[List[str] | str] = None,
x_title: Optional[str] = None,
y_title: Optional[str] = None,
color_legend_title: Optional[str] = None,
size_legend_title: Optional[str] = None,
shape_legend_title: Optional[str] = None,
):

encodings = dict(
x=alt.X(x, title=x_title or x),
y=alt.Y(y, title=y_title or y),
)
properties = {}
if title:
properties["title"] = title
if color:
if is_numeric_dtype(value[color]):
domain = [value[color].min(), value[color].max()]
range_ = [0, 1]
type_ = "quantitative"
else:
domain = value[color].unique().tolist()
range_ = list(range(len(domain)))
type_ = "nominal"

encodings["color"] = {
"field": color,
"type": type_,
"legend": {"title": color_legend_title or color},
"scale": {"domain": domain, "range": range_},
}
if tooltip:
encodings["tooltip"] = tooltip
if size:
encodings["size"] = {
"field": size,
"type": "quantitative" if is_numeric_dtype(value[size]) else "nominal",
"legend": {"title": size_legend_title or size},
}
if shape:
encodings["shape"] = {
"field": shape,
"type": "quantitative" if is_numeric_dtype(value[shape]) else "nominal",
"legend": {"title": shape_legend_title or shape},
}

return (
alt.Chart(value)
.mark_point()
.encode(**encodings)
.properties(background="transparent", **properties)
.interactive()
)

def postprocess(self, y: pd.DataFrame | Dict | None) -> Dict[str, str] | None:
# if None or update
if y is None or isinstance(y, Dict):
return y
chart = self.create_plot(
value=y,
x=self.x,
y=self.y,
color=self.color,
size=self.size,
shape=self.shape,
title=self.title,
tooltip=self.tooltip,
x_title=self.x_title,
y_title=self.y_title,
color_legend_title=self.color_legend_title,
size_legend_title=self.size_legend_title,
shape_legend_title=self.size_legend_title,
)

return {"type": "altair", "plot": chart.to_json(), "chart": "scatter"}


@document("change")
Expand Down
1 change: 1 addition & 0 deletions scripts/copy_demos.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def copy_all_demos(source_dir: str, dest_dir: str):
"kitchen_sink_random",
"matrix_transpose",
"model3D",
"native_plots",
"reset_components",
"reverse_audio",
"stt_or_tts",
Expand Down
Loading