-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
factor out conversion tool for dataframes
- Loading branch information
Showing
5 changed files
with
181 additions
and
131 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,121 +0,0 @@ | ||
import datetime | ||
import numpy | ||
import random | ||
from matplotlib import pyplot | ||
from convoys.multi import Exponential, Weibull, Gamma, GeneralizedGamma, \ | ||
KaplanMeier | ||
|
||
|
||
def get_timescale(t): | ||
def get_timedelta_converter(t_factor): | ||
return lambda td: td.total_seconds() * t_factor | ||
|
||
if type(t) != datetime.timedelta: | ||
# Assume numeric type | ||
return '', lambda x: x | ||
elif t >= datetime.timedelta(days=1): | ||
return 'Days', get_timedelta_converter(1./(24*60*60)) | ||
elif t >= datetime.timedelta(hours=1): | ||
return 'Hours', get_timedelta_converter(1./(60*60)) | ||
elif t >= datetime.timedelta(minutes=1): | ||
return 'Minutes', get_timedelta_converter(1./60) | ||
else: | ||
return 'Minutes', get_timedelta_converter(1) | ||
|
||
|
||
def get_arrays(groups, data, t_converter): | ||
G, B, T = [], [], [] | ||
group2j = dict((group, j) for j, group in enumerate(groups)) | ||
for group, created_at, converted_at, now in data: | ||
if created_at is None: | ||
print('created at is None') | ||
continue | ||
if converted_at is not None and converted_at <= created_at: | ||
print('created at', created_at, 'but converted at', converted_at) | ||
continue | ||
if now < created_at: | ||
print('created at', created_at, 'but now is', now) | ||
continue | ||
if converted_at is not None and now < converted_at: | ||
print('converted at', converted_at, 'but now is', now) | ||
continue | ||
if group in group2j: | ||
G.append(group2j[group]) | ||
B.append(converted_at is not None) | ||
T.append(t_converter(converted_at - created_at) if converted_at is not None else t_converter(now - created_at)) | ||
return numpy.array(G), numpy.array(B), numpy.array(T) | ||
|
||
|
||
def get_groups(data, group_min_size, max_groups): | ||
group2count = {} | ||
for group, created_at, converted_at, now in data: | ||
group2count[group] = group2count.get(group, 0) + 1 | ||
|
||
# Remove groups with too few data points | ||
# Pick the top groups | ||
# Sort groups lexicographically | ||
groups = [group for group, count in group2count.items() if count >= group_min_size] | ||
groups = sorted(groups, key=group2count.get, reverse=True)[:max_groups] | ||
return sorted(groups) | ||
|
||
|
||
_models = { | ||
'kaplan-meier': KaplanMeier, | ||
'exponential': lambda: Exponential(ci=True), | ||
'weibull': lambda: Weibull(ci=True), | ||
'gamma': lambda: Gamma(ci=True), | ||
'generalized-gamma': lambda: GeneralizedGamma(ci=True), | ||
} | ||
|
||
|
||
def plot_cohorts(data, t_max=None, title=None, group_min_size=0, max_groups=100, model='kaplan-meier', ci=0.95, extra_model=None): | ||
# Set x scale | ||
if t_max is None: | ||
t_max = max(now - created_at for group, created_at, converted_at, now in data) | ||
t_unit, t_converter = get_timescale(t_max) | ||
t_max = t_converter(t_max) | ||
|
||
# Split data by group and get data | ||
groups = get_groups(data, group_min_size, max_groups) | ||
G, B, T = get_arrays(groups, data, t_converter) | ||
|
||
# Fit model | ||
m = _models[model]() | ||
m.fit(G, B, T) | ||
if extra_model is not None: | ||
extra_m = _models[extra_model]() | ||
extra_m.fit(G, B, T) | ||
|
||
# Plot | ||
colors = pyplot.get_cmap('tab10').colors | ||
colors = [colors[i % len(colors)] for i in range(len(groups))] | ||
t = numpy.linspace(0, t_max, 1000) | ||
y_max = 0 | ||
result = [] | ||
for j, (group, color) in enumerate(zip(groups, colors)): | ||
n = sum(1 for g in G if g == j) # TODO: slow | ||
k = sum(1 for g, b in zip(G, B) if g == j and b) # TODO: slow | ||
label = '%s (n=%.0f, k=%.0f)' % (group, n, k) | ||
|
||
if ci is not None: | ||
p_y, p_y_lo, p_y_hi = m.cdf(j, t, ci=ci).T | ||
pyplot.plot(t, 100. * p_y, color=color, linewidth=1.5, alpha=0.7, label=label) | ||
pyplot.fill_between(t, 100. * p_y_lo, 100. * p_y_hi, color=color, alpha=0.2) | ||
else: | ||
p_y = m.cdf(j, t).T | ||
pyplot.plot(t, 100. * p_y, color=color, linewidth=1.5, alpha=0.7, label=label) | ||
|
||
if extra_model is not None: | ||
extra_p_y = extra_m.cdf(j, t) | ||
pyplot.plot(t, 100. * extra_p_y, color=color, linestyle='--', linewidth=1.5, alpha=0.7) | ||
y_max = max(y_max, 110. * max(p_y)) | ||
|
||
if title: | ||
pyplot.title(title) | ||
pyplot.xlim([0, t_max]) | ||
pyplot.ylim([0, y_max]) | ||
pyplot.xlabel(t_unit) | ||
pyplot.ylabel('Conversion rate %') | ||
pyplot.legend() | ||
pyplot.gca().grid(True) | ||
return m | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import datetime | ||
import numpy | ||
from matplotlib import pyplot | ||
import convoys.multi | ||
|
||
|
||
_models = { | ||
'kaplan-meier': convoys.multi.KaplanMeier, | ||
'exponential': lambda: convoys.multi.Exponential(ci=True), | ||
'weibull': lambda: convoys.multi.Weibull(ci=True), | ||
'gamma': lambda: convoys.multi.Gamma(ci=True), | ||
'generalized-gamma': lambda: convoys.multi.GeneralizedGamma(ci=True), | ||
} | ||
|
||
|
||
def plot_cohorts(G, B, T, t_max=None, title=None, model='kaplan-meier', ci=0.95, extra_model=None): | ||
# Set x scale | ||
if t_max is None: | ||
t_max = max(T) | ||
|
||
groups = set(G) # TODO: fix | ||
|
||
# Fit model | ||
m = _models[model]() | ||
m.fit(G, B, T) | ||
if extra_model is not None: | ||
extra_m = _models[extra_model]() | ||
extra_m.fit(G, B, T) | ||
|
||
# Plot | ||
colors = pyplot.get_cmap('tab10').colors | ||
colors = [colors[i % len(colors)] for i in range(len(groups))] | ||
t = numpy.linspace(0, t_max, 1000) | ||
y_max = 0 | ||
result = [] | ||
for j, (group, color) in enumerate(zip(groups, colors)): | ||
n = sum(1 for g in G if g == j) # TODO: slow | ||
k = sum(1 for g, b in zip(G, B) if g == j and b) # TODO: slow | ||
label = '%s (n=%.0f, k=%.0f)' % (group, n, k) | ||
|
||
if ci is not None: | ||
p_y, p_y_lo, p_y_hi = m.cdf(j, t, ci=ci).T | ||
pyplot.plot(t, 100. * p_y, color=color, linewidth=1.5, alpha=0.7, label=label) | ||
pyplot.fill_between(t, 100. * p_y_lo, 100. * p_y_hi, color=color, alpha=0.2) | ||
else: | ||
p_y = m.cdf(j, t).T | ||
pyplot.plot(t, 100. * p_y, color=color, linewidth=1.5, alpha=0.7, label=label) | ||
|
||
if extra_model is not None: | ||
extra_p_y = extra_m.cdf(j, t) | ||
pyplot.plot(t, 100. * extra_p_y, color=color, linestyle='--', linewidth=1.5, alpha=0.7) | ||
y_max = max(y_max, 110. * max(p_y)) | ||
|
||
if title: | ||
pyplot.title(title) | ||
pyplot.xlim([0, t_max]) | ||
pyplot.ylim([0, y_max]) | ||
# pyplot.xlabel(t_unit) | ||
pyplot.ylabel('Conversion rate %') | ||
pyplot.legend() | ||
pyplot.gca().grid(True) | ||
return m |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import datetime | ||
import numpy | ||
import pandas | ||
|
||
|
||
def get_timescale(t): | ||
''' Take a datetime or a numerical type, return two things: | ||
1. A unit | ||
2. A function that converts it to numerical form | ||
''' | ||
def get_timedelta_converter(t_factor): | ||
return lambda td: td.total_seconds() * t_factor | ||
|
||
if not isinstance(t, datetime.timedelta): | ||
# Assume numeric type | ||
return '', lambda x: x | ||
elif t >= datetime.timedelta(days=1): | ||
return 'Days', get_timedelta_converter(1./(24*60*60)) | ||
elif t >= datetime.timedelta(hours=1): | ||
return 'Hours', get_timedelta_converter(1./(60*60)) | ||
elif t >= datetime.timedelta(minutes=1): | ||
return 'Minutes', get_timedelta_converter(1./60) | ||
else: | ||
return 'Minutes', get_timedelta_converter(1) | ||
|
||
|
||
def get_groups(data, group_min_size, max_groups): | ||
''' Picks the top groups out of a dataset | ||
1. Remove groups with too few data points | ||
2. Pick the top groups | ||
3. Sort groups lexicographically | ||
''' | ||
group2count = {} | ||
for group in data: | ||
group2count[group] = group2count.get(group, 0) + 1 | ||
|
||
groups = [group for group, count in group2count.items() if count >= group_min_size] | ||
if max_groups >= 0: | ||
groups = sorted(groups, key=group2count.get, reverse=True)[:max_groups] | ||
return sorted(groups) | ||
|
||
|
||
def get_arrays(data, features=None, groups=None, created=None, converted=None, now=None, group_min_size=0, max_groups=-1): | ||
''' Converts a dataframe to a list of numpy arrays. | ||
Each input refers to a column in the dataframe. | ||
TODO: more doc | ||
''' | ||
if groups is not None: | ||
group2j = dict((group, j) for j, group in enumerate(get_groups(data[groups], group_min_size, max_groups))) | ||
|
||
# TODO: sanity check inputs | ||
|
||
X, G, B = [], [], [] | ||
T_raw = [] # might be timedeltas or numerical | ||
for i, row in data.iterrows(): | ||
if groups is not None: | ||
if row[groups] not in group2j: | ||
continue | ||
G.append(group2j[row[groups]]) | ||
if features is not None: | ||
X.append(row[features]) | ||
if not pandas.isnull(row[converted]): | ||
B.append(True) | ||
if created is not None: | ||
T_raw.append(row[converted] - row[created]) | ||
else: | ||
T_raw.append(row[converted]) | ||
else: | ||
B.append(False) | ||
if created is not None: | ||
if now is not None: | ||
T_raw.append(row[now] - row[created]) | ||
else: | ||
T_raw.append(datetime.datetime.now(tzinfo=row[created].tzinfo) - row[created_at]) | ||
else: | ||
T_raw.append(row[now]) | ||
|
||
unit, converter = get_timescale(max(T_raw)) | ||
T = [converter(t) for t in T_raw] | ||
X, G, B, T = (numpy.array(z) for z in (X, G, B, T)) | ||
|
||
res = [] | ||
if groups is not None: | ||
res.append(G) | ||
elif features is not None: | ||
res.append(X) | ||
res.append(B) | ||
res.append(T) | ||
|
||
return unit, tuple(res) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
autograd | ||
emcee | ||
matplotlib>=2.0.0 | ||
pandas | ||
numpy | ||
scipy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters