diff --git a/lux/vislib/matplotlib/ScatterChart.py b/lux/vislib/matplotlib/ScatterChart.py index cd605f93..cb038ee5 100644 --- a/lux/vislib/matplotlib/ScatterChart.py +++ b/lux/vislib/matplotlib/ScatterChart.py @@ -53,11 +53,11 @@ def initialize_chart(self): x_pts = df[x_attr.attribute] y_pts = df[y_attr.attribute] + set_fig_code = "" plot_code = "" color_attr = self.vis.get_attr_by_channel("color") if len(color_attr) == 1: - self.fig, self.ax = matplotlib_setup(6, 5) color_attr_name = color_attr[0].attribute color_attr_type = color_attr[0].data_type colors = df[color_attr_name].values @@ -65,6 +65,8 @@ def initialize_chart(self): unique = list(set(colors)) vals = [unique.index(i) for i in colors] if color_attr_type == "quantitative": + self.fig, self.ax = matplotlib_setup(7, 5) + set_fig_code = "fig, ax = plt.subplots(7, 5)\n" self.ax.scatter(x_pts, y_pts, c=vals, cmap="Blues", alpha=0.5) plot_code += f"ax.scatter(x_pts, y_pts, c={vals}, cmap='Blues', alpha=0.5)\n" my_cmap = plt.cm.get_cmap("Blues") @@ -82,10 +84,29 @@ def initialize_chart(self): plot_code += f"cbar = plt.colorbar(sm, label='{color_attr_name}')\n" plot_code += f"cbar.outline.set_linewidth(0)\n" else: - scatter = self.ax.scatter(x_pts, y_pts, c=vals, cmap="Set1") - plot_code += f"scatter = ax.scatter(x_pts, y_pts, c={vals}, cmap='Set1')\n" + if len(unique) >= 16: + unique = unique[:16] + + maxlen = 0 + for i in range(len(unique)): + unique[i] = str(unique[i]) + if len(unique[i]) > 26: + unique[i] = unique[i][:26] + "..." + if len(unique[i]) > maxlen: + maxlen = len(unique[i]) + if maxlen > 20: + self.fig, self.ax = matplotlib_setup(9, 5) + set_fig_code = "fig, ax = plt.subplots(9, 5)\n" + else: + self.fig, self.ax = matplotlib_setup(7, 5) + set_fig_code = "fig, ax = plt.subplots(7, 5)\n" + + cmap = "Set1" + if len(unique) > 9: + cmap = "tab20c" + scatter = self.ax.scatter(x_pts, y_pts, c=vals, cmap=cmap) + plot_code += f"scatter = ax.scatter(x_pts, y_pts, c={vals}, cmap={cmap})\n" - unique = [str(i) for i in unique] leg = self.ax.legend( handles=scatter.legend_elements(num=range(0, len(unique)))[0], labels=unique, @@ -95,6 +116,7 @@ def initialize_chart(self): loc="upper left", ncol=1, frameon=False, + fontsize="13", ) scatter.set_alpha(0.5) plot_code += f"""ax.legend( @@ -105,23 +127,24 @@ def initialize_chart(self): bbox_to_anchor=(1.05, 1), loc='upper left', ncol=1, - frameon=False,)\n""" + frameon=False, + fontsize='13')\n""" plot_code += "scatter.set_alpha(0.5)\n" else: + set_fig_code = "fig, ax = plt.subplots(4.5, 4)\n" self.ax.scatter(x_pts, y_pts, alpha=0.5) plot_code += f"ax.scatter(x_pts, y_pts, alpha=0.5)\n" - - self.ax.set_xlabel(x_attr_abv) - self.ax.set_ylabel(y_attr_abv) + self.ax.set_xlabel(x_attr_abv, fontsize="15") + self.ax.set_ylabel(y_attr_abv, fontsize="15") self.code += "import numpy as np\n" self.code += "from math import nan\n" self.code += "from matplotlib.cm import ScalarMappable\n" - self.code += f"fig, ax = plt.subplots()\n" + self.code += set_fig_code self.code += f"x_pts = df['{x_attr.attribute}']\n" self.code += f"y_pts = df['{y_attr.attribute}']\n" self.code += plot_code - self.code += f"ax.set_xlabel('{x_attr_abv}')\n" - self.code += f"ax.set_ylabel('{y_attr_abv}')\n" + self.code += f"ax.set_xlabel('{x_attr_abv}', fontsize='15')\n" + self.code += f"ax.set_ylabel('{y_attr_abv}', fontsize='15')\n"