Skip to content

Commit

Permalink
hound
Browse files Browse the repository at this point in the history
  • Loading branch information
Erik Bernhardsson committed Apr 10, 2018
1 parent 2eaf8fc commit 3661685
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
3 changes: 2 additions & 1 deletion convoys/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import random
import seaborn
from matplotlib import pyplot
from convoys.multi import *
from convoys.multi import Exponential, Weibull, Gamma, GeneralizedGamma, \
KaplanMeier, Nonparametric


def get_timescale(t):
Expand Down
18 changes: 13 additions & 5 deletions convoys/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ def fit(self, X, B, T, k=None, p=None):
p = tf.constant(p, tf.float32)

# PDF: p*lambda^(k*p) / gamma(k) * t^(k*p-1) * exp(-(x*lambda)^p)
log_pdf = tf.log(p) + (k*p) * tf.log(lambd) - tf.lgamma(k) + (k*p-1) * tf.log(T_batch) - (T_batch*lambd)**p
log_pdf = \
tf.log(p) + (k*p) * tf.log(lambd) \
- tf.lgamma(k) + (k*p-1) * tf.log(T_batch) \
- (T_batch*lambd)**p
cdf = tf.igamma(k, (T_batch*lambd)**p)

LL_observed = tf.log(c) + log_pdf
Expand All @@ -76,9 +79,10 @@ def fit(self, X, B, T, k=None, p=None):

with tf.Session() as sess:
feed_dict = {X_batch: X, B_batch: B, T_batch: T}
tf_utils.optimize(sess, LL, feed_dict,
update_callback=(tf_utils.get_tweaker(sess, LL, k, feed_dict)
if should_update_k else None))
tf_utils.optimize(
sess, LL, feed_dict,
update_callback=(tf_utils.get_tweaker(sess, LL, k, feed_dict)
if should_update_k else None))
self.params = {
'a': a.params(sess, LL, feed_dict),
'b': b.params(sess, LL, feed_dict),
Expand All @@ -90,7 +94,11 @@ def cdf(self, x, t, ci=None, n=1000):
t = numpy.array(t)
a = LinearCombination.sample(self.params['a'], x, ci, n)
b = LinearCombination.sample(self.params['b'], x, ci, n)
return tf_utils.predict(expit(b) * gammainc(self.params['k'], numpy.multiply.outer(t, numpy.exp(a))**self.params['p']), ci)
return tf_utils.predict(
expit(b) * gammainc(
self.params['k'],
numpy.multiply.outer(t, numpy.exp(a))**self.params['p']),
ci)


class Exponential(GeneralizedGamma):
Expand Down

0 comments on commit 3661685

Please sign in to comment.