Skip to content

Commit

Permalink
Pass stat_func through jointplot
Browse files Browse the repository at this point in the history
  • Loading branch information
tavinathanson committed Aug 17, 2016
1 parent b1461ca commit 1b245d1
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
5 changes: 3 additions & 2 deletions cohorts/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from topiary.sequence_helpers import contains_mutant_residues
from isovar.protein_sequence import variants_to_protein_sequences_dataframe
from pysam import AlignmentFile
from scipy.stats import pearsonr

from .utils import strip_column_names as _strip_column_names
from .survival import plot_kmf
Expand Down Expand Up @@ -1219,7 +1220,7 @@ def plot_survival(self, on, col=None, how="os", survival_units="Days", ax=None,
ci_show=ci_show)
return results

def plot_joint(self, on, on_two=None, **kwargs):
def plot_joint(self, on, on_two=None, stat_func=pearsonr, **kwargs):
"""Plot a jointplot.
Parameters
Expand All @@ -1234,7 +1235,7 @@ def plot_joint(self, on, on_two=None, **kwargs):
plot_cols, df = self.as_dataframe(on, **kwargs)
for plot_col in plot_cols:
df = filter_not_null(df, plot_col)
p = sb.jointplot(data=df, x=plot_cols[0], y=plot_cols[1])
p = sb.jointplot(data=df, x=plot_cols[0], y=plot_cols[1], stat_func=stat_func)
return p

def _list_patient_ids(self):
Expand Down
8 changes: 4 additions & 4 deletions cohorts/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ def bootstrap_auc(df, col, pred_col, n_bootstrap=1000):
scores[i] = roc_auc_score(sampled_pred, sampled_counts)
return scores

def cohort_bootstrap_auc(cohort, func, pred_col="is_benefit", n_bootstrap=1000):
col, df = cohort.as_dataframe(func)
def cohort_bootstrap_auc(cohort, func, pred_col="is_benefit", n_bootstrap=1000, **kwargs):
col, df = cohort.as_dataframe(func, **kwargs)
return bootstrap_auc(df=df,
col=col,
pred_col=pred_col,
n_bootstrap=n_bootstrap)

def cohort_mean_bootstrap_auc(cohort, func, pred_col="is_benefit", n_bootstrap=1000):
return cohort_bootstrap_auc(cohort, func, pred_col, n_bootstrap).mean()
def cohort_mean_bootstrap_auc(cohort, func, pred_col="is_benefit", n_bootstrap=1000, **kwargs):
return cohort_bootstrap_auc(cohort, func, pred_col, n_bootstrap, **kwargs).mean()

def coxph_model(formula, data, time_col, event_col, **kwargs):
# pylint: disable=no-member
Expand Down

0 comments on commit 1b245d1

Please sign in to comment.