Skip to content

Commit

Permalink
Merge e974a9b into 31545c4
Browse files Browse the repository at this point in the history
  • Loading branch information
jburos committed Jul 7, 2017
2 parents 31545c4 + e974a9b commit 488182a
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 36 deletions.
2 changes: 2 additions & 0 deletions cohorts/cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,6 +1317,8 @@ def plot_survival(self,
df = filter_not_null(df, plot_col)
if df[plot_col].dtype == "bool":
default_threshold = None
if df[plot_col].dtype == "O": # is string
default_threshold = None
else:
default_threshold = "median"
results = plot_kmf(
Expand Down
105 changes: 71 additions & 34 deletions cohorts/survival.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from lifelines import KaplanMeierFitter
from lifelines import KaplanMeierFitter, CoxPHFitter
from lifelines.statistics import logrank_test
import matplotlib.colors as colors
from matplotlib import pyplot as plt
import seaborn as sb
import patsy
from .rounding import float_str

def plot_kmf(df,
Expand Down Expand Up @@ -59,8 +62,12 @@ def plot_kmf(df,
print_as_title: bool, optional, whether or not to print text
within the plot's title vs. stdout, default False
"""
with_condition_color = colors.hex2color(with_condition_color)
no_condition_color = colors.hex2color(no_condition_color)
if ax is None:
ax = plt.subplot(111)
if colors.is_color_like(with_condition_color):
with_condition_color = colors.to_hex(with_condition_color)
if colors.is_color_like(no_condition_color):
no_condition_color = colors.to_hex(no_condition_color)
kmf = KaplanMeierFitter()
if threshold is not None:
is_median = threshold == "median"
Expand All @@ -72,59 +79,89 @@ def plot_kmf(df,
if is_median:
label_suffix += " (median)"
default_label_with_condition = "%s > %s" % (condition_col, label_suffix)
else:
with_condition_label = with_condition_label or default_label_with_condition
no_condition_label = no_condition_label or default_label_no_condition
label_map = {False: no_condition_label,
True: with_condition_label}
color_map = {False: no_condition_color,
True: with_condition_color}
elif df[condition_col].dtype == 'O':
condition = df[condition_col].astype('category')
label_map = dict()
[label_map.update({condition_value: '{} = {}'.format(condition_col,
condition_value)})
for condition_value in condition.unique()]
rgb_values = sb.color_palette("Set2", len(label_map.keys()))
hex_values = [colors.to_hex(col) for col in rgb_values]
color_map = dict(zip(label_map.keys(), hex_values))
elif df[condition_col].dtype == 'bool':
condition = df[condition_col]
default_label_with_condition = "= {}".format(condition_col)
default_label_no_condition = "¬ {}".format(condition_col)

with_condition_label = with_condition_label or default_label_with_condition
no_condition_label = no_condition_label or default_label_no_condition

df_with_condition = df[condition]
df_no_condition = df[~condition]
survival_no_condition = df_no_condition[survival_col]
survival_with_condition = df_with_condition[survival_col]

event_no_condition = (df_no_condition[censor_col].astype(bool))
event_with_condition = (df_with_condition[censor_col].astype(bool))

kmf.fit(survival_no_condition, event_no_condition, label=(no_condition_label))
if ax:
kmf.plot(ax=ax, show_censors=True, ci_show=ci_show, color=no_condition_color)
with_condition_label = with_condition_label or default_label_with_condition
no_condition_label = no_condition_label or default_label_no_condition
label_map = {False: no_condition_label,
True: with_condition_label}
color_map = {False: no_condition_color,
True: with_condition_color}
else:
ax = kmf.plot(show_censors=True, ci_show=ci_show, color=no_condition_color)

kmf.fit(survival_with_condition, event_with_condition, label=(with_condition_label))
plot = kmf.plot(ax=ax, show_censors=True, ci_show=ci_show, color=with_condition_color)
raise ValueError('Don\'t know how to plot data of type\
{}'.format(df[condition_col].dtype))

grp_desc = list()
grp_survival_data = dict()
grp_event_data = dict()
grp_names = list(condition.unique())
for grp_name, grp_df in df.groupby(condition):
grp_survival = grp_df[survival_col]
grp_event = (grp_df[censor_col].astype(bool))
grp_label = label_map[grp_name]
grp_color = color_map[grp_name]
kmf.fit(grp_survival, grp_event, label=grp_label)
desc_str = "# {}: {}".format(grp_label, len(grp_survival))
grp_desc.append(desc_str)
grp_survival_data[grp_name] = grp_survival
grp_event_data[grp_name] = grp_event
if ax:
ax = kmf.plot(ax=ax, show_censors=True, ci_show=ci_show, color=grp_color)
else:
ax = kmf.plot(show_censors=True, ci_show=ci_show, color=grp_color)

# Set the y-axis to range 0 to 1
ax.set_ylim(0, 1)
y_tick_vals = ax.get_yticks()
ax.set_yticklabels(["%d" % int(y_tick_val * 100) for y_tick_val in y_tick_vals])

no_cond_str = "# no condition {}".format(len(survival_no_condition))
cond_str = "# with condition {}".format(len(survival_with_condition))
if title:
ax.set_title(title)
elif print_as_title:
ax.set_title("%s | %s" % (no_cond_str, cond_str))
ax.set_title(' | '.join(grp_desc))
else:
print(no_cond_str)
print(cond_str)
[print(desc) for desc in grp_desc]

if xlabel:
ax.set_xlabel(xlabel)
if ylabel:
ax.set_ylabel(ylabel)

results = logrank_test(survival_no_condition,
survival_with_condition,
event_observed_A=event_no_condition,
event_observed_B=event_with_condition)
results.without_condition_series = survival_no_condition
results.with_condition_series = survival_with_condition
if len(grp_names) == 2:
results = logrank_test(grp_survival_data[grp_names[0]],
grp_survival_data[grp_names[1]],
event_observed_A=grp_event_data[grp_names[0]],
event_observed_B=grp_event_data[grp_names[1]])
else:
cf = CoxPHFitter()
cox_df = patsy.dmatrix('+'.join([condition_col, survival_col,
censor_col]),
df, return_type='dataframe')
del cox_df['Intercept']
results = cf.fit(cox_df, survival_col, event_col=censor_col)
results.print_summary()
results.survival_data_series = grp_survival_data
results.event_data_series = grp_event_data
return results


def logrank(df,
condition_col,
censor_col,
Expand Down
2 changes: 1 addition & 1 deletion pylintrc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[TYPECHECK]
# Without ignoring this, we get errors like:
# E:249,20: Module 'numpy' has no 'nan' member (no-member)
ignored-modules = numpy, inspect
ignored-modules = numpy, inspect, patsy
ignored-classes = DataFrameHolder, TextFileReader, tuple, list, zip, izip, str
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ topiary>=0.1.0, <0.2.0
mhctools>=0.3.0, <0.4.0
varcode>=0.5.15, <0.6.0
pyensembl>=1.0.1, <1.1.0
lifelines>=0.9.1.0
lifelines>=0.9.3.2
scikit-learn>=0.17.1
vcf-annotate-polyphen>=0.1.2
nose>=1.3.3
Expand Down

0 comments on commit 488182a

Please sign in to comment.