Skip to content

Commit

Permalink
Merge pull request #54 from better/misc
Browse files Browse the repository at this point in the history
Remove Seaborn req, add some sanity checking of input data
  • Loading branch information
erikbern committed May 12, 2018
2 parents 04dd396 + 58390c8 commit d9697a2
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
4 changes: 2 additions & 2 deletions convoys/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import datetime
import numpy
import random
import seaborn
from matplotlib import pyplot
from convoys.multi import Exponential, Weibull, Gamma, GeneralizedGamma, \
KaplanMeier, Nonparametric
Expand Down Expand Up @@ -89,7 +88,8 @@ def plot_cohorts(data, t_max=None, title=None, group_min_size=0, max_groups=100,
extra_m.fit(G, B, T)

# Plot
colors = seaborn.color_palette('hls', len(groups))
colors = pyplot.get_cmap('tab10').colors
colors = [colors[i % len(colors)] for i in range(len(groups))]
t = numpy.linspace(0, t_max, 1000)
y_max = 0
result = []
Expand Down
15 changes: 12 additions & 3 deletions convoys/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from scipy.special import expit, gammainc, gammaincinv
import scipy.stats
import tensorflow as tf
import warnings
from convoys import tf_utils


Expand Down Expand Up @@ -52,11 +53,19 @@ class GeneralizedGamma(RegressionModel):
def fit(self, X, B, T, W=None, k=None, p=None, method='Powell'):
# Note on using Powell: tf.igamma returns the wrong gradient wrt k
# https://github.com/tensorflow/tensorflow/issues/17995
n_features = X.shape[1]
X, B, T = (numpy.array(z, dtype=numpy.float32) for z in (X, B, T))
# Sanity check input:
if W is None:
W = numpy.ones(B.shape, dtype=numpy.float32)
W = [1] * len(X)
XBTW = [(x, b, t, w) for x, b, t, w in zip(X, B, T, W)
if t > 0 or float(t) not in [0, 1] or w < 0]
if len(XBTW) < len(X):
n_removed = len(X) - len(XBTW)
warnings.warn('Warning! Removed %d entries from inputs where' +
'T <= 0 or B not 0/1 or W < 0' % n_removed)
X, B, T, W = (numpy.array([z[i] for z in XBTW], dtype=numpy.float32)
for i in range(4))

n_features = X.shape[1]
a = LinearCombination(X, n_features)
b = LinearCombination(X, n_features)
lambd = tf.exp(a.y)
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
matplotlib>=2.0.0
numpy
scipy
seaborn==0.8.1
tensorflow>=1.6.0

0 comments on commit d9697a2

Please sign in to comment.