Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Units keyword fails for pandas string datatype when calling lineplot #2797

Closed
OlgerSiebinga opened this issue May 13, 2022 · 2 comments
Closed

Comments

@OlgerSiebinga
Copy link

I have encountered a problem when calling the function lineplot with the keyword argument units. This happens when the data is a pandas DataFrame and the units series has the pandas string dtype. A minimum example to reproduce this error is attached below. I use the following versions on a Windows 10 computer.

Python 3.8
Seaborn 0.11.2
Numpy 1.22.3
Pandas 1.4.2

import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


if __name__ == '__main__':
    units = np.random.randint(0, 5, 500).astype(str)
    x = np.random.random(500) * 100
    y = np.random.random(500) * 20 - 10

    data = pd.DataFrame({'units': units,
                         'x': x,
                         'y': y})

    data['units'] = data['units'].astype(pd.StringDtype())

    sns.lineplot(data=data, x='x', y='y', units='units', hue='units', estimator=None)
    plt.show()

This code returns the following error:

Traceback (most recent call last):
  File "C:/Users/OlgerSiebinga/PycharmProjects/seaborn_bug/main.py", line 18, in <module>
    sns.lineplot(data=data, x='x', y='y', units='units', hue='units', estimator=None)
  File "C:\Users\OlgerSiebinga\PycharmProjects\seaborn_bug\venv\lib\site-packages\seaborn\_decorators.py", line 46, in inner_f
    return f(**kwargs)
  File "C:\Users\OlgerSiebinga\PycharmProjects\seaborn_bug\venv\lib\site-packages\seaborn\relational.py", line 710, in lineplot
    p.plot(ax, kwargs)
  File "C:\Users\OlgerSiebinga\PycharmProjects\seaborn_bug\venv\lib\site-packages\seaborn\relational.py", line 527, in plot
    ax.plot(x[rows], y[rows], **kws)
IndexError: arrays used as indices must be of integer (or boolean) type

I think the issue is that on relational.py line 526, the np.asarray(u == u_i) function returns an array with type object if u has dtype pd.StringDtype() and u_i is a Python str object. I believe a solution would be to change line 526 into np.asarray(u == u_i, dtype=bool), that fixes the issue for me but I'm not sure if that would break other things as I did not run any further tests. (sorry, I didn't have time to fork and create a merge request).

I also found a possible workaround for the time being (if other people encounter this same issue). The problem is gone when you cast the dtype of the units in the pandas data frame to str or object (e.g., change line 16 in the example above to data['units'] = data['units'].astype(str))

This should give the expected outcome, a figure like this:
image

@mwaskom
Copy link
Owner

mwaskom commented May 13, 2022

Thank you for the reproducible example. I can reproduce it on 0.11.2 with slightly different dependency versions (numpy 1.22.0, pandas 1.3.5, matplotlib 3.5.1) but it runs fine for me on the development version of seaborn. I don't recall addressing this specifically, but the relevant code does look different (it does a pandas groupby over the unit variable instead of numpy boolean indexing. So this may be moot?

@OlgerSiebinga
Copy link
Author

Hi @mwaskom! Thanks for checking this. Moving to the development version fixed it for me as well. So'll use that version until it is released and I can upgrade. Thanks!

@mwaskom mwaskom closed this as completed May 17, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants