Skip to content

Commit

Permalink
Extend color legend size for scatter charts for matplotlib (#286)
Browse files Browse the repository at this point in the history
* rec series

* rec

* add changes back

* fix tests

* rm

* change

* set plotting_style

* nice

* black

* rm

* change rec

* reuse export

* change size of plot

* rm

* ref

* fixed

* reform

* enlarge font

* enlarge font

Co-authored-by: Caitlyn Chen <caitlynachen@berkeley.edu>
  • Loading branch information
caitlynachen and Caitlyn Chen committed Mar 16, 2021
1 parent 0748949 commit 182a4b7
Showing 1 changed file with 34 additions and 11 deletions.
45 changes: 34 additions & 11 deletions lux/vislib/matplotlib/ScatterChart.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,20 @@ def initialize_chart(self):
x_pts = df[x_attr.attribute]
y_pts = df[y_attr.attribute]

set_fig_code = ""
plot_code = ""

color_attr = self.vis.get_attr_by_channel("color")
if len(color_attr) == 1:
self.fig, self.ax = matplotlib_setup(6, 5)
color_attr_name = color_attr[0].attribute
color_attr_type = color_attr[0].data_type
colors = df[color_attr_name].values
plot_code += f"colors = df['{color_attr_name}'].values\n"
unique = list(set(colors))
vals = [unique.index(i) for i in colors]
if color_attr_type == "quantitative":
self.fig, self.ax = matplotlib_setup(7, 5)
set_fig_code = "fig, ax = plt.subplots(7, 5)\n"
self.ax.scatter(x_pts, y_pts, c=vals, cmap="Blues", alpha=0.5)
plot_code += f"ax.scatter(x_pts, y_pts, c={vals}, cmap='Blues', alpha=0.5)\n"
my_cmap = plt.cm.get_cmap("Blues")
Expand All @@ -82,10 +84,29 @@ def initialize_chart(self):
plot_code += f"cbar = plt.colorbar(sm, label='{color_attr_name}')\n"
plot_code += f"cbar.outline.set_linewidth(0)\n"
else:
scatter = self.ax.scatter(x_pts, y_pts, c=vals, cmap="Set1")
plot_code += f"scatter = ax.scatter(x_pts, y_pts, c={vals}, cmap='Set1')\n"
if len(unique) >= 16:
unique = unique[:16]

maxlen = 0
for i in range(len(unique)):
unique[i] = str(unique[i])
if len(unique[i]) > 26:
unique[i] = unique[i][:26] + "..."
if len(unique[i]) > maxlen:
maxlen = len(unique[i])
if maxlen > 20:
self.fig, self.ax = matplotlib_setup(9, 5)
set_fig_code = "fig, ax = plt.subplots(9, 5)\n"
else:
self.fig, self.ax = matplotlib_setup(7, 5)
set_fig_code = "fig, ax = plt.subplots(7, 5)\n"

cmap = "Set1"
if len(unique) > 9:
cmap = "tab20c"
scatter = self.ax.scatter(x_pts, y_pts, c=vals, cmap=cmap)
plot_code += f"scatter = ax.scatter(x_pts, y_pts, c={vals}, cmap={cmap})\n"

unique = [str(i) for i in unique]
leg = self.ax.legend(
handles=scatter.legend_elements(num=range(0, len(unique)))[0],
labels=unique,
Expand All @@ -95,6 +116,7 @@ def initialize_chart(self):
loc="upper left",
ncol=1,
frameon=False,
fontsize="13",
)
scatter.set_alpha(0.5)
plot_code += f"""ax.legend(
Expand All @@ -105,23 +127,24 @@ def initialize_chart(self):
bbox_to_anchor=(1.05, 1),
loc='upper left',
ncol=1,
frameon=False,)\n"""
frameon=False,
fontsize='13')\n"""
plot_code += "scatter.set_alpha(0.5)\n"
else:
set_fig_code = "fig, ax = plt.subplots(4.5, 4)\n"
self.ax.scatter(x_pts, y_pts, alpha=0.5)
plot_code += f"ax.scatter(x_pts, y_pts, alpha=0.5)\n"

self.ax.set_xlabel(x_attr_abv)
self.ax.set_ylabel(y_attr_abv)
self.ax.set_xlabel(x_attr_abv, fontsize="15")
self.ax.set_ylabel(y_attr_abv, fontsize="15")

self.code += "import numpy as np\n"
self.code += "from math import nan\n"
self.code += "from matplotlib.cm import ScalarMappable\n"

self.code += f"fig, ax = plt.subplots()\n"
self.code += set_fig_code
self.code += f"x_pts = df['{x_attr.attribute}']\n"
self.code += f"y_pts = df['{y_attr.attribute}']\n"

self.code += plot_code
self.code += f"ax.set_xlabel('{x_attr_abv}')\n"
self.code += f"ax.set_ylabel('{y_attr_abv}')\n"
self.code += f"ax.set_xlabel('{x_attr_abv}', fontsize='15')\n"
self.code += f"ax.set_ylabel('{y_attr_abv}', fontsize='15')\n"

0 comments on commit 182a4b7

Please sign in to comment.