New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sns.lmplot x_estimator / logistic estimation speed #347

Closed
tomwallis opened this Issue Oct 29, 2014 · 6 comments

Comments

Projects
None yet
2 participants
@tomwallis
Copy link

tomwallis commented Oct 29, 2014

I'm a long-time ggplot2 (R) user switching over to Python. Seaborn is lovely, but I've noticed that at least for me, its x_estimator or logistic regression estimators are very slow compared to R.

You can see more details on my blog here, but the short story is: plotting a summary of binomial data, with summarised x values, some logistic regression lines, and a facet_wrap by subject (col='subject') takes 3 seconds in R and 2 minutes in Seaborn. Test data here. Seaborn code here:

# do imports
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import ggplot

dat = pd.read_csv('contrast_data.csv')

# make a log contrast variable (because Seaborn can't auto-log an axis like ggplot):
dat['log_contrast'] = np.log(dat['contrast'])

# set some styles we like:
sns.set_style("white")
sns.set_style("ticks")
pal = sns.cubehelix_palette(5, start=2, rot=0, dark=0.1, light=.8, reverse=True)

%%timeit
fig = sns.lmplot("log_contrast", "correct", dat,
                 x_estimator=np.mean, 
                 col="subject", 
                 hue='sf',
                 col_wrap=3,
                 logistic=True,
                 palette=pal,
                 ci=True);

I'm afraid I haven't done extensive testing so I can't tell you what specifically is slow (bootstrapping? logistic model fitting?). I wanted to raise it here because I couldn't see anyone talking about this on SO or here.

Are there plans to optimise the backend you're using? If someone wanted to help, where would they start?

@mwaskom

This comment has been minimized.

Copy link
Owner

mwaskom commented Oct 29, 2014

Yes it's probably the fact that the CIs are computed with a bootstrap rather than analytically.

@mwaskom

This comment has been minimized.

Copy link
Owner

mwaskom commented Oct 29, 2014

By the way ci=True is going to be interpreted as "show the 1% confidence interval". That value should either be the confidence level or None to skip bootstrapping.

@tomwallis

This comment has been minimized.

Copy link

tomwallis commented Oct 29, 2014

The R (ggplot) code I compare it to is also bootstrapping (at least for the data points). The line
stat_summary(fun.data = "mean_cl_boot")

produces points that show the data mean, and does 1000 bootstrap iterations to compute 95% confidence intervals. So it's not that R is computing analytically.

Thanks for the tip on the cis.

@mwaskom

This comment has been minimized.

Copy link
Owner

mwaskom commented Oct 29, 2014

I don't really know ggplot but I gather that's just getting error bars on the point estimates for each data bin. What's taking time is bootstrapping the logistic regression to get error bands on the regression prediction.

@mwaskom

This comment has been minimized.

Copy link
Owner

mwaskom commented Oct 29, 2014

Using the class object that is actually doing everything gives some further insight. This is on my 2 year-old macbook air:

dat_conditioned = dat.query("subject == 'S1' and sf == 0.5")
plotter = sns.linearmodels._RegressionPlotter("log_contrast", "correct", dat_conditioned,
                                               x_estimator=np.mean, logistic=True)

This computes the point estimate and CIs for each level of conditioning (it's actually a property):

%timeit plotter.estimate_data
1 loops, best of 3: 228 ms per loop

This needs to be done 25 times to get the whole plot (5 hue levels and col levels), which means the aggregating and bootstrapping for the point estimates takes about 5-6 seconds.

Bootstrapping the logistic regression takes substantially longer:

%timeit plotter.fit_regression(x_range=(-7, 0))
1 loops, best of 3: 4.2 s per loop

Fitting a logistic regression is fairly computationally expensive:

plotter = sns.linearmodels._RegressionPlotter("log_contrast", "correct", dat_conditioned,
                                              x_estimator=np.mean, logistic=True, ci=None)
%timeit plotter.fit_regression(x_range=(-7, 0))
100 loops, best of 3: 4.6 ms per loop
@tomwallis

This comment has been minimized.

Copy link

tomwallis commented Oct 29, 2014

Ah, awesome stuff, thanks Michael. It is indeed the logistic regression bootstrapping. Changing to ci=False now takes 6.7 seconds in the example I posted instead of 2 minutes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment