Skip to content

Commit

Permalink
get rid of seaborn requirement
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Mar 18, 2018
1 parent 433d23c commit 565bbd6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
18 changes: 13 additions & 5 deletions convoys/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import abc
import datetime
import lifelines
import math
import itertools
import numpy
import random
import seaborn
from matplotlib import pyplot
from convoys.multi import Exponential, Weibull, Gamma, KaplanMeier

Expand Down Expand Up @@ -76,6 +73,17 @@ def get_groups(data, group_min_size, max_groups):
return sorted(groups)


def generate_n_colors(n):
vs = numpy.linspace(0.4, 1.0, 7)
colors = [(1., .3, .3)]
def euclidean(a, b):
return sum((x-y)**2 for x, y in zip(a, b))
while len(colors) < n:
new_color = max(itertools.product(vs, vs, vs), key=lambda a: min(euclidean(a, b) for b in colors))
colors.append(new_color)
return colors


_models = {
'kaplan-meier': KaplanMeier,
'exponential': Exponential,
Expand All @@ -100,7 +108,7 @@ def plot_cohorts(data, t_max=None, title=None, group_min_size=0, max_groups=100,
m.fit(G, B, T)

# Plot
colors = seaborn.color_palette('hls', len(groups))
colors = generate_n_colors(len(groups))
t = numpy.linspace(0, t_max, 1000)
y_max = 0
result = []
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@ lifelines==0.11.2
matplotlib>=2.0.0
numpy
scipy
seaborn==0.8.1
tensorflow==1.6.0rc1

0 comments on commit 565bbd6

Please sign in to comment.