Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ensure that there are enough colors to match the score index in visua… #3560

Merged
merged 1 commit into from
Aug 31, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions ludwig/utils/visualization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ def compare_classifiers_plot(
width = 0.8 / num_metrics if num_metrics > 1 else 0.4
ticks = np.arange(len(scores[0]))

colors = plt.get_cmap("tab10").colors
if num_metrics <= 10:
jeffkinnison marked this conversation as resolved.
Show resolved Hide resolved
colors = plt.get_cmap("tab10").colors
else:
colors = plt.get_cmap("tab20").colors
Comment on lines +159 to +162
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I want to see if we can work together to find a better solution. The problem I see here is, what if we have 30 categories? This would break once again and we'd have to add another conditional block.

Can you give something like this a try?

import colorsys

def generate_color_palette(num_colors):
    base_colors = [(0.0, 0.5, 0.7),  # Blue
                   (0.3, 0.6, 0.2),  # Green
                   (0.9, 0.4, 0.1)]  # Orange
    
    palette = []
    for base_color in base_colors:
        for i in range(num_colors):
            hue = (base_color[0] + i * 0.1) % 1.0
            saturation = base_color[1]
            lightness = base_color[2]
            rgb = colorsys.hls_to_rgb(hue, lightness, saturation)
            palette.append(rgb)
    
    return palette

This should create good interpolation between colors and create something perceptually distinctive, but need to actually give it a try.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran this locally and this is what I see:

>>> plt.get_cmap("tab10").colors
((0.12156862745098039, 0.4666666666666667, 0.7058823529411765), (1.0, 0.4980392156862745, 0.054901960784313725), (0.17254901960784313, 0.6274509803921569, 0.17254901960784313), (0.8392156862745098, 0.15294117647058825, 0.1568627450980392), (0.5803921568627451, 0.403921568627451, 0.7411764705882353), (0.5490196078431373, 0.33725490196078434, 0.29411764705882354), (0.8901960784313725, 0.4666666666666667, 0.7607843137254902), (0.4980392156862745, 0.4980392156862745, 0.4980392156862745), (0.7372549019607844, 0.7411764705882353, 0.13333333333333333), (0.09019607843137255, 0.7450980392156863, 0.8117647058823529))
>>> generate_color_palette(10)
[(0.85, 0.5499999999999999, 0.5499999999999999), (0.85, 0.73, 0.5499999999999999), (0.7899999999999999, 0.85, 0.5499999999999999), (0.6099999999999999, 0.85, 0.5499999999999999), (0.5499999999999999, 0.85, 0.67), (0.5499999999999999, 0.8499999999999999, 0.85), (0.5499999999999999, 0.6699999999999997, 0.85), (0.6100000000000001, 0.5499999999999999, 0.85), (0.7899999999999999, 0.5499999999999999, 0.85), (0.85, 0.5499999999999999, 0.73), (0.12799999999999997, 0.32000000000000006, 0.07999999999999996), (0.07999999999999996, 0.32000000000000006, 0.17600000000000007), (0.07999999999999996, 0.32, 0.32000000000000006), (0.07999999999999996, 0.17599999999999982, 0.32000000000000006), (0.1279999999999998, 0.07999999999999996, 0.32000000000000006), (0.272, 0.07999999999999996, 0.32000000000000006), (0.32000000000000006, 0.07999999999999996, 0.22399999999999967), (0.32000000000000006, 0.07999999999999996, 0.07999999999999996), (0.32000000000000006, 0.22400000000000014, 0.07999999999999996), (0.2720000000000002, 0.32000000000000006, 0.07999999999999996), (0.13999999999999999, 0.060000000000000026, 0.108), (0.13999999999999999, 0.060000000000000026, 0.060000000000000026), (0.13999999999999999, 0.10800000000000004, 0.060000000000000026), (0.12399999999999993, 0.13999999999999999, 0.060000000000000026), (0.07600000000000001, 0.13999999999999999, 0.060000000000000026), (0.060000000000000026, 0.13999999999999999, 0.09199999999999997), (0.060000000000000026, 0.13999999999999996, 0.13999999999999999), (0.060000000000000026, 0.09199999999999994, 0.13999999999999999), (0.07600000000000007, 0.060000000000000026, 0.13999999999999999), (0.12399999999999999, 0.060000000000000026, 0.13999999999999999)]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great suggestion. Are there any concerns about potentially having inconsistent viz colors between different models? (same question for tab10/tab20 switching)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I considered the same thing @arnavgarg1 but thought from index not found to 20 was improvement. It is not a full solution however. If we are going to look for a proper solution instead of a hotfix then perhaps using the built-in normalization from matplotlib would be best? Something like this https://matplotlib.org/stable/gallery/images_contours_and_fields/colormap_normalizations.html#sphx-glr-gallery-images-contours-and-fields-colormap-normalizations-py

I'm fine with either one of the changes. Just let me know.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thelinuxkid Thanks for the prompt response - I guess the only challenge with this is that normalization may not ensure enough perceptual difference.

For now, if it helps to unblock you, I'm happy to merge your change in but would love to follow-up with another fix for this!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinxzhao any thoughts on this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Going with what @thelinuxkid SGTM! Let's merge it.

if adaptive:
maximum = max(max(score) for score in scores)
else:
Expand Down Expand Up @@ -211,8 +214,14 @@ def compare_classifiers_line_plot(
filename=None,
callbacks=None,
):
assert len(scores) > 0

sns.set_style("whitegrid")
colors = plt.get_cmap("tab10").colors

if len(scores) <= 10:
colors = plt.get_cmap("tab10").colors
else:
colors = plt.get_cmap("tab20").colors

fig, ax = plt.subplots()

Expand Down Expand Up @@ -267,7 +276,10 @@ def compare_classifiers_multiclass_multimetric_plot(
width = 0.9 / len(scores)
ticks = np.arange(len(scores[0]))

colors = plt.get_cmap("tab10").colors
if len(scores) <= 10:
colors = plt.get_cmap("tab10").colors
else:
colors = plt.get_cmap("tab20").colors
ax.set_xlabel("class")
ax.set_xticks(ticks + width)
if labels is not None:
Expand Down
Loading