From 5fbc4e4c1b3d21b30cc52ab4c278791699f46dc3 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Tue, 29 Nov 2022 14:20:27 -0500 Subject: [PATCH 01/30] Try clean install --- ui/pnpm-lock.yaml | 55 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 48 insertions(+), 7 deletions(-) diff --git a/ui/pnpm-lock.yaml b/ui/pnpm-lock.yaml index a156ee1631c0..61088341eeb5 100644 --- a/ui/pnpm-lock.yaml +++ b/ui/pnpm-lock.yaml @@ -414,12 +414,21 @@ importers: svelte: 3.49.0 devDependencies: '@sveltejs/adapter-auto': 1.0.0-next.90 +<<<<<<< HEAD '@sveltejs/kit': 1.0.0-next.318_svelte@3.49.0 autoprefixer: 10.4.2_postcss@8.4.6 postcss: 8.4.6 postcss-load-config: 3.1.1 svelte-check: 2.4.1_onvlxjpnd23pr3hxbmout2wrjm svelte-preprocess: 4.10.2_2udzbozq3wemyrf2xz7puuv2zy +======= + '@sveltejs/kit': 1.0.0-next.318 + autoprefixer: 10.4.2_postcss@8.4.6 + postcss: 8.4.6 + postcss-load-config: 3.1.1 + svelte-check: 2.4.1_2y4otvh2n6klv6metqycpfiuzy + svelte-preprocess: 4.10.2_bw7ic75prjd4umr4fb55sbospu +>>>>>>> ad12c698 (Try clean install) tailwindcss: 3.0.23_autoprefixer@10.4.2 tslib: 2.3.1 typescript: 4.5.5 @@ -597,16 +606,15 @@ packages: import-meta-resolve: 2.2.0 dev: true - /@sveltejs/kit/1.0.0-next.318_svelte@3.49.0: + /@sveltejs/kit/1.0.0-next.318: resolution: {integrity: sha512-/M/XNvEqK71KCGro1xLuiUuklsMPe+G5DiVMs39tpfFIFhH4oCzAt+YBaIZDKORogGz3QDaYc5BV+eFv9E5cyw==} engines: {node: '>=14.13'} hasBin: true peerDependencies: svelte: ^3.44.0 dependencies: - '@sveltejs/vite-plugin-svelte': 1.0.0-next.44_svelte@3.49.0+vite@2.9.9 + '@sveltejs/vite-plugin-svelte': 1.0.0-next.44_vite@2.9.9 sade: 1.8.1 - svelte: 3.49.0 vite: 2.9.9 transitivePeerDependencies: - diff-match-patch @@ -639,7 +647,7 @@ packages: - supports-color dev: false - /@sveltejs/vite-plugin-svelte/1.0.0-next.44_svelte@3.49.0+vite@2.9.9: + /@sveltejs/vite-plugin-svelte/1.0.0-next.44_vite@2.9.9: resolution: {integrity: sha512-n+sssEWbzykPS447FmnNyU5GxEhrBPDVd0lxNZnxRGz9P6651LjjwAnISKr3CKgT9v8IybP8VD0n2i5XzbqExg==} engines: {node: ^14.13.1 || >= 16} peerDependencies: @@ -655,8 +663,7 @@ packages: deepmerge: 4.2.2 kleur: 4.1.4 magic-string: 0.26.1 - svelte: 3.49.0 - svelte-hmr: 0.14.11_svelte@3.49.0 + svelte-hmr: 0.14.11 vite: 2.9.9 transitivePeerDependencies: - supports-color @@ -3104,7 +3111,11 @@ packages: resolution: {integrity: sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==} engines: {node: '>= 0.4'} +<<<<<<< HEAD /svelte-check/2.4.1_onvlxjpnd23pr3hxbmout2wrjm: +======= + /svelte-check/2.4.1_2y4otvh2n6klv6metqycpfiuzy: +>>>>>>> ad12c698 (Try clean install) resolution: {integrity: sha512-xhf3ShP5rnRwBokrgTBJ/0cO9QIc1DAVu1NWNRTfCDsDBNjGmkS3HgitgUadRuoMKj1+irZR/yHJ+Uqobnkbrw==} hasBin: true peerDependencies: @@ -3117,8 +3128,12 @@ packages: picocolors: 1.0.0 sade: 1.8.1 source-map: 0.7.3 +<<<<<<< HEAD svelte: 3.49.0 svelte-preprocess: 4.10.2_2udzbozq3wemyrf2xz7puuv2zy +======= + svelte-preprocess: 4.10.2_bw7ic75prjd4umr4fb55sbospu +>>>>>>> ad12c698 (Try clean install) typescript: 4.5.5 transitivePeerDependencies: - '@babel/core' @@ -3161,6 +3176,13 @@ packages: - sugarss dev: false + /svelte-hmr/0.14.11: + resolution: {integrity: sha512-R9CVfX6DXxW1Kn45Jtmx+yUe+sPhrbYSUp7TkzbW0jI5fVPn6lsNG9NEs5dFg5qRhFNAoVdRw5qQDLALNKhwbQ==} + engines: {node: ^12.20 || ^14.13.1 || >= 16} + peerDependencies: + svelte: '>=3.19.0' + dev: true + /svelte-hmr/0.14.11_svelte@3.49.0: resolution: {integrity: sha512-R9CVfX6DXxW1Kn45Jtmx+yUe+sPhrbYSUp7TkzbW0jI5fVPn6lsNG9NEs5dFg5qRhFNAoVdRw5qQDLALNKhwbQ==} engines: {node: ^12.20 || ^14.13.1 || >= 16} @@ -3168,6 +3190,21 @@ packages: svelte: '>=3.19.0' dependencies: svelte: 3.49.0 + dev: false + + /svelte-i18n/3.3.13: + resolution: {integrity: sha512-RQM+ys4+Y9ztH//tX22H1UL2cniLNmIR+N4xmYygV6QpQ6EyQvloZiENRew8XrVzfvJ8HaE8NU6/yurLkl7z3g==} + engines: {node: '>= 11.15.0'} + hasBin: true + peerDependencies: + svelte: ^3.25.1 + dependencies: + deepmerge: 4.2.2 + estree-walker: 2.0.2 + intl-messageformat: 9.11.4 + sade: 1.8.1 + tiny-glob: 0.2.9 + dev: false /svelte-i18n/3.3.13_svelte@3.49.0: resolution: {integrity: sha512-RQM+ys4+Y9ztH//tX22H1UL2cniLNmIR+N4xmYygV6QpQ6EyQvloZiENRew8XrVzfvJ8HaE8NU6/yurLkl7z3g==} @@ -3184,7 +3221,11 @@ packages: tiny-glob: 0.2.9 dev: false +<<<<<<< HEAD /svelte-preprocess/4.10.2_2udzbozq3wemyrf2xz7puuv2zy: +======= + /svelte-preprocess/4.10.2_bw7ic75prjd4umr4fb55sbospu: +>>>>>>> ad12c698 (Try clean install) resolution: {integrity: sha512-aPpkCreSo8EL/y8kJSa1trhiX0oyAtTjlNNM7BNjRAsMJ8Yy2LtqHt0zyd4pQPXt+D4PzbO3qTjjio3kwOxDlA==} engines: {node: '>= 9.11.2'} requiresBuild: true @@ -3233,7 +3274,6 @@ packages: postcss-load-config: 3.1.1 sorcery: 0.10.0 strip-indent: 3.0.0 - svelte: 3.49.0 typescript: 4.5.5 dev: true @@ -3308,6 +3348,7 @@ packages: /svelte/3.49.0: resolution: {integrity: sha512-+lmjic1pApJWDfPCpUUTc1m8azDqYCG1JN9YEngrx/hUyIcFJo6VZhj0A1Ai0wqoHcEIuQy+e9tk+4uDgdtsFA==} engines: {node: '>= 8'} + dev: false /sync-request/6.1.0: resolution: {integrity: sha512-8fjNkrNlNCrVc/av+Jn+xxqfCjYaBoHqCsDz6mt030UMxJGr+GSfCV1dQt2gRtlL63+VPidwDVLr7V2OcTSdRw==} From 3dd95fc80b25b41c11d0272d3c7d8ed10c94b935 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Tue, 29 Nov 2022 14:30:26 -0500 Subject: [PATCH 02/30] Resolve peer dependencies? --- ui/pnpm-lock.yaml | 55 ++++++----------------------------------------- 1 file changed, 7 insertions(+), 48 deletions(-) diff --git a/ui/pnpm-lock.yaml b/ui/pnpm-lock.yaml index 61088341eeb5..a156ee1631c0 100644 --- a/ui/pnpm-lock.yaml +++ b/ui/pnpm-lock.yaml @@ -414,21 +414,12 @@ importers: svelte: 3.49.0 devDependencies: '@sveltejs/adapter-auto': 1.0.0-next.90 -<<<<<<< HEAD '@sveltejs/kit': 1.0.0-next.318_svelte@3.49.0 autoprefixer: 10.4.2_postcss@8.4.6 postcss: 8.4.6 postcss-load-config: 3.1.1 svelte-check: 2.4.1_onvlxjpnd23pr3hxbmout2wrjm svelte-preprocess: 4.10.2_2udzbozq3wemyrf2xz7puuv2zy -======= - '@sveltejs/kit': 1.0.0-next.318 - autoprefixer: 10.4.2_postcss@8.4.6 - postcss: 8.4.6 - postcss-load-config: 3.1.1 - svelte-check: 2.4.1_2y4otvh2n6klv6metqycpfiuzy - svelte-preprocess: 4.10.2_bw7ic75prjd4umr4fb55sbospu ->>>>>>> ad12c698 (Try clean install) tailwindcss: 3.0.23_autoprefixer@10.4.2 tslib: 2.3.1 typescript: 4.5.5 @@ -606,15 +597,16 @@ packages: import-meta-resolve: 2.2.0 dev: true - /@sveltejs/kit/1.0.0-next.318: + /@sveltejs/kit/1.0.0-next.318_svelte@3.49.0: resolution: {integrity: sha512-/M/XNvEqK71KCGro1xLuiUuklsMPe+G5DiVMs39tpfFIFhH4oCzAt+YBaIZDKORogGz3QDaYc5BV+eFv9E5cyw==} engines: {node: '>=14.13'} hasBin: true peerDependencies: svelte: ^3.44.0 dependencies: - '@sveltejs/vite-plugin-svelte': 1.0.0-next.44_vite@2.9.9 + '@sveltejs/vite-plugin-svelte': 1.0.0-next.44_svelte@3.49.0+vite@2.9.9 sade: 1.8.1 + svelte: 3.49.0 vite: 2.9.9 transitivePeerDependencies: - diff-match-patch @@ -647,7 +639,7 @@ packages: - supports-color dev: false - /@sveltejs/vite-plugin-svelte/1.0.0-next.44_vite@2.9.9: + /@sveltejs/vite-plugin-svelte/1.0.0-next.44_svelte@3.49.0+vite@2.9.9: resolution: {integrity: sha512-n+sssEWbzykPS447FmnNyU5GxEhrBPDVd0lxNZnxRGz9P6651LjjwAnISKr3CKgT9v8IybP8VD0n2i5XzbqExg==} engines: {node: ^14.13.1 || >= 16} peerDependencies: @@ -663,7 +655,8 @@ packages: deepmerge: 4.2.2 kleur: 4.1.4 magic-string: 0.26.1 - svelte-hmr: 0.14.11 + svelte: 3.49.0 + svelte-hmr: 0.14.11_svelte@3.49.0 vite: 2.9.9 transitivePeerDependencies: - supports-color @@ -3111,11 +3104,7 @@ packages: resolution: {integrity: sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==} engines: {node: '>= 0.4'} -<<<<<<< HEAD /svelte-check/2.4.1_onvlxjpnd23pr3hxbmout2wrjm: -======= - /svelte-check/2.4.1_2y4otvh2n6klv6metqycpfiuzy: ->>>>>>> ad12c698 (Try clean install) resolution: {integrity: sha512-xhf3ShP5rnRwBokrgTBJ/0cO9QIc1DAVu1NWNRTfCDsDBNjGmkS3HgitgUadRuoMKj1+irZR/yHJ+Uqobnkbrw==} hasBin: true peerDependencies: @@ -3128,12 +3117,8 @@ packages: picocolors: 1.0.0 sade: 1.8.1 source-map: 0.7.3 -<<<<<<< HEAD svelte: 3.49.0 svelte-preprocess: 4.10.2_2udzbozq3wemyrf2xz7puuv2zy -======= - svelte-preprocess: 4.10.2_bw7ic75prjd4umr4fb55sbospu ->>>>>>> ad12c698 (Try clean install) typescript: 4.5.5 transitivePeerDependencies: - '@babel/core' @@ -3176,13 +3161,6 @@ packages: - sugarss dev: false - /svelte-hmr/0.14.11: - resolution: {integrity: sha512-R9CVfX6DXxW1Kn45Jtmx+yUe+sPhrbYSUp7TkzbW0jI5fVPn6lsNG9NEs5dFg5qRhFNAoVdRw5qQDLALNKhwbQ==} - engines: {node: ^12.20 || ^14.13.1 || >= 16} - peerDependencies: - svelte: '>=3.19.0' - dev: true - /svelte-hmr/0.14.11_svelte@3.49.0: resolution: {integrity: sha512-R9CVfX6DXxW1Kn45Jtmx+yUe+sPhrbYSUp7TkzbW0jI5fVPn6lsNG9NEs5dFg5qRhFNAoVdRw5qQDLALNKhwbQ==} engines: {node: ^12.20 || ^14.13.1 || >= 16} @@ -3190,21 +3168,6 @@ packages: svelte: '>=3.19.0' dependencies: svelte: 3.49.0 - dev: false - - /svelte-i18n/3.3.13: - resolution: {integrity: sha512-RQM+ys4+Y9ztH//tX22H1UL2cniLNmIR+N4xmYygV6QpQ6EyQvloZiENRew8XrVzfvJ8HaE8NU6/yurLkl7z3g==} - engines: {node: '>= 11.15.0'} - hasBin: true - peerDependencies: - svelte: ^3.25.1 - dependencies: - deepmerge: 4.2.2 - estree-walker: 2.0.2 - intl-messageformat: 9.11.4 - sade: 1.8.1 - tiny-glob: 0.2.9 - dev: false /svelte-i18n/3.3.13_svelte@3.49.0: resolution: {integrity: sha512-RQM+ys4+Y9ztH//tX22H1UL2cniLNmIR+N4xmYygV6QpQ6EyQvloZiENRew8XrVzfvJ8HaE8NU6/yurLkl7z3g==} @@ -3221,11 +3184,7 @@ packages: tiny-glob: 0.2.9 dev: false -<<<<<<< HEAD /svelte-preprocess/4.10.2_2udzbozq3wemyrf2xz7puuv2zy: -======= - /svelte-preprocess/4.10.2_bw7ic75prjd4umr4fb55sbospu: ->>>>>>> ad12c698 (Try clean install) resolution: {integrity: sha512-aPpkCreSo8EL/y8kJSa1trhiX0oyAtTjlNNM7BNjRAsMJ8Yy2LtqHt0zyd4pQPXt+D4PzbO3qTjjio3kwOxDlA==} engines: {node: '>= 9.11.2'} requiresBuild: true @@ -3274,6 +3233,7 @@ packages: postcss-load-config: 3.1.1 sorcery: 0.10.0 strip-indent: 3.0.0 + svelte: 3.49.0 typescript: 4.5.5 dev: true @@ -3348,7 +3308,6 @@ packages: /svelte/3.49.0: resolution: {integrity: sha512-+lmjic1pApJWDfPCpUUTc1m8azDqYCG1JN9YEngrx/hUyIcFJo6VZhj0A1Ai0wqoHcEIuQy+e9tk+4uDgdtsFA==} engines: {node: '>= 8'} - dev: false /sync-request/6.1.0: resolution: {integrity: sha512-8fjNkrNlNCrVc/av+Jn+xxqfCjYaBoHqCsDz6mt030UMxJGr+GSfCV1dQt2gRtlL63+VPidwDVLr7V2OcTSdRw==} From 50b60065e5b1286d2df08345697e1863a716483e Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Tue, 29 Nov 2022 17:17:38 -0500 Subject: [PATCH 03/30] CHANGELOG --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 493d9953d927..20659f6a30b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,7 +33,7 @@ demo.launch() By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2741](https://github.com/gradio-app/gradio/pull/2741) -### Set the background color of a Label component +### Set the background color of a Label component The `Label` component now accepts a `color` argument by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2736](https://github.com/gradio-app/gradio/pull/2736). The `color` argument should either be a valid css color name or hexadecimal string. From 1d56cdef9ddc0e57ad822b412b978e779bff904f Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Fri, 2 Dec 2022 09:22:36 -0500 Subject: [PATCH 04/30] Add outbreak_forcast notebook --- demo/outbreak_forecast/run.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demo/outbreak_forecast/run.ipynb b/demo/outbreak_forecast/run.ipynb index 816159e1e27b..8a59c8aba915 100644 --- a/demo/outbreak_forecast/run.ipynb +++ b/demo/outbreak_forecast/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: outbreak_forecast\n", "### Generate a plot based on 5 inputs.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio numpy matplotlib bokeh plotly altair"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import altair\n", "\n", "import gradio as gr\n", "from math import sqrt\n", "import matplotlib\n", "\n", "matplotlib.use(\"Agg\")\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import plotly.express as px\n", "import pandas as pd\n", "\n", "\n", "def outbreak(plot_type, r, month, countries, social_distancing):\n", " months = [\"January\", \"February\", \"March\", \"April\", \"May\"]\n", " m = months.index(month)\n", " start_day = 30 * m\n", " final_day = 30 * (m + 1)\n", " x = np.arange(start_day, final_day + 1)\n", " pop_count = {\"USA\": 350, \"Canada\": 40, \"Mexico\": 300, \"UK\": 120}\n", " if social_distancing:\n", " r = sqrt(r)\n", " df = pd.DataFrame({\"day\": x})\n", " for country in countries:\n", " df[country] = x ** (r) * (pop_count[country] + 1)\n", "\n", " if plot_type == \"Matplotlib\":\n", " fig = plt.figure()\n", " plt.plot(df[\"day\"], df[countries].to_numpy())\n", " plt.title(\"Outbreak in \" + month)\n", " plt.ylabel(\"Cases\")\n", " plt.xlabel(\"Days since Day 0\")\n", " plt.legend(countries)\n", " return fig\n", " elif plot_type == \"Plotly\":\n", " fig = px.line(df, x=\"day\", y=countries)\n", " fig.update_layout(\n", " title=\"Outbreak in \" + month,\n", " xaxis_title=\"Cases\",\n", " yaxis_title=\"Days Since Day 0\",\n", " )\n", " return fig\n", " elif plot_type == \"Altair\":\n", " df = df.melt(id_vars=\"day\").rename(columns={\"variable\": \"country\"})\n", " fig = altair.Chart(df).mark_line().encode(x=\"day\", y='value', color='country')\n", " return fig\n", " else:\n", " raise ValueError(\"A plot type must be selected\")\n", "\n", "\n", "inputs = [\n", " gr.Dropdown([\"Matplotlib\", \"Plotly\", \"Altair\"], label=\"Plot Type\"),\n", " gr.Slider(1, 4, 3.2, label=\"R\"),\n", " gr.Dropdown([\"January\", \"February\", \"March\", \"April\", \"May\"], label=\"Month\"),\n", " gr.CheckboxGroup(\n", " [\"USA\", \"Canada\", \"Mexico\", \"UK\"], label=\"Countries\", value=[\"USA\", \"Canada\"]\n", " ),\n", " gr.Checkbox(label=\"Social Distancing?\"),\n", "]\n", "outputs = gr.Plot()\n", "\n", "demo = gr.Interface(\n", " fn=outbreak,\n", " inputs=inputs,\n", " outputs=outputs,\n", " examples=[\n", " [\"Matplotlib\", 2, \"March\", [\"Mexico\", \"UK\"], True],\n", " [\"Altair\", 2, \"March\", [\"Mexico\", \"Canada\"], True],\n", " [\"Plotly\", 3.6, \"February\", [\"Canada\", \"Mexico\", \"UK\"], False],\n", " ],\n", " cache_examples=True,\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: outbreak_forecast\n", "### Generate a plot based on 5 inputs.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio numpy matplotlib bokeh plotly altair"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import altair\n", "\n", "import gradio as gr\n", "from math import sqrt\n", "import matplotlib\n", "\n", "matplotlib.use(\"Agg\")\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import plotly.express as px\n", "import pandas as pd\n", "\n", "\n", "def outbreak(plot_type, r, month, countries, social_distancing):\n", " months = [\"January\", \"February\", \"March\", \"April\", \"May\"]\n", " m = months.index(month)\n", " start_day = 30 * m\n", " final_day = 30 * (m + 1)\n", " x = np.arange(start_day, final_day + 1)\n", " pop_count = {\"USA\": 350, \"Canada\": 40, \"Mexico\": 300, \"UK\": 120}\n", " if social_distancing:\n", " r = sqrt(r)\n", " df = pd.DataFrame({\"day\": x})\n", " for country in countries:\n", " df[country] = x ** (r) * (pop_count[country] + 1)\n", "\n", " if plot_type == \"Matplotlib\":\n", " fig = plt.figure()\n", " plt.plot(df[\"day\"], df[countries].to_numpy())\n", " plt.title(\"Outbreak in \" + month)\n", " plt.ylabel(\"Cases\")\n", " plt.xlabel(\"Days since Day 0\")\n", " plt.legend(countries)\n", " return fig\n", " elif plot_type == \"Plotly\":\n", " fig = px.line(df, x=\"day\", y=countries)\n", " fig.update_layout(\n", " title=\"Outbreak in \" + month,\n", " xaxis_title=\"Cases\",\n", " yaxis_title=\"Days Since Day 0\",\n", " )\n", " return fig\n", " elif plot_type == \"Altair\":\n", " df = df.melt(id_vars=\"day\").rename(columns={\"variable\": \"country\"})\n", " fig = altair.Chart(df).mark_line().encode(x=\"day\", y='value', color='country')\n", " return fig\n", " else:\n", " raise ValueError(\"A plot type must be selected\")\n", "\n", "\n", "inputs = [\n", " gr.Dropdown([\"Matplotlib\", \"Plotly\", \"Altair\"], label=\"Plot Type\"),\n", " gr.Slider(1, 4, 3.2, label=\"R\"),\n", " gr.Dropdown([\"January\", \"February\", \"March\", \"April\", \"May\"], label=\"Month\"),\n", " gr.CheckboxGroup(\n", " [\"USA\", \"Canada\", \"Mexico\", \"UK\"], label=\"Countries\", value=[\"USA\", \"Canada\"]\n", " ),\n", " gr.Checkbox(label=\"Social Distancing?\"),\n", "]\n", "outputs = gr.Plot()\n", "\n", "demo = gr.Interface(\n", " fn=outbreak,\n", " inputs=inputs,\n", " outputs=outputs,\n", " examples=[\n", " [\"Matplotlib\", 2, \"March\", [\"Mexico\", \"UK\"], True],\n", " [\"Altair\", 2, \"March\", [\"Mexico\", \"Canada\"], True],\n", " [\"Plotly\", 3.6, \"February\", [\"Canada\", \"Mexico\", \"UK\"], False],\n", " ],\n", " cache_examples=True,\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} From bfa84f7ac0f8de48e9275ce513e2a2babb2b0a3c Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Fri, 2 Dec 2022 09:53:56 -0500 Subject: [PATCH 05/30] generate again --- demo/outbreak_forecast/run.ipynb | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/demo/outbreak_forecast/run.ipynb b/demo/outbreak_forecast/run.ipynb index 8a59c8aba915..278fdfd31f26 100644 --- a/demo/outbreak_forecast/run.ipynb +++ b/demo/outbreak_forecast/run.ipynb @@ -1 +1,5 @@ +<<<<<<< HEAD {"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: outbreak_forecast\n", "### Generate a plot based on 5 inputs.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio numpy matplotlib bokeh plotly altair"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import altair\n", "\n", "import gradio as gr\n", "from math import sqrt\n", "import matplotlib\n", "\n", "matplotlib.use(\"Agg\")\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import plotly.express as px\n", "import pandas as pd\n", "\n", "\n", "def outbreak(plot_type, r, month, countries, social_distancing):\n", " months = [\"January\", \"February\", \"March\", \"April\", \"May\"]\n", " m = months.index(month)\n", " start_day = 30 * m\n", " final_day = 30 * (m + 1)\n", " x = np.arange(start_day, final_day + 1)\n", " pop_count = {\"USA\": 350, \"Canada\": 40, \"Mexico\": 300, \"UK\": 120}\n", " if social_distancing:\n", " r = sqrt(r)\n", " df = pd.DataFrame({\"day\": x})\n", " for country in countries:\n", " df[country] = x ** (r) * (pop_count[country] + 1)\n", "\n", " if plot_type == \"Matplotlib\":\n", " fig = plt.figure()\n", " plt.plot(df[\"day\"], df[countries].to_numpy())\n", " plt.title(\"Outbreak in \" + month)\n", " plt.ylabel(\"Cases\")\n", " plt.xlabel(\"Days since Day 0\")\n", " plt.legend(countries)\n", " return fig\n", " elif plot_type == \"Plotly\":\n", " fig = px.line(df, x=\"day\", y=countries)\n", " fig.update_layout(\n", " title=\"Outbreak in \" + month,\n", " xaxis_title=\"Cases\",\n", " yaxis_title=\"Days Since Day 0\",\n", " )\n", " return fig\n", " elif plot_type == \"Altair\":\n", " df = df.melt(id_vars=\"day\").rename(columns={\"variable\": \"country\"})\n", " fig = altair.Chart(df).mark_line().encode(x=\"day\", y='value', color='country')\n", " return fig\n", " else:\n", " raise ValueError(\"A plot type must be selected\")\n", "\n", "\n", "inputs = [\n", " gr.Dropdown([\"Matplotlib\", \"Plotly\", \"Altair\"], label=\"Plot Type\"),\n", " gr.Slider(1, 4, 3.2, label=\"R\"),\n", " gr.Dropdown([\"January\", \"February\", \"March\", \"April\", \"May\"], label=\"Month\"),\n", " gr.CheckboxGroup(\n", " [\"USA\", \"Canada\", \"Mexico\", \"UK\"], label=\"Countries\", value=[\"USA\", \"Canada\"]\n", " ),\n", " gr.Checkbox(label=\"Social Distancing?\"),\n", "]\n", "outputs = gr.Plot()\n", "\n", "demo = gr.Interface(\n", " fn=outbreak,\n", " inputs=inputs,\n", " outputs=outputs,\n", " examples=[\n", " [\"Matplotlib\", 2, \"March\", [\"Mexico\", \"UK\"], True],\n", " [\"Altair\", 2, \"March\", [\"Mexico\", \"Canada\"], True],\n", " [\"Plotly\", 3.6, \"February\", [\"Canada\", \"Mexico\", \"UK\"], False],\n", " ],\n", " cache_examples=True,\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} +======= +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: outbreak_forecast\n", "### Generate a plot based on 5 inputs.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio numpy matplotlib bokeh plotly altair"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import altair\n", "\n", "import gradio as gr\n", "from math import sqrt\n", "import matplotlib\n", "\n", "matplotlib.use(\"Agg\")\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import plotly.express as px\n", "import pandas as pd\n", "\n", "\n", "def outbreak(plot_type, r, month, countries, social_distancing):\n", " months = [\"January\", \"February\", \"March\", \"April\", \"May\"]\n", " m = months.index(month)\n", " start_day = 30 * m\n", " final_day = 30 * (m + 1)\n", " x = np.arange(start_day, final_day + 1)\n", " pop_count = {\"USA\": 350, \"Canada\": 40, \"Mexico\": 300, \"UK\": 120}\n", " if social_distancing:\n", " r = sqrt(r)\n", " df = pd.DataFrame({\"day\": x})\n", " for country in countries:\n", " df[country] = x ** (r) * (pop_count[country] + 1)\n", "\n", " if plot_type == \"Matplotlib\":\n", " fig = plt.figure()\n", " plt.plot(df[\"day\"], df[countries].to_numpy())\n", " plt.title(\"Outbreak in \" + month)\n", " plt.ylabel(\"Cases\")\n", " plt.xlabel(\"Days since Day 0\")\n", " plt.legend(countries)\n", " return fig\n", " elif plot_type == \"Plotly\":\n", " fig = px.line(df, x=\"day\", y=countries)\n", " fig.update_layout(\n", " title=\"Outbreak in \" + month,\n", " xaxis_title=\"Cases\",\n", " yaxis_title=\"Days Since Day 0\",\n", " )\n", " return fig\n", " elif plot_type == \"Altair\":\n", " df = df.melt(id_vars=\"day\").rename(columns={\"variable\": \"country\"})\n", " fig = altair.Chart(df).mark_line().encode(x=\"day\", y='value', color='country')\n", " return fig\n", " else:\n", " raise ValueError(\"A plot type must be selected\")\n", "\n", "\n", "inputs = [\n", " gr.Dropdown([\"Matplotlib\", \"Plotly\", \"Altair\"], label=\"Plot Type\"),\n", " gr.Slider(1, 4, 3.2, label=\"R\"),\n", " gr.Dropdown([\"January\", \"February\", \"March\", \"April\", \"May\"], label=\"Month\"),\n", " gr.CheckboxGroup(\n", " [\"USA\", \"Canada\", \"Mexico\", \"UK\"], label=\"Countries\", value=[\"USA\", \"Canada\"]\n", " ),\n", " gr.Checkbox(label=\"Social Distancing?\"),\n", "]\n", "outputs = gr.Plot()\n", "\n", "demo = gr.Interface(\n", " fn=outbreak,\n", " inputs=inputs,\n", " outputs=outputs,\n", " examples=[\n", " [\"Matplotlib\", 2, \"March\", [\"Mexico\", \"UK\"], True],\n", " [\"Altair\", 2, \"March\", [\"Mexico\", \"Canada\"], True],\n", " [\"Plotly\", 3.6, \"February\", [\"Canada\", \"Mexico\", \"UK\"], False],\n", " ],\n", " cache_examples=True,\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} +>>>>>>> 5f4d29de (generate again) From ee242ba26bdac5b93713435965bee4437bab09be Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Tue, 29 Nov 2022 17:17:38 -0500 Subject: [PATCH 06/30] CHANGELOG --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 20659f6a30b4..0440d4911cc1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -70,6 +70,11 @@ demo.queue().launch() ![label_bg_color_update](https://user-images.githubusercontent.com/41651716/204400372-80e53857-f26f-4a38-a1ae-1acadff75e89.gif) +======= + + +By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2741](https://github.com/gradio-app/gradio/pull/2741) +>>>>>>> d457c6c1 (CHANGELOG) ## Bug Fixes: * Fixed issue where image thumbnails were not showing when an example directory was provided From 1f832c9380c0609d3217817340677457304d3554 Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Tue, 29 Nov 2022 17:19:07 -0500 Subject: [PATCH 07/30] Add image to changelog --- CHANGELOG.md | 6 ------ 1 file changed, 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0440d4911cc1..5d34202593a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -70,12 +70,6 @@ demo.queue().launch() ![label_bg_color_update](https://user-images.githubusercontent.com/41651716/204400372-80e53857-f26f-4a38-a1ae-1acadff75e89.gif) -======= - - -By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2741](https://github.com/gradio-app/gradio/pull/2741) ->>>>>>> d457c6c1 (CHANGELOG) - ## Bug Fixes: * Fixed issue where image thumbnails were not showing when an example directory was provided by by [@abidlabs](https://github.com/abidlabs) in [PR 2745](https://github.com/gradio-app/gradio/pull/2745) From 18079f72e183eaae4cefcaf292e7cf6772f2138a Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Thu, 1 Dec 2022 17:50:28 -0500 Subject: [PATCH 08/30] Color palette --- gradio/__init__.py | 1 + gradio/components.py | 42 ++++++++++++++++++++++++++++++ ui/packages/plot/src/Plot.svelte | 44 +++++++++++++++++++++++++++++++- 3 files changed, 86 insertions(+), 1 deletion(-) diff --git a/gradio/__init__.py b/gradio/__init__.py index 3c76ebbc9ec5..b125955d2f89 100644 --- a/gradio/__init__.py +++ b/gradio/__init__.py @@ -46,6 +46,7 @@ Variable, Video, component, + ScatterPlot ) from gradio.examples import create_examples as Examples from gradio.exceptions import Error diff --git a/gradio/components.py b/gradio/components.py index b6b7938595c6..a3cf42da9da5 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -3941,6 +3941,48 @@ def style(self): return self +@document("change", "clear") +class ScatterPlot(Plot): + + def __init__(self, value: pd.DataFrame, + x: str, + y: str, + color: Optional[str] = None, + label: Optional[str] = None, + show_label: bool = True, + visible: bool = True, + elem_id: Optional[str] = None,): + self.x = x + self.y = y + self.color = color + self.value = self.postprocess(value) + super().__init__(value, label=label, show_label=show_label, visible=visible, elem_id=elem_id) + + def get_block_name(self) -> str: + return "plot" + + def postprocess(self, y: pd.DataFrame | None) -> Dict[str, str] | None: + import altair as alt + + encodings = dict(x=self.x, y=self.y) + if self.color: + domain = y[self.color].unique().tolist() + encodings['color'] = { + 'field': self.color, + "type": "nominal", + "scale": { + "domain": domain, + "range": list(range(len(domain))) + } + } + + chart = alt.Chart(y).mark_point().\ + encode(**encodings).\ + properties(background='transparent') + + return {"type": "altair", "plot": chart.to_json(), "chart": "scatter"} + + @document("change") class Markdown(IOComponent, Changeable, SimpleSerializable): """ diff --git a/ui/packages/plot/src/Plot.svelte b/ui/packages/plot/src/Plot.svelte index 774cc50648c7..45ed0da06d4e 100644 --- a/ui/packages/plot/src/Plot.svelte +++ b/ui/packages/plot/src/Plot.svelte @@ -1,17 +1,59 @@ - - + diff --git a/ui/packages/plot/src/Plot.svelte b/ui/packages/plot/src/Plot.svelte index 39ba8bbb7e6d..047d539bdd42 100644 --- a/ui/packages/plot/src/Plot.svelte +++ b/ui/packages/plot/src/Plot.svelte @@ -4,10 +4,8 @@ import { Plot as PlotIcon } from "@gradio/icons"; import { colors as color_palette, ordered_colors } from "@gradio/theme"; import { get_next_color } from "@gradio/utils"; - import { Vega } from "svelte-vega"; import tw_colors from "tailwindcss/colors"; - import { afterUpdate, onDestroy } from "svelte"; export let value; @@ -49,6 +47,10 @@ "labelFont": "monospace", "titleColor": darkmode ? tw_colors.slate['200'] : "black", "titleFont": 'monospace', + }, + "title": { + "color": darkmode ? tw_colors.slate['200'] : "black", + "titleFont": "monospace" } } if (spec['encoding']['color']) { From 80a00950963433fab96fba8ea41817905d30b601 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Mon, 5 Dec 2022 10:28:08 -0500 Subject: [PATCH 12/30] Clean up code a bit + quant scales --- gradio/components.py | 21 ++++++++++++++----- ui/packages/plot/src/Plot.svelte | 36 +++++++++----------------------- 2 files changed, 26 insertions(+), 31 deletions(-) diff --git a/gradio/components.py b/gradio/components.py index 6ae596af6457..5e60811f94c4 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -20,6 +20,7 @@ from types import ModuleType from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple +import altair as alt import matplotlib.figure import numpy as np import pandas as pd @@ -28,6 +29,7 @@ from ffmpy import FFmpeg from markdown_it import MarkdownIt from mdit_py_plugins.dollarmath import dollarmath_plugin +from pandas.api.types import is_numeric_dtype from gradio import media_data, processing_utils, utils from gradio.blocks import Block @@ -3956,6 +3958,7 @@ def __init__( tooltip: Optional[str] = None, x_title: Optional[str] = None, y_title: Optional[str] = None, + legend_title: Optional[str] = None, label: Optional[str] = None, show_label: bool = True, visible: bool = True, @@ -3968,6 +3971,7 @@ def __init__( self.title = title self.x_title = x_title self.y_title = y_title + self.legend_title = legend_title self.value = self.postprocess(value) super().__init__( value, label=label, show_label=show_label, visible=visible, elem_id=elem_id @@ -3977,8 +3981,6 @@ def get_block_name(self) -> str: return "plot" def postprocess(self, y: pd.DataFrame | None) -> Dict[str, str] | None: - import altair as alt - encodings = dict( x=alt.X(self.x, title=self.x_title or self.x), y=alt.Y(self.y, title=self.y_title or self.y), @@ -3987,11 +3989,20 @@ def postprocess(self, y: pd.DataFrame | None) -> Dict[str, str] | None: if self.title: properties["title"] = self.title if self.color: - domain = y[self.color].unique().tolist() + if is_numeric_dtype(y[self.color]): + domain = [y[self.color].min(), y[self.color].max()] + range_ = [0, 1] + type_ = "quantitative" + else: + domain = y[self.color].unique().tolist() + range_ = list(range(len(domain))) + type_ = "nominal" + encodings["color"] = { "field": self.color, - "type": "nominal", - "scale": {"domain": domain, "range": list(range(len(domain)))}, + "type": type_, + "legend": {"title": self.legend_title or self.color}, + "scale": {"domain": domain, "range": range_}, } if self.tooltip: encodings["tooltip"] = self.tooltip diff --git a/ui/packages/plot/src/Plot.svelte b/ui/packages/plot/src/Plot.svelte index 047d539bdd42..bcb69186d446 100644 --- a/ui/packages/plot/src/Plot.svelte +++ b/ui/packages/plot/src/Plot.svelte @@ -5,8 +5,8 @@ import { colors as color_palette, ordered_colors } from "@gradio/theme"; import { get_next_color } from "@gradio/utils"; import { Vega } from "svelte-vega"; - import tw_colors from "tailwindcss/colors"; import { afterUpdate, onDestroy } from "svelte"; + import {create_config} from "./utils" export let value; export let target; @@ -31,34 +31,18 @@ $: if(value && value['type'] == "altair") { spec = JSON.parse(value['plot']) - switch (value['chart'] || 'foo') { + const config = create_config(darkmode); + spec['config'] = config; + switch (value['chart'] || '') { case "scatter": - const config = { - "axis": { - "labelFont": 'monospace', - "labelColor": darkmode ? tw_colors.slate['200'] : "black", - "titleFont": 'monospace', - "titleColor": darkmode ? tw_colors.slate['200'] : "black", - "tickColor": "#aaa", - "gridColor": "#aaa" - }, - "legend": { - "labelColor": darkmode ? tw_colors.slate['200'] : "black", - "labelFont": "monospace", - "titleColor": darkmode ? tw_colors.slate['200'] : "black", - "titleFont": 'monospace', - }, - "title": { - "color": darkmode ? tw_colors.slate['200'] : "black", - "titleFont": "monospace" - } - } - if (spec['encoding']['color']) { - console.log(spec); + if (spec['encoding']['color'] && spec['encoding']['color']['type'] == 'nominal') { spec['encoding']['color']['scale']['range'] = spec['encoding']['color']['scale']['range'].map((e, i) => get_color(i)); } - spec['config'] = config; - console.log(spec); + else if (spec['encoding']['color'] && spec['encoding']['color']['type'] == 'quantitative') { + spec['encoding']['color']['scale']['range'] = ['#eff6ff', '#1e3a8a']; + spec['encoding']['color']['scale']['interpolate'] = "hsv"; + + } break; default: break; From c8f8a189634860de3c6e2f54a89c68e5f3dbb35b Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Mon, 5 Dec 2022 11:01:21 -0500 Subject: [PATCH 13/30] Add code --- gradio/components.py | 8 ++++++-- ui/packages/plot/src/Plot.svelte | 5 +++-- ui/packages/plot/src/utils.ts | 25 +++++++++++++++++++++++++ 3 files changed, 34 insertions(+), 4 deletions(-) create mode 100644 ui/packages/plot/src/utils.ts diff --git a/gradio/components.py b/gradio/components.py index 5e60811f94c4..dfe3afa2e118 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -3950,9 +3950,9 @@ def style(self, container: Optional[bool] = None): class ScatterPlot(Plot): def __init__( self, - value: pd.DataFrame, x: str, y: str, + value: Optional[pd.DataFrame] = None, color: Optional[str] = None, title: Optional[str] = None, tooltip: Optional[str] = None, @@ -3972,7 +3972,9 @@ def __init__( self.x_title = x_title self.y_title = y_title self.legend_title = legend_title - self.value = self.postprocess(value) + self.value = None + if value is not None: + self.value = self.postprocess(value) super().__init__( value, label=label, show_label=show_label, visible=visible, elem_id=elem_id ) @@ -3981,6 +3983,8 @@ def get_block_name(self) -> str: return "plot" def postprocess(self, y: pd.DataFrame | None) -> Dict[str, str] | None: + if y is None: + return y encodings = dict( x=alt.X(self.x, title=self.x_title or self.x), y=alt.Y(self.y, title=self.y_title or self.y), diff --git a/ui/packages/plot/src/Plot.svelte b/ui/packages/plot/src/Plot.svelte index bcb69186d446..82a848ff9c73 100644 --- a/ui/packages/plot/src/Plot.svelte +++ b/ui/packages/plot/src/Plot.svelte @@ -33,6 +33,7 @@ spec = JSON.parse(value['plot']) const config = create_config(darkmode); spec['config'] = config; + console.log(spec); switch (value['chart'] || '') { case "scatter": if (spec['encoding']['color'] && spec['encoding']['color']['type'] == 'nominal') { @@ -40,8 +41,8 @@ } else if (spec['encoding']['color'] && spec['encoding']['color']['type'] == 'quantitative') { spec['encoding']['color']['scale']['range'] = ['#eff6ff', '#1e3a8a']; - spec['encoding']['color']['scale']['interpolate'] = "hsv"; - + spec['encoding']['color']['scale']['interpolate'] = "hsl"; + console.log(spec); } break; default: diff --git a/ui/packages/plot/src/utils.ts b/ui/packages/plot/src/utils.ts new file mode 100644 index 000000000000..db947a2868b4 --- /dev/null +++ b/ui/packages/plot/src/utils.ts @@ -0,0 +1,25 @@ +import type { Config as VegaConfig } from "vega"; +import tw_colors from "tailwindcss/colors"; + +export function create_config(darkmode: boolean): VegaConfig { + return { + axis: { + labelFont: "monospace", + labelColor: darkmode ? tw_colors.slate["200"] : "black", + titleFont: "monospace", + titleColor: darkmode ? tw_colors.slate["200"] : "black", + tickColor: "#aaa", + gridColor: "#aaa" + }, + legend: { + labelColor: darkmode ? tw_colors.slate["200"] : "black", + labelFont: "monospace", + titleColor: darkmode ? tw_colors.slate["200"] : "black", + titleFont: "monospace" + }, + title: { + color: darkmode ? tw_colors.slate["200"] : "black", + font: "monospace" + } + }; +} From f455f4dcbf44950fd1f74b9c0a1f27312a3da5ca Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Mon, 5 Dec 2022 15:36:18 -0500 Subject: [PATCH 14/30] Add size, shape + rename legend title --- gradio/components.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/gradio/components.py b/gradio/components.py index dfe3afa2e118..be3d994f662b 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -3954,11 +3954,15 @@ def __init__( y: str, value: Optional[pd.DataFrame] = None, color: Optional[str] = None, + size: Optional[str] = None, + shape: Optional[str] = None, title: Optional[str] = None, tooltip: Optional[str] = None, x_title: Optional[str] = None, y_title: Optional[str] = None, - legend_title: Optional[str] = None, + color_legend_title: Optional[str] = None, + size_legend_title: Optional[str] = None, + shape_legend_title: Optional[str] = None, label: Optional[str] = None, show_label: bool = True, visible: bool = True, @@ -3967,11 +3971,15 @@ def __init__( self.x = x self.y = y self.color = color + self.size = size + self.shape = shape self.tooltip = tooltip self.title = title self.x_title = x_title self.y_title = y_title - self.legend_title = legend_title + self.color_legend_title = color_legend_title + self.size_legend_title = size_legend_title + self.shape_legend_title = shape_legend_title self.value = None if value is not None: self.value = self.postprocess(value) @@ -4005,11 +4013,25 @@ def postprocess(self, y: pd.DataFrame | None) -> Dict[str, str] | None: encodings["color"] = { "field": self.color, "type": type_, - "legend": {"title": self.legend_title or self.color}, + "legend": {"title": self.color_legend_title or self.color}, "scale": {"domain": domain, "range": range_}, } if self.tooltip: encodings["tooltip"] = self.tooltip + if self.size: + encodings["size"] = { + "field": self.size, + "type": "quantitative" if is_numeric_dtype(y[self.size]) else "nominal", + "legend": {"title": self.size_legend_title or self.size}, + } + if self.shape: + encodings["shape"] = { + "field": self.shape, + "type": "quantitative" + if is_numeric_dtype(y[self.shape]) + else "nominal", + "legend": {"title": self.shape_legend_title or self.shape}, + } chart = ( alt.Chart(y) From f546f03d7bddc9da1d4b7a1471f300497483cf86 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Mon, 5 Dec 2022 15:39:10 -0500 Subject: [PATCH 15/30] Fix demo --- demo/outbreak_forecast/run.ipynb | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/demo/outbreak_forecast/run.ipynb b/demo/outbreak_forecast/run.ipynb index 278fdfd31f26..816159e1e27b 100644 --- a/demo/outbreak_forecast/run.ipynb +++ b/demo/outbreak_forecast/run.ipynb @@ -1,5 +1 @@ -<<<<<<< HEAD -{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: outbreak_forecast\n", "### Generate a plot based on 5 inputs.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio numpy matplotlib bokeh plotly altair"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import altair\n", "\n", "import gradio as gr\n", "from math import sqrt\n", "import matplotlib\n", "\n", "matplotlib.use(\"Agg\")\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import plotly.express as px\n", "import pandas as pd\n", "\n", "\n", "def outbreak(plot_type, r, month, countries, social_distancing):\n", " months = [\"January\", \"February\", \"March\", \"April\", \"May\"]\n", " m = months.index(month)\n", " start_day = 30 * m\n", " final_day = 30 * (m + 1)\n", " x = np.arange(start_day, final_day + 1)\n", " pop_count = {\"USA\": 350, \"Canada\": 40, \"Mexico\": 300, \"UK\": 120}\n", " if social_distancing:\n", " r = sqrt(r)\n", " df = pd.DataFrame({\"day\": x})\n", " for country in countries:\n", " df[country] = x ** (r) * (pop_count[country] + 1)\n", "\n", " if plot_type == \"Matplotlib\":\n", " fig = plt.figure()\n", " plt.plot(df[\"day\"], df[countries].to_numpy())\n", " plt.title(\"Outbreak in \" + month)\n", " plt.ylabel(\"Cases\")\n", " plt.xlabel(\"Days since Day 0\")\n", " plt.legend(countries)\n", " return fig\n", " elif plot_type == \"Plotly\":\n", " fig = px.line(df, x=\"day\", y=countries)\n", " fig.update_layout(\n", " title=\"Outbreak in \" + month,\n", " xaxis_title=\"Cases\",\n", " yaxis_title=\"Days Since Day 0\",\n", " )\n", " return fig\n", " elif plot_type == \"Altair\":\n", " df = df.melt(id_vars=\"day\").rename(columns={\"variable\": \"country\"})\n", " fig = altair.Chart(df).mark_line().encode(x=\"day\", y='value', color='country')\n", " return fig\n", " else:\n", " raise ValueError(\"A plot type must be selected\")\n", "\n", "\n", "inputs = [\n", " gr.Dropdown([\"Matplotlib\", \"Plotly\", \"Altair\"], label=\"Plot Type\"),\n", " gr.Slider(1, 4, 3.2, label=\"R\"),\n", " gr.Dropdown([\"January\", \"February\", \"March\", \"April\", \"May\"], label=\"Month\"),\n", " gr.CheckboxGroup(\n", " [\"USA\", \"Canada\", \"Mexico\", \"UK\"], label=\"Countries\", value=[\"USA\", \"Canada\"]\n", " ),\n", " gr.Checkbox(label=\"Social Distancing?\"),\n", "]\n", "outputs = gr.Plot()\n", "\n", "demo = gr.Interface(\n", " fn=outbreak,\n", " inputs=inputs,\n", " outputs=outputs,\n", " examples=[\n", " [\"Matplotlib\", 2, \"March\", [\"Mexico\", \"UK\"], True],\n", " [\"Altair\", 2, \"March\", [\"Mexico\", \"Canada\"], True],\n", " [\"Plotly\", 3.6, \"February\", [\"Canada\", \"Mexico\", \"UK\"], False],\n", " ],\n", " cache_examples=True,\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} -======= -{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: outbreak_forecast\n", "### Generate a plot based on 5 inputs.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio numpy matplotlib bokeh plotly altair"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import altair\n", "\n", "import gradio as gr\n", "from math import sqrt\n", "import matplotlib\n", "\n", "matplotlib.use(\"Agg\")\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import plotly.express as px\n", "import pandas as pd\n", "\n", "\n", "def outbreak(plot_type, r, month, countries, social_distancing):\n", " months = [\"January\", \"February\", \"March\", \"April\", \"May\"]\n", " m = months.index(month)\n", " start_day = 30 * m\n", " final_day = 30 * (m + 1)\n", " x = np.arange(start_day, final_day + 1)\n", " pop_count = {\"USA\": 350, \"Canada\": 40, \"Mexico\": 300, \"UK\": 120}\n", " if social_distancing:\n", " r = sqrt(r)\n", " df = pd.DataFrame({\"day\": x})\n", " for country in countries:\n", " df[country] = x ** (r) * (pop_count[country] + 1)\n", "\n", " if plot_type == \"Matplotlib\":\n", " fig = plt.figure()\n", " plt.plot(df[\"day\"], df[countries].to_numpy())\n", " plt.title(\"Outbreak in \" + month)\n", " plt.ylabel(\"Cases\")\n", " plt.xlabel(\"Days since Day 0\")\n", " plt.legend(countries)\n", " return fig\n", " elif plot_type == \"Plotly\":\n", " fig = px.line(df, x=\"day\", y=countries)\n", " fig.update_layout(\n", " title=\"Outbreak in \" + month,\n", " xaxis_title=\"Cases\",\n", " yaxis_title=\"Days Since Day 0\",\n", " )\n", " return fig\n", " elif plot_type == \"Altair\":\n", " df = df.melt(id_vars=\"day\").rename(columns={\"variable\": \"country\"})\n", " fig = altair.Chart(df).mark_line().encode(x=\"day\", y='value', color='country')\n", " return fig\n", " else:\n", " raise ValueError(\"A plot type must be selected\")\n", "\n", "\n", "inputs = [\n", " gr.Dropdown([\"Matplotlib\", \"Plotly\", \"Altair\"], label=\"Plot Type\"),\n", " gr.Slider(1, 4, 3.2, label=\"R\"),\n", " gr.Dropdown([\"January\", \"February\", \"March\", \"April\", \"May\"], label=\"Month\"),\n", " gr.CheckboxGroup(\n", " [\"USA\", \"Canada\", \"Mexico\", \"UK\"], label=\"Countries\", value=[\"USA\", \"Canada\"]\n", " ),\n", " gr.Checkbox(label=\"Social Distancing?\"),\n", "]\n", "outputs = gr.Plot()\n", "\n", "demo = gr.Interface(\n", " fn=outbreak,\n", " inputs=inputs,\n", " outputs=outputs,\n", " examples=[\n", " [\"Matplotlib\", 2, \"March\", [\"Mexico\", \"UK\"], True],\n", " [\"Altair\", 2, \"March\", [\"Mexico\", \"Canada\"], True],\n", " [\"Plotly\", 3.6, \"February\", [\"Canada\", \"Mexico\", \"UK\"], False],\n", " ],\n", " cache_examples=True,\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} ->>>>>>> 5f4d29de (generate again) +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: outbreak_forecast\n", "### Generate a plot based on 5 inputs.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio numpy matplotlib bokeh plotly altair"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import altair\n", "\n", "import gradio as gr\n", "from math import sqrt\n", "import matplotlib\n", "\n", "matplotlib.use(\"Agg\")\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import plotly.express as px\n", "import pandas as pd\n", "\n", "\n", "def outbreak(plot_type, r, month, countries, social_distancing):\n", " months = [\"January\", \"February\", \"March\", \"April\", \"May\"]\n", " m = months.index(month)\n", " start_day = 30 * m\n", " final_day = 30 * (m + 1)\n", " x = np.arange(start_day, final_day + 1)\n", " pop_count = {\"USA\": 350, \"Canada\": 40, \"Mexico\": 300, \"UK\": 120}\n", " if social_distancing:\n", " r = sqrt(r)\n", " df = pd.DataFrame({\"day\": x})\n", " for country in countries:\n", " df[country] = x ** (r) * (pop_count[country] + 1)\n", "\n", " if plot_type == \"Matplotlib\":\n", " fig = plt.figure()\n", " plt.plot(df[\"day\"], df[countries].to_numpy())\n", " plt.title(\"Outbreak in \" + month)\n", " plt.ylabel(\"Cases\")\n", " plt.xlabel(\"Days since Day 0\")\n", " plt.legend(countries)\n", " return fig\n", " elif plot_type == \"Plotly\":\n", " fig = px.line(df, x=\"day\", y=countries)\n", " fig.update_layout(\n", " title=\"Outbreak in \" + month,\n", " xaxis_title=\"Cases\",\n", " yaxis_title=\"Days Since Day 0\",\n", " )\n", " return fig\n", " elif plot_type == \"Altair\":\n", " df = df.melt(id_vars=\"day\").rename(columns={\"variable\": \"country\"})\n", " fig = altair.Chart(df).mark_line().encode(x=\"day\", y='value', color='country')\n", " return fig\n", " else:\n", " raise ValueError(\"A plot type must be selected\")\n", "\n", "\n", "inputs = [\n", " gr.Dropdown([\"Matplotlib\", \"Plotly\", \"Altair\"], label=\"Plot Type\"),\n", " gr.Slider(1, 4, 3.2, label=\"R\"),\n", " gr.Dropdown([\"January\", \"February\", \"March\", \"April\", \"May\"], label=\"Month\"),\n", " gr.CheckboxGroup(\n", " [\"USA\", \"Canada\", \"Mexico\", \"UK\"], label=\"Countries\", value=[\"USA\", \"Canada\"]\n", " ),\n", " gr.Checkbox(label=\"Social Distancing?\"),\n", "]\n", "outputs = gr.Plot()\n", "\n", "demo = gr.Interface(\n", " fn=outbreak,\n", " inputs=inputs,\n", " outputs=outputs,\n", " examples=[\n", " [\"Matplotlib\", 2, \"March\", [\"Mexico\", \"UK\"], True],\n", " [\"Altair\", 2, \"March\", [\"Mexico\", \"Canada\"], True],\n", " [\"Plotly\", 3.6, \"February\", [\"Canada\", \"Mexico\", \"UK\"], False],\n", " ],\n", " cache_examples=True,\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file From 289f075b56ad70e6d93536cb26693809a0f40bf0 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Wed, 7 Dec 2022 11:08:22 -0500 Subject: [PATCH 16/30] Add update + demo --- demo/native_plots/requirements.txt | 1 + demo/native_plots/run.ipynb | 1 + demo/native_plots/run.py | 12 ++ demo/native_plots/scatter_plot_demo.py | 45 ++++++++ demo/scatterplot_component/requirements.txt | 1 + demo/scatterplot_component/run.ipynb | 1 + demo/scatterplot_component/run.py | 18 +++ gradio/components.py | 118 +++++++++++++++----- scripts/copy_demos.py | 1 + ui/packages/plot/src/Plot.svelte | 2 - 10 files changed, 170 insertions(+), 30 deletions(-) create mode 100644 demo/native_plots/requirements.txt create mode 100644 demo/native_plots/run.ipynb create mode 100644 demo/native_plots/run.py create mode 100644 demo/native_plots/scatter_plot_demo.py create mode 100644 demo/scatterplot_component/requirements.txt create mode 100644 demo/scatterplot_component/run.ipynb create mode 100644 demo/scatterplot_component/run.py diff --git a/demo/native_plots/requirements.txt b/demo/native_plots/requirements.txt new file mode 100644 index 000000000000..d1c8a7ae0396 --- /dev/null +++ b/demo/native_plots/requirements.txt @@ -0,0 +1 @@ +vega_datasets \ No newline at end of file diff --git a/demo/native_plots/run.ipynb b/demo/native_plots/run.ipynb new file mode 100644 index 000000000000..4d3070de8e8e --- /dev/null +++ b/demo/native_plots/run.ipynb @@ -0,0 +1 @@ +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: native_plots"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio vega_datasets"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/scatter_plot_demo.py"]}, {"cell_type": "code", "execution_count": null, "id": 44380577570523278879349135829904343037, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "from scatter_plot_demo import scatter_plot\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Tabs():\n", " with gr.TabItem(\"Scatter Plot\"):\n", " scatter_plot.render()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/native_plots/run.py b/demo/native_plots/run.py new file mode 100644 index 000000000000..a0194d2917a4 --- /dev/null +++ b/demo/native_plots/run.py @@ -0,0 +1,12 @@ +import gradio as gr + +from scatter_plot_demo import scatter_plot + + +with gr.Blocks() as demo: + with gr.Tabs(): + with gr.TabItem("Scatter Plot"): + scatter_plot.render() + +if __name__ == "__main__": + demo.launch() diff --git a/demo/native_plots/scatter_plot_demo.py b/demo/native_plots/scatter_plot_demo.py new file mode 100644 index 000000000000..c428ae32a998 --- /dev/null +++ b/demo/native_plots/scatter_plot_demo.py @@ -0,0 +1,45 @@ +import gradio as gr + +from vega_datasets import data + +cars = data.cars() +iris = data.iris() + + +def scatter_plot_fn(dataset): + if dataset == "iris": + return gr.ScatterPlot.update( + value=iris, + x="petalWidth", + y="petalLength", + color="species", + title="Iris Dataset", + color_legend_title="Species", + x_title="Petal Width", + y_title="Petal Length", + tooltip=["petalWidth", "petalLength", "species"], + ) + else: + return gr.ScatterPlot.update( + value=cars, + x="Horsepower", + y="Miles_per_Gallon", + color="Origin", + tooltip="Name", + title="Car Data", + y_title="Miles per Gallon", + color_legend_title="Origin of Car", + ) + + +with gr.Blocks() as scatter_plot: + with gr.Row(): + with gr.Column(): + dataset = gr.Dropdown(choices=["cars", "iris"], value="cars") + with gr.Column(): + plot = gr.ScatterPlot(show_label=False).style(container=True) + dataset.change(scatter_plot_fn, inputs=dataset, outputs=plot) + scatter_plot.load(fn=scatter_plot_fn, inputs=dataset, outputs=plot) + +if __name__ == "__main__": + scatter_plot.launch() diff --git a/demo/scatterplot_component/requirements.txt b/demo/scatterplot_component/requirements.txt new file mode 100644 index 000000000000..d1c8a7ae0396 --- /dev/null +++ b/demo/scatterplot_component/requirements.txt @@ -0,0 +1 @@ +vega_datasets \ No newline at end of file diff --git a/demo/scatterplot_component/run.ipynb b/demo/scatterplot_component/run.ipynb new file mode 100644 index 000000000000..75d7a60528ab --- /dev/null +++ b/demo/scatterplot_component/run.ipynb @@ -0,0 +1 @@ +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: scatterplot_component"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio vega_datasets"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from vega_datasets import data\n", "\n", "cars = data.cars()\n", "\n", "with gr.Blocks() as demo:\n", " gr.ScatterPlot(show_label=False,\n", " value=cars,\n", " x=\"Horsepower\",\n", " y=\"Miles_per_Gallon\",\n", " color=\"Origin\",\n", " tooltip=\"Name\",\n", " title=\"Car Data\",\n", " y_title=\"Miles per Gallon\",\n", " color_legend_title=\"Origin of Car\").style(container=False)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/scatterplot_component/run.py b/demo/scatterplot_component/run.py new file mode 100644 index 000000000000..4b005c639021 --- /dev/null +++ b/demo/scatterplot_component/run.py @@ -0,0 +1,18 @@ +import gradio as gr +from vega_datasets import data + +cars = data.cars() + +with gr.Blocks() as demo: + gr.ScatterPlot(show_label=False, + value=cars, + x="Horsepower", + y="Miles_per_Gallon", + color="Origin", + tooltip="Name", + title="Car Data", + y_title="Miles per Gallon", + color_legend_title="Origin of Car").style(container=False) + +if __name__ == "__main__": + demo.launch() \ No newline at end of file diff --git a/gradio/components.py b/gradio/components.py index be3d994f662b..1ff521bb4434 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -3950,14 +3950,14 @@ def style(self, container: Optional[bool] = None): class ScatterPlot(Plot): def __init__( self, - x: str, - y: str, + x: Optional[str] = None, + y: Optional[str] = None, value: Optional[pd.DataFrame] = None, color: Optional[str] = None, size: Optional[str] = None, shape: Optional[str] = None, title: Optional[str] = None, - tooltip: Optional[str] = None, + tooltip: Optional[List[str] | str] = None, x_title: Optional[str] = None, y_title: Optional[str] = None, color_legend_title: Optional[str] = None, @@ -3990,57 +3990,119 @@ def __init__( def get_block_name(self) -> str: return "plot" - def postprocess(self, y: pd.DataFrame | None) -> Dict[str, str] | None: - if y is None: - return y + @staticmethod + def update( + value: Optional[Any] = _Keywords.NO_VALUE, + x: Optional[str] = None, + y: Optional[str] = None, + color: Optional[str] = None, + size: Optional[str] = None, + shape: Optional[str] = None, + title: Optional[str] = None, + tooltip: Optional[List[str] | str] = None, + x_title: Optional[str] = None, + y_title: Optional[str] = None, + color_legend_title: Optional[str] = None, + size_legend_title: Optional[str] = None, + shape_legend_title: Optional[str] = None, + label: Optional[str] = None, + show_label: Optional[bool] = None, + visible: Optional[bool] = None, + ): + properties = [x, y, color, size, shape, title, tooltip, x_title, y_title, + color_legend_title, size_legend_title, shape_legend_title] + if any(properties): + if value is _Keywords.NO_VALUE: + raise ValueError("In order to update plot properties the value parameter " + "must be provided.") + + chart = ScatterPlot.create_plot(value, *properties) + new_chart_str = {"type": "altair", "plot": chart.to_json(), "chart": "scatter"} + + updated_config = { + "label": label, + "show_label": show_label, + "visible": visible, + "value": new_chart_str, + "__type__": "update", + } + return updated_config + + @staticmethod + def create_plot( + value: pd.DataFrame, + x: str, + y: str, + color: Optional[str] = None, + size: Optional[str] = None, + shape: Optional[str] = None, + title: Optional[str] = None, + tooltip: Optional[List[str] | str] = None, + x_title: Optional[str] = None, + y_title: Optional[str] = None, + color_legend_title: Optional[str] = None, + size_legend_title: Optional[str] = None, + shape_legend_title: Optional[str] = None + ): + encodings = dict( - x=alt.X(self.x, title=self.x_title or self.x), - y=alt.Y(self.y, title=self.y_title or self.y), + x=alt.X(x, title=x_title or x), + y=alt.Y(y, title=y_title or y), ) properties = {} - if self.title: - properties["title"] = self.title - if self.color: - if is_numeric_dtype(y[self.color]): - domain = [y[self.color].min(), y[self.color].max()] + if title: + properties["title"] = title + if color: + if is_numeric_dtype(value[color]): + domain = [value[color].min(), value[color].max()] range_ = [0, 1] type_ = "quantitative" else: - domain = y[self.color].unique().tolist() + domain = value[color].unique().tolist() range_ = list(range(len(domain))) type_ = "nominal" encodings["color"] = { - "field": self.color, + "field": color, "type": type_, - "legend": {"title": self.color_legend_title or self.color}, + "legend": {"title": color_legend_title or color}, "scale": {"domain": domain, "range": range_}, } - if self.tooltip: - encodings["tooltip"] = self.tooltip - if self.size: + if tooltip: + encodings["tooltip"] = tooltip + if size: encodings["size"] = { - "field": self.size, - "type": "quantitative" if is_numeric_dtype(y[self.size]) else "nominal", - "legend": {"title": self.size_legend_title or self.size}, + "field": size, + "type": "quantitative" if is_numeric_dtype(value[size]) else "nominal", + "legend": {"title": size_legend_title or size}, } - if self.shape: + if shape: encodings["shape"] = { - "field": self.shape, + "field": shape, "type": "quantitative" - if is_numeric_dtype(y[self.shape]) + if is_numeric_dtype(value[shape]) else "nominal", - "legend": {"title": self.shape_legend_title or self.shape}, + "legend": {"title": shape_legend_title or shape}, } - chart = ( - alt.Chart(y) + return ( + alt.Chart(value) .mark_point() .encode(**encodings) .properties(background="transparent", **properties) .interactive() ) + def postprocess(self, y: pd.DataFrame | Dict | None) -> Dict[str, str] | None: + # if None or update + if y is None or isinstance(y, Dict): + return y + chart = self.create_plot(value=y, x=self.x, y=self.y, color=self.color, + size=self.size, shape=self.shape, title=self.title, + tooltip=self.tooltip, x_title=self.x_title, y_title=self.y_title, + color_legend_title=self.color_legend_title, size_legend_title=self.size_legend_title, + shape_legend_title=self.size_legend_title) + return {"type": "altair", "plot": chart.to_json(), "chart": "scatter"} diff --git a/scripts/copy_demos.py b/scripts/copy_demos.py index efd59cd494c9..a33e88f50a77 100644 --- a/scripts/copy_demos.py +++ b/scripts/copy_demos.py @@ -28,6 +28,7 @@ def copy_all_demos(source_dir: str, dest_dir: str): "kitchen_sink_random", "matrix_transpose", "model3D", + "native_plots", "reset_components", "reverse_audio", "stt_or_tts", diff --git a/ui/packages/plot/src/Plot.svelte b/ui/packages/plot/src/Plot.svelte index 82a848ff9c73..8d1a53b668a6 100644 --- a/ui/packages/plot/src/Plot.svelte +++ b/ui/packages/plot/src/Plot.svelte @@ -33,7 +33,6 @@ spec = JSON.parse(value['plot']) const config = create_config(darkmode); spec['config'] = config; - console.log(spec); switch (value['chart'] || '') { case "scatter": if (spec['encoding']['color'] && spec['encoding']['color']['type'] == 'nominal') { @@ -42,7 +41,6 @@ else if (spec['encoding']['color'] && spec['encoding']['color']['type'] == 'quantitative') { spec['encoding']['color']['scale']['range'] = ['#eff6ff', '#1e3a8a']; spec['encoding']['color']['scale']['interpolate'] = "hsl"; - console.log(spec); } break; default: From c4006e898321ed30c1aa53bc146a94b7dbbd9a4d Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Wed, 7 Dec 2022 12:42:53 -0500 Subject: [PATCH 17/30] Handle darkmode better --- ui/packages/app/src/Blocks.svelte | 2 ++ ui/packages/app/src/Render.svelte | 3 ++ .../app/src/components/Plot/Plot.svelte | 3 +- ui/packages/app/src/main.ts | 30 ++++++++++++------- ui/packages/plot/src/Plot.svelte | 3 +- 5 files changed, 29 insertions(+), 12 deletions(-) diff --git a/ui/packages/app/src/Blocks.svelte b/ui/packages/app/src/Blocks.svelte index 57f1de65a783..902b9acadc87 100644 --- a/ui/packages/app/src/Blocks.svelte +++ b/ui/packages/app/src/Blocks.svelte @@ -40,6 +40,7 @@ export let show_api: boolean = true; export let control_page_title = false; export let app_mode: boolean; + export let theme: string; let loading_status = create_loading_status_store(); @@ -419,6 +420,7 @@ {instance_map} {root} {target} + {theme} on:mount={handle_mount} on:destroy={({ detail }) => handle_destroy(detail)} /> diff --git a/ui/packages/app/src/Render.svelte b/ui/packages/app/src/Render.svelte index 87adf5693d5b..2e070e24e8b3 100644 --- a/ui/packages/app/src/Render.svelte +++ b/ui/packages/app/src/Render.svelte @@ -14,6 +14,7 @@ export let has_modes: boolean | undefined; export let parent: string | null = null; export let target: HTMLElement; + export let theme: string; const dispatch = createEventDispatcher<{ mount: number; destroy: number }>(); @@ -56,6 +57,7 @@ on:prop_change={handle_prop_change} {target} {...props} + {theme} {root} > {#if children && children.length} @@ -70,6 +72,7 @@ children={_children} {dynamic_ids} {has_modes} + {theme} on:destroy on:mount /> diff --git a/ui/packages/app/src/components/Plot/Plot.svelte b/ui/packages/app/src/components/Plot/Plot.svelte index 4649ecf7c807..87326b6ed8fc 100644 --- a/ui/packages/app/src/components/Plot/Plot.svelte +++ b/ui/packages/app/src/components/Plot/Plot.svelte @@ -18,6 +18,7 @@ export let show_label: boolean; export let target: HTMLElement; export let style: Styles = {}; + export let theme: string; - + diff --git a/ui/packages/app/src/main.ts b/ui/packages/app/src/main.ts index 1d498f1cb82e..69b177a592a2 100644 --- a/ui/packages/app/src/main.ts +++ b/ui/packages/app/src/main.ts @@ -192,6 +192,7 @@ function create_custom_element() { root: ShadowRoot; wrapper: HTMLDivElement; _id: number; + theme: string; constructor() { super(); @@ -210,6 +211,7 @@ function create_custom_element() { this.wrapper.style.position = "relative"; this.wrapper.style.width = "100%"; this.wrapper.style.minHeight = "100vh"; + this.theme = "light"; window.__gradio_loader__[this._id] = new Loader({ target: this.wrapper, @@ -223,7 +225,7 @@ function create_custom_element() { this.root.append(this.wrapper); if (window.__gradio_mode__ !== "website") { - handle_darkmode(this.wrapper); + this.theme = handle_darkmode(this.wrapper); } } @@ -260,6 +262,7 @@ function create_custom_element() { const _autoscroll = autoscroll === "true" ? true : false; this.wrapper.style.minHeight = initial_height || "300px"; + console.log(this.theme); const config = await handle_config(this.root, source); if (config === null) { @@ -268,6 +271,7 @@ function create_custom_element() { mount_app( { ...config, + theme: this.theme, control_page_title: control_page_title && control_page_title === "true" ? true : false }, @@ -305,8 +309,9 @@ async function unscoped_mount() { mount_app({ ...config, control_page_title: true }, false, target, 0); } -function handle_darkmode(target: HTMLDivElement) { +function handle_darkmode(target: HTMLDivElement): string { let url = new URL(window.location.toString()); + let theme = "light"; const color_mode: "light" | "dark" | "system" | null = url.searchParams.get( "__theme" @@ -314,39 +319,44 @@ function handle_darkmode(target: HTMLDivElement) { if (color_mode !== null) { if (color_mode === "dark") { - darkmode(target); + theme = darkmode(target); } else if (color_mode === "system") { - use_system_theme(target); + theme = use_system_theme(target); } // light is default, so we don't need to do anything else } else if (url.searchParams.get("__dark-theme") === "true") { - darkmode(target); + theme = darkmode(target); } else { - use_system_theme(target); + theme = use_system_theme(target); } + return theme; } -function use_system_theme(target: HTMLDivElement) { - update_scheme(); +function use_system_theme(target: HTMLDivElement): string { + const theme = update_scheme(); window ?.matchMedia("(prefers-color-scheme: dark)") ?.addEventListener("change", update_scheme); function update_scheme() { + let theme = "light"; const is_dark = window?.matchMedia?.("(prefers-color-scheme: dark)").matches ?? null; if (is_dark) { - darkmode(target); + theme = darkmode(target); } + return theme; } + return theme; } -function darkmode(target: HTMLDivElement) { +function darkmode(target: HTMLDivElement): string { target.classList.add("dark"); if (app_mode) { document.body.style.backgroundColor = "rgb(11, 15, 25)"; // bg-gray-950 for scrolling outside the body } + return "dark"; } // dev mode or if inside an iframe diff --git a/ui/packages/plot/src/Plot.svelte b/ui/packages/plot/src/Plot.svelte index 8d1a53b668a6..5826b2c73597 100644 --- a/ui/packages/plot/src/Plot.svelte +++ b/ui/packages/plot/src/Plot.svelte @@ -12,6 +12,7 @@ export let target; let spec = null; export let colors: Array = []; + export let theme: string; function get_color(index: number) { let current_color = colors[index % colors.length]; @@ -27,7 +28,7 @@ } } - $: darkmode = window.matchMedia && window.matchMedia('(prefers-color-scheme: dark)').matches + $: darkmode = theme == "dark" $: if(value && value['type'] == "altair") { spec = JSON.parse(value['plot']) From 6dbf58cec9111e6318c8ccd3968d58a21c0115f5 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Wed, 7 Dec 2022 12:55:35 -0500 Subject: [PATCH 18/30] Try new font --- ui/packages/plot/src/utils.ts | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ui/packages/plot/src/utils.ts b/ui/packages/plot/src/utils.ts index db947a2868b4..85e72b2cf37e 100644 --- a/ui/packages/plot/src/utils.ts +++ b/ui/packages/plot/src/utils.ts @@ -4,22 +4,22 @@ import tw_colors from "tailwindcss/colors"; export function create_config(darkmode: boolean): VegaConfig { return { axis: { - labelFont: "monospace", + labelFont: "font-sans", labelColor: darkmode ? tw_colors.slate["200"] : "black", - titleFont: "monospace", + titleFont: "font-sans", titleColor: darkmode ? tw_colors.slate["200"] : "black", tickColor: "#aaa", gridColor: "#aaa" }, legend: { labelColor: darkmode ? tw_colors.slate["200"] : "black", - labelFont: "monospace", + labelFont: "font-sans", titleColor: darkmode ? tw_colors.slate["200"] : "black", - titleFont: "monospace" + titleFont: "font-sans" }, title: { color: darkmode ? tw_colors.slate["200"] : "black", - font: "monospace" + font: "font-sans" } }; } From 7401b803b7277e0681fab40b9c2020910afe537f Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Wed, 7 Dec 2022 13:01:38 -0500 Subject: [PATCH 19/30] Use sans-serif --- ui/packages/plot/src/utils.ts | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ui/packages/plot/src/utils.ts b/ui/packages/plot/src/utils.ts index 85e72b2cf37e..afc20499a7b0 100644 --- a/ui/packages/plot/src/utils.ts +++ b/ui/packages/plot/src/utils.ts @@ -4,22 +4,22 @@ import tw_colors from "tailwindcss/colors"; export function create_config(darkmode: boolean): VegaConfig { return { axis: { - labelFont: "font-sans", + labelFont: "sans-serif", labelColor: darkmode ? tw_colors.slate["200"] : "black", - titleFont: "font-sans", + titleFont: "sans-serif", titleColor: darkmode ? tw_colors.slate["200"] : "black", tickColor: "#aaa", gridColor: "#aaa" }, legend: { labelColor: darkmode ? tw_colors.slate["200"] : "black", - labelFont: "font-sans", + labelFont: "sans-serif", titleColor: darkmode ? tw_colors.slate["200"] : "black", - titleFont: "font-sans" + titleFont: "sans-serif" }, title: { color: darkmode ? tw_colors.slate["200"] : "black", - font: "font-sans" + font: "sans-serif" } }; } From 3e1f397d3f01af2bcd0783142cc201d0c3bcb947 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Wed, 7 Dec 2022 14:43:35 -0500 Subject: [PATCH 20/30] Add caption --- demo/native_plots/scatter_plot_demo.py | 2 + gradio/components.py | 64 +++++++++++++++---- test/test_blocks.py | 8 ++- .../app/src/components/Plot/Plot.svelte | 3 +- ui/packages/plot/src/Plot.svelte | 11 +++- 5 files changed, 69 insertions(+), 19 deletions(-) diff --git a/demo/native_plots/scatter_plot_demo.py b/demo/native_plots/scatter_plot_demo.py index c428ae32a998..223a0110df4f 100644 --- a/demo/native_plots/scatter_plot_demo.py +++ b/demo/native_plots/scatter_plot_demo.py @@ -18,6 +18,7 @@ def scatter_plot_fn(dataset): x_title="Petal Width", y_title="Petal Length", tooltip=["petalWidth", "petalLength", "species"], + caption="", ) else: return gr.ScatterPlot.update( @@ -29,6 +30,7 @@ def scatter_plot_fn(dataset): title="Car Data", y_title="Miles per Gallon", color_legend_title="Origin of Car", + caption="MPG vs Horsepower of various cars" ) diff --git a/gradio/components.py b/gradio/components.py index 1ff521bb4434..36bac2c34458 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -3963,6 +3963,7 @@ def __init__( color_legend_title: Optional[str] = None, size_legend_title: Optional[str] = None, shape_legend_title: Optional[str] = None, + caption: Optional[str] = None, label: Optional[str] = None, show_label: bool = True, visible: bool = True, @@ -3980,6 +3981,7 @@ def __init__( self.color_legend_title = color_legend_title self.size_legend_title = size_legend_title self.shape_legend_title = shape_legend_title + self.caption = caption self.value = None if value is not None: self.value = self.postprocess(value) @@ -3987,6 +3989,11 @@ def __init__( value, label=label, show_label=show_label, visible=visible, elem_id=elem_id ) + def get_config(self): + config = super().get_config() + config["caption"] = self.caption + return config + def get_block_name(self) -> str: return "plot" @@ -4005,16 +4012,38 @@ def update( color_legend_title: Optional[str] = None, size_legend_title: Optional[str] = None, shape_legend_title: Optional[str] = None, + caption: Optional[str] = None, label: Optional[str] = None, show_label: Optional[bool] = None, visible: Optional[bool] = None, ): - properties = [x, y, color, size, shape, title, tooltip, x_title, y_title, - color_legend_title, size_legend_title, shape_legend_title] + properties = [ + x, + y, + color, + size, + shape, + title, + tooltip, + x_title, + y_title, + color_legend_title, + size_legend_title, + shape_legend_title, + ] if any(properties): if value is _Keywords.NO_VALUE: - raise ValueError("In order to update plot properties the value parameter " - "must be provided.") + raise ValueError( + "In order to update plot properties the value parameter " + "must be provided. Please pass a value parameter to " + "gr.ScatterPlot.update." + ) + if x is None and y is None: + raise ValueError( + "In order to update plot properties, the x and y axis data " + "must be specified. Please pass valid values for x an y to " + "gr.ScatterPlot.update." + ) chart = ScatterPlot.create_plot(value, *properties) new_chart_str = {"type": "altair", "plot": chart.to_json(), "chart": "scatter"} @@ -4024,6 +4053,7 @@ def update( "show_label": show_label, "visible": visible, "value": new_chart_str, + "caption": caption, "__type__": "update", } return updated_config @@ -4042,7 +4072,7 @@ def create_plot( y_title: Optional[str] = None, color_legend_title: Optional[str] = None, size_legend_title: Optional[str] = None, - shape_legend_title: Optional[str] = None + shape_legend_title: Optional[str] = None, ): encodings = dict( @@ -4079,9 +4109,7 @@ def create_plot( if shape: encodings["shape"] = { "field": shape, - "type": "quantitative" - if is_numeric_dtype(value[shape]) - else "nominal", + "type": "quantitative" if is_numeric_dtype(value[shape]) else "nominal", "legend": {"title": shape_legend_title or shape}, } @@ -4097,11 +4125,21 @@ def postprocess(self, y: pd.DataFrame | Dict | None) -> Dict[str, str] | None: # if None or update if y is None or isinstance(y, Dict): return y - chart = self.create_plot(value=y, x=self.x, y=self.y, color=self.color, - size=self.size, shape=self.shape, title=self.title, - tooltip=self.tooltip, x_title=self.x_title, y_title=self.y_title, - color_legend_title=self.color_legend_title, size_legend_title=self.size_legend_title, - shape_legend_title=self.size_legend_title) + chart = self.create_plot( + value=y, + x=self.x, + y=self.y, + color=self.color, + size=self.size, + shape=self.shape, + title=self.title, + tooltip=self.tooltip, + x_title=self.x_title, + y_title=self.y_title, + color_legend_title=self.color_legend_title, + size_legend_title=self.size_legend_title, + shape_legend_title=self.size_legend_title, + ) return {"type": "altair", "plot": chart.to_json(), "chart": "scatter"} diff --git a/test/test_blocks.py b/test/test_blocks.py index 818496ad65b8..6d5225d6b2c0 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -294,7 +294,9 @@ def test_slider_random_value_config(self): assert not any([dep["queue"] for dep in demo.config["dependencies"]]) def test_io_components_attach_load_events_when_value_is_fn(self, io_components): - io_components = [comp for comp in io_components if not (comp == gr.State)] + io_components = [ + comp for comp in io_components if comp not in [gr.State, gr.ScatterPlot] + ] interface = gr.Interface( lambda *args: None, inputs=[comp(value=lambda: None) for comp in io_components], @@ -307,7 +309,9 @@ def test_io_components_attach_load_events_when_value_is_fn(self, io_components): assert len(dependencies_on_load) == len(io_components) def test_blocks_do_not_filter_none_values_from_updates(self, io_components): - io_components = [c() for c in io_components if c not in [gr.State, gr.Button]] + io_components = [ + c() for c in io_components if c not in [gr.State, gr.Button, gr.ScatterPlot] + ] with gr.Blocks() as demo: for component in io_components: component.render() diff --git a/ui/packages/app/src/components/Plot/Plot.svelte b/ui/packages/app/src/components/Plot/Plot.svelte index 87326b6ed8fc..1672d62783ce 100644 --- a/ui/packages/app/src/components/Plot/Plot.svelte +++ b/ui/packages/app/src/components/Plot/Plot.svelte @@ -19,6 +19,7 @@ export let target: HTMLElement; export let style: Styles = {}; export let theme: string; + export let caption: string; - + diff --git a/ui/packages/plot/src/Plot.svelte b/ui/packages/plot/src/Plot.svelte index 5826b2c73597..a5792acf7e33 100644 --- a/ui/packages/plot/src/Plot.svelte +++ b/ui/packages/plot/src/Plot.svelte @@ -13,6 +13,7 @@ let spec = null; export let colors: Array = []; export let theme: string; + export let caption: string; function get_color(index: number) { let current_color = colors[index % colors.length]; @@ -49,7 +50,6 @@ } } - // Plotly let plotDiv; let plotlyGlobalStyle; @@ -152,8 +152,13 @@ {:else if value && value["type"] == "bokeh"}
{:else if value && value['type'] == "altair"} -
- +
+ + {#if caption} +
+ {caption} +
+ {/if}
{:else if value && value["type"] == "matplotlib"}
From 611e764cd244866cccf4a3fbef5935994b794757 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Wed, 7 Dec 2022 15:41:55 -0500 Subject: [PATCH 21/30] Changelog + tests --- CHANGELOG.md | 35 +++++++++++++++++++ test/test_components.py | 74 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d34202593a2..67574186b9fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,41 @@ ## New Features: +### Scatter plot component + +It is now possible to create a scatter plot without knowledge of a plotting library! + +The `gr.ScatterPlot` component accepts a pandas dataframe and some optional configuration parameters +and will automatically create a plot for you! + +This is the first of many native plotting components in Gradio! + +For an example of how to use `gr.ScatterPlot` see below: + +```python +import gradio as gr +from vega_datasets import data + +cars = data.cars() + +with gr.Blocks() as demo: + gr.ScatterPlot(show_label=False, + value=cars, + x="Horsepower", + y="Miles_per_Gallon", + color="Origin", + tooltip="Name", + title="Car Data", + y_title="Miles per Gallon", + color_legend_title="Origin of Car").style(container=False) + +demo.launch() +``` + +By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2764](https://github.com/gradio-app/gradio/pull/2764) + + + ### Support for altair plots The `Plot` component can now accept altair plots as values! diff --git a/test/test_components.py b/test/test_components.py index 1bb3676ab22a..393f22877502 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -20,6 +20,7 @@ import pandas as pd import PIL import pytest +import vega_datasets from scipy.io import wavfile import gradio as gr @@ -1847,3 +1848,76 @@ def test_dataset_calls_as_example(*mocks): ], ) assert all([m.called for m in mocks]) + + +cars = vega_datasets.data.cars() + + +class TestScatterPlot: + def test_get_config(self): + assert gr.ScatterPlot().get_config() == { + "caption": None, + "elem_id": None, + "interactive": None, + "label": None, + "name": "plot", + "root_url": None, + "show_label": True, + "style": {}, + "value": None, + "visible": True, + } + + def test_no_color(self): + plot = gr.ScatterPlot( + x="Horsepower", + y="Miles_per_Gallon", + tooltip="Name", + title="Car Data", + x_title="Horse", + ) + output = plot.postprocess(cars) + assert sorted(list(output.keys())) == ["chart", "plot", "type"] + config = json.loads(output["plot"]) + assert config["encoding"]["x"]["field"] == "Horsepower" + assert config["encoding"]["x"]["title"] == "Horse" + assert config["encoding"]["y"]["field"] == "Miles_per_Gallon" + assert config["title"] == "Car Data" + + def test_color_encoding(self): + plot = gr.ScatterPlot( + x="Horsepower", + y="Miles_per_Gallon", + tooltip="Name", + title="Car Data", + color="Origin", + ) + output = plot.postprocess(cars) + config = json.loads(output["plot"]) + assert config["encoding"]["color"]["field"] == "Origin" + assert config["encoding"]["color"]["scale"] == { + "domain": ["USA", "Europe", "Japan"], + "range": [0, 1, 2], + } + assert config["encoding"]["color"]["type"] == "nominal" + + def test_two_encodings(self): + plot = gr.ScatterPlot( + show_label=False, + title="Two encodings", + x="Horsepower", + y="Miles_per_Gallon", + color="Acceleration", + shape="Origin", + ) + output = plot.postprocess(cars) + config = json.loads(output["plot"]) + assert config["encoding"]["color"]["field"] == "Acceleration" + assert config["encoding"]["color"]["scale"] == { + "domain": [cars.Acceleration.min(), cars.Acceleration.max()], + "range": [0, 1], + } + assert config["encoding"]["color"]["type"] == "quantitative" + + assert config["encoding"]["shape"]["field"] == "Origin" + assert config["encoding"]["shape"]["type"] == "nominal" From 4ef2d9170279e8da76df606dddd0c24879cb998a Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Wed, 7 Dec 2022 15:48:57 -0500 Subject: [PATCH 22/30] More tests --- gradio/components.py | 2 +- test/test_components.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/gradio/components.py b/gradio/components.py index 36bac2c34458..4f006403a636 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -4038,7 +4038,7 @@ def update( "must be provided. Please pass a value parameter to " "gr.ScatterPlot.update." ) - if x is None and y is None: + if x is None or y is None: raise ValueError( "In order to update plot properties, the x and y axis data " "must be specified. Please pass valid values for x an y to " diff --git a/test/test_components.py b/test/test_components.py index 393f22877502..659567efc333 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -1921,3 +1921,20 @@ def test_two_encodings(self): assert config["encoding"]["shape"]["field"] == "Origin" assert config["encoding"]["shape"]["type"] == "nominal" + + def test_update(self): + output = gr.ScatterPlot.update(value=cars, x="Horsepower", y="Miles_per_Gallon") + postprocessed = gr.ScatterPlot().postprocess(output["value"]) + assert postprocessed == output["value"] + + def test_update_errors(self): + with pytest.raises( + ValueError, match="In order to update plot properties the value parameter" + ): + gr.ScatterPlot.update(x="foo", y="bar") + + with pytest.raises( + ValueError, + match="In order to update plot properties, the x and y axis data", + ): + gr.ScatterPlot.update(value=cars, x="foo") From d56eb1087bcf7fa096d781f88392714cb4980382 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Thu, 8 Dec 2022 11:51:25 -0500 Subject: [PATCH 23/30] Address comments --- ui/packages/app/src/main.ts | 1 - ui/packages/plot/src/Plot.svelte | 16 ++++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/ui/packages/app/src/main.ts b/ui/packages/app/src/main.ts index 69b177a592a2..f6ce362edd13 100644 --- a/ui/packages/app/src/main.ts +++ b/ui/packages/app/src/main.ts @@ -262,7 +262,6 @@ function create_custom_element() { const _autoscroll = autoscroll === "true" ? true : false; this.wrapper.style.minHeight = initial_height || "300px"; - console.log(this.theme); const config = await handle_config(this.root, source); if (config === null) { diff --git a/ui/packages/plot/src/Plot.svelte b/ui/packages/plot/src/Plot.svelte index a5792acf7e33..27a542ef8515 100644 --- a/ui/packages/plot/src/Plot.svelte +++ b/ui/packages/plot/src/Plot.svelte @@ -31,18 +31,18 @@ $: darkmode = theme == "dark" - $: if(value && value['type'] == "altair") { + $: if(value && value.type == "altair") { spec = JSON.parse(value['plot']) const config = create_config(darkmode); - spec['config'] = config; - switch (value['chart'] || '') { + spec.config = config; + switch (value.chart || '') { case "scatter": - if (spec['encoding']['color'] && spec['encoding']['color']['type'] == 'nominal') { - spec['encoding']['color']['scale']['range'] = spec['encoding']['color']['scale']['range'].map((e, i) => get_color(i)); + if (spec.encoding.color && spec.encoding.color.type == 'nominal') { + spec.encoding.color.scale.range = spec.encoding.color.scale.range.map((e, i) => get_color(i)); } - else if (spec['encoding']['color'] && spec['encoding']['color']['type'] == 'quantitative') { - spec['encoding']['color']['scale']['range'] = ['#eff6ff', '#1e3a8a']; - spec['encoding']['color']['scale']['interpolate'] = "hsl"; + else if (spec.encoding.color && spec.encoding.color.type == 'quantitative') { + spec.encoding.color.scale.range = ['#eff6ff', '#1e3a8a']; + spec.encoding.color.scale.range.interpolate = "hsl"; } break; default: From 08ffebbd673aba0546343e57f75508a287131174 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Thu, 8 Dec 2022 14:23:32 -0500 Subject: [PATCH 24/30] Make caption fontsize smaller and enable interactivity --- gradio/components.py | 21 +++++++++++++++++---- test/test_components.py | 18 +++++++++++++++++- ui/packages/plot/src/Plot.svelte | 2 +- 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/gradio/components.py b/gradio/components.py index 4f006403a636..f5f9f3db3172 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -3964,6 +3964,7 @@ def __init__( size_legend_title: Optional[str] = None, shape_legend_title: Optional[str] = None, caption: Optional[str] = None, + interactive: Optional[bool] = True, label: Optional[str] = None, show_label: bool = True, visible: bool = True, @@ -3982,11 +3983,17 @@ def __init__( self.size_legend_title = size_legend_title self.shape_legend_title = shape_legend_title self.caption = caption + self.interactive = interactive self.value = None if value is not None: self.value = self.postprocess(value) super().__init__( - value, label=label, show_label=show_label, visible=visible, elem_id=elem_id + value, + label=label, + show_label=show_label, + visible=visible, + elem_id=elem_id, + interactive=interactive, ) def get_config(self): @@ -4012,6 +4019,7 @@ def update( color_legend_title: Optional[str] = None, size_legend_title: Optional[str] = None, shape_legend_title: Optional[str] = None, + interactive: Optional[bool] = True, caption: Optional[str] = None, label: Optional[str] = None, show_label: Optional[bool] = None, @@ -4030,6 +4038,7 @@ def update( color_legend_title, size_legend_title, shape_legend_title, + interactive, ] if any(properties): if value is _Keywords.NO_VALUE: @@ -4073,6 +4082,7 @@ def create_plot( color_legend_title: Optional[str] = None, size_legend_title: Optional[str] = None, shape_legend_title: Optional[str] = None, + interactive: Optional[bool] = True, ): encodings = dict( @@ -4112,14 +4122,16 @@ def create_plot( "type": "quantitative" if is_numeric_dtype(value[shape]) else "nominal", "legend": {"title": shape_legend_title or shape}, } - - return ( + chart = ( alt.Chart(value) .mark_point() .encode(**encodings) .properties(background="transparent", **properties) - .interactive() ) + if interactive: + chart = chart.interactive() + + return chart def postprocess(self, y: pd.DataFrame | Dict | None) -> Dict[str, str] | None: # if None or update @@ -4139,6 +4151,7 @@ def postprocess(self, y: pd.DataFrame | Dict | None) -> Dict[str, str] | None: color_legend_title=self.color_legend_title, size_legend_title=self.size_legend_title, shape_legend_title=self.size_legend_title, + interactive=self.interactive, ) return {"type": "altair", "plot": chart.to_json(), "chart": "scatter"} diff --git a/test/test_components.py b/test/test_components.py index 659567efc333..54d5f0612afd 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -1858,7 +1858,7 @@ def test_get_config(self): assert gr.ScatterPlot().get_config() == { "caption": None, "elem_id": None, - "interactive": None, + "interactive": True, "label": None, "name": "plot", "root_url": None, @@ -1882,8 +1882,24 @@ def test_no_color(self): assert config["encoding"]["x"]["field"] == "Horsepower" assert config["encoding"]["x"]["title"] == "Horse" assert config["encoding"]["y"]["field"] == "Miles_per_Gallon" + assert config["selection"] == { + "selector001": { + "bind": "scales", + "encodings": ["x", "y"], + "type": "interval", + } + } assert config["title"] == "Car Data" + def test_no_interactive(self): + plot = gr.ScatterPlot( + x="Horsepower", y="Miles_per_Gallon", tooltip="Name", interactive=False + ) + output = plot.postprocess(cars) + assert sorted(list(output.keys())) == ["chart", "plot", "type"] + config = json.loads(output["plot"]) + assert "selection" not in config + def test_color_encoding(self): plot = gr.ScatterPlot( x="Horsepower", diff --git a/ui/packages/plot/src/Plot.svelte b/ui/packages/plot/src/Plot.svelte index 27a542ef8515..e6cc174e8775 100644 --- a/ui/packages/plot/src/Plot.svelte +++ b/ui/packages/plot/src/Plot.svelte @@ -155,7 +155,7 @@
{#if caption} -
+
{caption}
{/if} From 130ee756b498473f4c5fe727adcd6f92c52677b9 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Thu, 8 Dec 2022 15:43:23 -0500 Subject: [PATCH 25/30] Add docstrings + add height + width --- gradio/components.py | 92 +++++++++++++++++++++++++++++++++++++---- test/test_components.py | 19 ++++++++- 2 files changed, 101 insertions(+), 10 deletions(-) diff --git a/gradio/components.py b/gradio/components.py index f5f9f3db3172..2d89cd45b009 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -3948,6 +3948,15 @@ def style(self, container: Optional[bool] = None): @document("change", "clear") class ScatterPlot(Plot): + """ + Create a scatter plot. + + Preprocessing: this component does *not* accept input. + Postprocessing: expects a pandas dataframe with the data to plot. + + Demos: native_plots + """ + def __init__( self, x: Optional[str] = None, @@ -3963,6 +3972,8 @@ def __init__( color_legend_title: Optional[str] = None, size_legend_title: Optional[str] = None, shape_legend_title: Optional[str] = None, + height: Optional[int] = None, + width: Optional[int] = None, caption: Optional[str] = None, interactive: Optional[bool] = True, label: Optional[str] = None, @@ -3970,6 +3981,30 @@ def __init__( visible: bool = True, elem_id: Optional[str] = None, ): + """ + Parameters: + x: Column corresponding to the x axis. + y: Column corresponding to the y axis. + value: The pandas dataframe containing the data to display in a scatter plot. + color: The column to determine the point color. If the column contains numeric data, gradio will interpolate the column data so that small values correspond to light colors and large values correspond to dark values. + size: The column used to determine the point size. Should contain numeric data so that gradio can map the data to the point size. + shape: The column used to determine the point shape. Should contain categorical data. Gradio will map each unique value to a different shape. + title: The title to display on top of the chart. + tooltip: The column (or list of columns) to display on the tooltip when a user hovers a point on the plot. + x_title: The title given to the x axis. By default, uses the value of the x parameter. + y_title: The title given to the y axis. By default, uses the value of the y parameter. + color_legend_title: The title given to the color legend. By default, uses the value of color parameter. + size_legend_title: The title given to the size legend. By default, uses the value of the size parameter. + shape_legend_title: The title given to the shape legend. By default, uses the value of the shape parameter. + height: The height of the plot in pixels. + width: The width of the plot in pixels. + caption: The (optional) caption to display below the plot. + interactive: Whether users should be able to interact with the plot by panning or zooming with their mouse or trackpad. + label: The (optional) label to display on the top left corner of the plot. + show_label: Whether the label should be displayed. + visible: Whether the plot should be visible. + elem_id: Unique id used for custom css targetting. + """ self.x = x self.y = y self.color = color @@ -3983,7 +4018,9 @@ def __init__( self.size_legend_title = size_legend_title self.shape_legend_title = shape_legend_title self.caption = caption - self.interactive = interactive + self.interactive_chart = interactive + self.width = width + self.height = height self.value = None if value is not None: self.value = self.postprocess(value) @@ -3993,7 +4030,6 @@ def __init__( show_label=show_label, visible=visible, elem_id=elem_id, - interactive=interactive, ) def get_config(self): @@ -4019,12 +4055,40 @@ def update( color_legend_title: Optional[str] = None, size_legend_title: Optional[str] = None, shape_legend_title: Optional[str] = None, - interactive: Optional[bool] = True, + height: Optional[int] = None, + width: Optional[int] = None, + interactive: Optional[bool] = None, caption: Optional[str] = None, label: Optional[str] = None, show_label: Optional[bool] = None, visible: Optional[bool] = None, ): + """Update an existing plot component. + + If updating any of the plot properties (color, size, etc) the value, x, and y parameters must be specified. + + Parameters: + value: The pandas dataframe containing the data to display in a scatter plot. + x: Column corresponding to the x axis. + y: Column corresponding to the y axis. + color: The column to determine the point color. If the column contains numeric data, gradio will interpolate the column data so that small values correspond to light colors and large values correspond to dark values. + size: The column used to determine the point size. Should contain numeric data so that gradio can map the data to the point size. + shape: The column used to determine the point shape. Should contain categorical data. Gradio will map each unique value to a different shape. + title: The title to display on top of the chart. + tooltip: The column (or list of columns) to display on the tooltip when a user hovers a point on the plot. + x_title: The title given to the x axis. By default, uses the value of the x parameter. + y_title: The title given to the y axis. By default, uses the value of the y parameter. + color_legend_title: The title given to the color legend. By default, uses the value of color parameter. + size_legend_title: The title given to the size legend. By default, uses the value of the size parameter. + shape_legend_title: The title given to the shape legend. By default, uses the value of the shape parameter. + height: The height of the plot in pixels. + width: The width of the plot in pixels. + caption: The (optional) caption to display below the plot. + interactive: Whether users should be able to interact with the plot by panning or zooming with their mouse or trackpad. + label: The (optional) label to display in the top left corner of the plot. + show_label: Whether the label should be displayed. + visible: Whether the plot should be visible. + """ properties = [ x, y, @@ -4039,6 +4103,8 @@ def update( size_legend_title, shape_legend_title, interactive, + height, + width, ] if any(properties): if value is _Keywords.NO_VALUE: @@ -4053,15 +4119,14 @@ def update( "must be specified. Please pass valid values for x an y to " "gr.ScatterPlot.update." ) - - chart = ScatterPlot.create_plot(value, *properties) - new_chart_str = {"type": "altair", "plot": chart.to_json(), "chart": "scatter"} + chart = ScatterPlot.create_plot(value, *properties) + value = {"type": "altair", "plot": chart.to_json(), "chart": "scatter"} updated_config = { "label": label, "show_label": show_label, "visible": visible, - "value": new_chart_str, + "value": value, "caption": caption, "__type__": "update", } @@ -4082,9 +4147,12 @@ def create_plot( color_legend_title: Optional[str] = None, size_legend_title: Optional[str] = None, shape_legend_title: Optional[str] = None, + height: Optional[int] = None, + width: Optional[int] = None, interactive: Optional[bool] = True, ): - + """Helper for creating the scatter plot.""" + interactive = True if interactive is None else interactive encodings = dict( x=alt.X(x, title=x_title or x), y=alt.Y(y, title=y_title or y), @@ -4092,6 +4160,10 @@ def create_plot( properties = {} if title: properties["title"] = title + if height: + properties["height"] = height + if width: + properties["width"] = width if color: if is_numeric_dtype(value[color]): domain = [value[color].min(), value[color].max()] @@ -4151,7 +4223,9 @@ def postprocess(self, y: pd.DataFrame | Dict | None) -> Dict[str, str] | None: color_legend_title=self.color_legend_title, size_legend_title=self.size_legend_title, shape_legend_title=self.size_legend_title, - interactive=self.interactive, + interactive=self.interactive_chart, + height=self.height, + width=self.width, ) return {"type": "altair", "plot": chart.to_json(), "chart": "scatter"} diff --git a/test/test_components.py b/test/test_components.py index 54d5f0612afd..74d6ed5034e9 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -1858,7 +1858,7 @@ def test_get_config(self): assert gr.ScatterPlot().get_config() == { "caption": None, "elem_id": None, - "interactive": True, + "interactive": None, "label": None, "name": "plot", "root_url": None, @@ -1890,6 +1890,8 @@ def test_no_color(self): } } assert config["title"] == "Car Data" + assert "height" not in config + assert "width" not in config def test_no_interactive(self): plot = gr.ScatterPlot( @@ -1900,6 +1902,16 @@ def test_no_interactive(self): config = json.loads(output["plot"]) assert "selection" not in config + def test_height_width(self): + plot = gr.ScatterPlot( + x="Horsepower", y="Miles_per_Gallon", height=100, width=200 + ) + output = plot.postprocess(cars) + assert sorted(list(output.keys())) == ["chart", "plot", "type"] + config = json.loads(output["plot"]) + assert config["height"] == 100 + assert config["width"] == 200 + def test_color_encoding(self): plot = gr.ScatterPlot( x="Horsepower", @@ -1943,6 +1955,11 @@ def test_update(self): postprocessed = gr.ScatterPlot().postprocess(output["value"]) assert postprocessed == output["value"] + def test_update_visibility(self): + output = gr.ScatterPlot.update(visible=False) + assert not output["visible"] + assert output["value"] is gr.components._Keywords.NO_VALUE + def test_update_errors(self): with pytest.raises( ValueError, match="In order to update plot properties the value parameter" From 5dc41e2bb82364883cbe18cb5f0e76278ff54127 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Fri, 9 Dec 2022 09:00:00 -0500 Subject: [PATCH 26/30] Use normal font weight --- ui/packages/plot/src/utils.ts | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/ui/packages/plot/src/utils.ts b/ui/packages/plot/src/utils.ts index afc20499a7b0..fc0457d59a03 100644 --- a/ui/packages/plot/src/utils.ts +++ b/ui/packages/plot/src/utils.ts @@ -5,21 +5,26 @@ export function create_config(darkmode: boolean): VegaConfig { return { axis: { labelFont: "sans-serif", - labelColor: darkmode ? tw_colors.slate["200"] : "black", + labelColor: darkmode ? tw_colors.slate["200"] : tw_colors.gray[900], titleFont: "sans-serif", - titleColor: darkmode ? tw_colors.slate["200"] : "black", + titleColor: darkmode ? tw_colors.slate["200"] : tw_colors.gray[900], tickColor: "#aaa", - gridColor: "#aaa" + gridColor: "#aaa", + titleFontWeight: "normal", + labelFontWeight: "normal" }, legend: { - labelColor: darkmode ? tw_colors.slate["200"] : "black", + labelColor: darkmode ? tw_colors.slate["200"] : tw_colors.gray[900], labelFont: "sans-serif", - titleColor: darkmode ? tw_colors.slate["200"] : "black", - titleFont: "sans-serif" + titleColor: darkmode ? tw_colors.slate["200"] : tw_colors.gray[900], + titleFont: "sans-serif", + titleFontWeight: "normal", + labelFontWeight: "normal" }, title: { - color: darkmode ? tw_colors.slate["200"] : "black", - font: "sans-serif" + color: darkmode ? tw_colors.slate["200"] : tw_colors.gray[900], + font: "sans-serif", + fontWeight: "normal" } }; } From ec01da3d636e6ca37e09094c82376e5e1c7b2a7d Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Fri, 9 Dec 2022 09:28:18 -0500 Subject: [PATCH 27/30] Make last values keyword only Co-authored-by: Abubakar Abid --- gradio/components.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gradio/components.py b/gradio/components.py index 99d664817a86..1cadc320840d 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -3987,9 +3987,10 @@ class ScatterPlot(Plot): def __init__( self, + value: Optional[pd.DataFrame] = None, x: Optional[str] = None, y: Optional[str] = None, - value: Optional[pd.DataFrame] = None, + * color: Optional[str] = None, size: Optional[str] = None, shape: Optional[str] = None, From 8146fda47ecf0025cd4a1262469a72e7f333909b Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Fri, 9 Dec 2022 09:29:46 -0500 Subject: [PATCH 28/30] Fix typo --- gradio/components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gradio/components.py b/gradio/components.py index 1cadc320840d..6bf6932f88e5 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -3990,7 +3990,7 @@ def __init__( value: Optional[pd.DataFrame] = None, x: Optional[str] = None, y: Optional[str] = None, - * + *, color: Optional[str] = None, size: Optional[str] = None, shape: Optional[str] = None, From 920cd925649cda50e8200232b43e65feb15cb605 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Fri, 9 Dec 2022 09:39:55 -0500 Subject: [PATCH 29/30] Accept value as fn --- gradio/components.py | 12 ++++++------ test/test_components.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/gradio/components.py b/gradio/components.py index 6bf6932f88e5..452a5a5dac43 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -3987,7 +3987,7 @@ class ScatterPlot(Plot): def __init__( self, - value: Optional[pd.DataFrame] = None, + value: Optional[pd.DataFrame | Callable] = None, x: Optional[str] = None, y: Optional[str] = None, *, @@ -4012,9 +4012,9 @@ def __init__( ): """ Parameters: + value: The pandas dataframe containing the data to display in a scatter plot. x: Column corresponding to the x axis. y: Column corresponding to the y axis. - value: The pandas dataframe containing the data to display in a scatter plot. color: The column to determine the point color. If the column contains numeric data, gradio will interpolate the column data so that small values correspond to light colors and large values correspond to dark values. size: The column used to determine the point size. Should contain numeric data so that gradio can map the data to the point size. shape: The column used to determine the point shape. Should contain categorical data. Gradio will map each unique value to a different shape. @@ -4050,11 +4050,11 @@ def __init__( self.interactive_chart = interactive self.width = width self.height = height - self.value = None - if value is not None: - self.value = self.postprocess(value) + # self.value = None + # if value is not None: + # self.value = self.postprocess(value) super().__init__( - value, + value=value, label=label, show_label=show_label, visible=visible, diff --git a/test/test_components.py b/test/test_components.py index dbec6a24c2a4..447e26b72008 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -2016,3 +2016,13 @@ def test_update_errors(self): match="In order to update plot properties, the x and y axis data", ): gr.ScatterPlot.update(value=cars, x="foo") + + def test_scatterplot_accepts_fn_as_value(self): + plot = gr.ScatterPlot( + value=lambda: cars.sample(frac=0.1, replace=False), + x="Horsepower", + y="Miles_per_Gallon", + color="Origin", + ) + assert isinstance(plot.value, dict) + assert isinstance(plot.value["plot"], str) From f2cb2e9f3048daa713624c313d757cb97a26bad0 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Fri, 9 Dec 2022 09:43:41 -0500 Subject: [PATCH 30/30] reword changelog a bit --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 25f2b36ced5d..b6084c7db141 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ ### Scatter plot component -It is now possible to create a scatter plot without knowledge of a plotting library! +It is now possible to create a scatter plot natively in Gradio! The `gr.ScatterPlot` component accepts a pandas dataframe and some optional configuration parameters and will automatically create a plot for you!