Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend color legend size for scatter charts for matplotlib #286

Merged
merged 24 commits into from
Mar 16, 2021
45 changes: 34 additions & 11 deletions lux/vislib/matplotlib/ScatterChart.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,20 @@ def initialize_chart(self):
x_pts = df[x_attr.attribute]
y_pts = df[y_attr.attribute]

set_fig_code = ""
plot_code = ""

color_attr = self.vis.get_attr_by_channel("color")
if len(color_attr) == 1:
self.fig, self.ax = matplotlib_setup(6, 5)
color_attr_name = color_attr[0].attribute
color_attr_type = color_attr[0].data_type
colors = df[color_attr_name].values
plot_code += f"colors = df['{color_attr_name}'].values\n"
unique = list(set(colors))
vals = [unique.index(i) for i in colors]
if color_attr_type == "quantitative":
self.fig, self.ax = matplotlib_setup(7, 5)
caitlynachen marked this conversation as resolved.
Show resolved Hide resolved
set_fig_code = "fig, ax = plt.subplots(7, 5)\n"
self.ax.scatter(x_pts, y_pts, c=vals, cmap="Blues", alpha=0.5)
plot_code += f"ax.scatter(x_pts, y_pts, c={vals}, cmap='Blues', alpha=0.5)\n"
my_cmap = plt.cm.get_cmap("Blues")
Expand All @@ -82,10 +84,29 @@ def initialize_chart(self):
plot_code += f"cbar = plt.colorbar(sm, label='{color_attr_name}')\n"
plot_code += f"cbar.outline.set_linewidth(0)\n"
else:
scatter = self.ax.scatter(x_pts, y_pts, c=vals, cmap="Set1")
plot_code += f"scatter = ax.scatter(x_pts, y_pts, c={vals}, cmap='Set1')\n"
if len(unique) >= 16:
unique = unique[:16]

maxlen = 0
for i in range(len(unique)):
unique[i] = str(unique[i])
if len(unique[i]) > 26:
unique[i] = unique[i][:26] + "..."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not showing up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean by this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant the abbreviation "..." is not showing up in the legend, see comment above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

o hmm i think it was working on my end
109049161-c787d900-768c-11eb-95d0-44363d65fabc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you seeing something differently

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm able to see it.

Copy link
Member

@dorisjlee dorisjlee Mar 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Oops sorry, I meant when the ... vertically (not horizontally) when there are too many color categories.

In the Brand legend example, there are only 16 colors displayed in the legend but there is 29 different Brand in the dataset.

In this example, you would have ... under the amc legend

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this is a good point, a user might be confused why a specific category is not showing up, so something in the legend would be nice to address this. Another idea rather than using the '...' (which seems a bit off to me) could be to use multiple columns. Try maybe something like this: https://stackoverflow.com/questions/58344791/how-to-handle-categorical-legends-with-too-many-entries-to-match-a-fixed-chart-h

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh the "..." is not for the length of the legend, but rather the individual categories. I checked and Altair also didn't have a "..." at the end of the list, so didn't include it in this PR. If you want, maybe I can tackle that issue separately

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an issue #295 for this new change

if len(unique[i]) > maxlen:
maxlen = len(unique[i])
if maxlen > 20:
self.fig, self.ax = matplotlib_setup(9, 5)
set_fig_code = "fig, ax = plt.subplots(9, 5)\n"
else:
self.fig, self.ax = matplotlib_setup(7, 5)
set_fig_code = "fig, ax = plt.subplots(7, 5)\n"

cmap = "Set1"
if len(unique) > 9:
cmap = "tab20c"
scatter = self.ax.scatter(x_pts, y_pts, c=vals, cmap=cmap)
plot_code += f"scatter = ax.scatter(x_pts, y_pts, c={vals}, cmap={cmap})\n"

unique = [str(i) for i in unique]
leg = self.ax.legend(
handles=scatter.legend_elements(num=range(0, len(unique)))[0],
labels=unique,
Expand All @@ -95,6 +116,7 @@ def initialize_chart(self):
loc="upper left",
ncol=1,
frameon=False,
fontsize="13",
)
scatter.set_alpha(0.5)
plot_code += f"""ax.legend(
Expand All @@ -105,23 +127,24 @@ def initialize_chart(self):
bbox_to_anchor=(1.05, 1),
loc='upper left',
ncol=1,
frameon=False,)\n"""
frameon=False,
fontsize='13')\n"""
plot_code += "scatter.set_alpha(0.5)\n"
else:
set_fig_code = "fig, ax = plt.subplots(4.5, 4)\n"
self.ax.scatter(x_pts, y_pts, alpha=0.5)
plot_code += f"ax.scatter(x_pts, y_pts, alpha=0.5)\n"

self.ax.set_xlabel(x_attr_abv)
self.ax.set_ylabel(y_attr_abv)
self.ax.set_xlabel(x_attr_abv, fontsize="15")
self.ax.set_ylabel(y_attr_abv, fontsize="15")

self.code += "import numpy as np\n"
self.code += "from math import nan\n"
self.code += "from matplotlib.cm import ScalarMappable\n"

self.code += f"fig, ax = plt.subplots()\n"
self.code += set_fig_code
self.code += f"x_pts = df['{x_attr.attribute}']\n"
self.code += f"y_pts = df['{y_attr.attribute}']\n"

self.code += plot_code
self.code += f"ax.set_xlabel('{x_attr_abv}')\n"
self.code += f"ax.set_ylabel('{y_attr_abv}')\n"
self.code += f"ax.set_xlabel('{x_attr_abv}', fontsize='15')\n"
self.code += f"ax.set_ylabel('{y_attr_abv}', fontsize='15')\n"