Skip to content

Commit

Permalink
use constants instead of placeholders for inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Mar 17, 2018
1 parent 169b02f commit 75520da
Showing 1 changed file with 20 additions and 27 deletions.
47 changes: 20 additions & 27 deletions convoys/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,11 @@

tf.logging.set_verbosity(3)

def _get_placeholders(n, k):
return (
tf.placeholder(tf.float32, [n, k]),
tf.placeholder(tf.float32, [n]),
tf.placeholder(tf.float32, [n])
)
def _get_constants(args):
return (tf.constant(arg.astype(numpy.float32)) for arg in args)


def _optimize(sess, target, feed_dict, variables):
def _optimize(sess, target, variables):
learning_rate_input = tf.placeholder(tf.float32, [])
optimizer = tf.train.AdamOptimizer(learning_rate_input).minimize(-target)

Expand All @@ -28,18 +24,18 @@ def _optimize(sess, target, feed_dict, variables):

best_step, step = 0, 0
learning_rate = 1.0
best_cost = sess.run(target, feed_dict=feed_dict)
best_cost = sess.run(target)
any_var_is_nan = tf.is_nan(tf.add_n([tf.reduce_sum(v) for v in variables]))

while True:
feed_dict[learning_rate_input] = learning_rate
feed_dict = {learning_rate_input: learning_rate}
if step < 120:
feed_dict[learning_rate_input] = min(learning_rate, 10**(step//20-6))
sess.run(optimizer, feed_dict=feed_dict)
if sess.run(any_var_is_nan):
cost = float('-inf')
else:
cost = sess.run(target, feed_dict=feed_dict)
cost = sess.run(target)
if cost > best_cost:
best_cost, best_step = cost, step
sess.run(store_best_state)
Expand All @@ -61,8 +57,8 @@ def _get_params(sess, params):
return {key: sess.run(param) for key, param in params.items()}


def _get_hessian(sess, f, param, feed_dict):
return sess.run(tf.hessians(-f, [param]), feed_dict=feed_dict)[0]
def _get_hessian(sess, f, param):
return sess.run(tf.hessians(-f, [param]))[0]


def _fix_t(t):
Expand Down Expand Up @@ -101,7 +97,7 @@ def __init__(self, L2_reg=1.0):
class ExponentialRegression(Regression):
def fit(self, X, B, T):
n, k = X.shape
X_input, B_input, T_input = _get_placeholders(n, k)
X_input, B_input, T_input = _get_constants((X, B, T))

alpha = tf.Variable(tf.zeros([k]), 'alpha')
beta = tf.Variable(tf.zeros([k]), 'beta')
Expand All @@ -120,11 +116,10 @@ def fit(self, X, B, T):
LL_penalized = LL - self._L2_reg * tf.reduce_sum(beta * beta, 0)

with tf.Session() as sess:
feed_dict = {X_input: X, B_input: B, T_input: T}
_optimize(sess, LL_penalized, feed_dict, (alpha, beta))
_optimize(sess, LL_penalized, (alpha, beta))
self.params = _get_params(sess, {'beta': beta, 'alpha': alpha})
self.params['alpha_hessian'] = _get_hessian(sess, LL_penalized, alpha, feed_dict)
self.params['beta_hessian'] = _get_hessian(sess, LL_penalized, beta, feed_dict)
self.params['alpha_hessian'] = _get_hessian(sess, LL_penalized, alpha)
self.params['beta_hessian'] = _get_hessian(sess, LL_penalized, beta)

def predict(self, x, t, ci=None, n=1000):
t = _fix_t(t)
Expand All @@ -144,7 +139,7 @@ def predict_time(self, x, ci=None, n=1000):
class WeibullRegression(Regression):
def fit(self, X, B, T):
n, k = X.shape
X_input, B_input, T_input = _get_placeholders(n, k)
X_input, B_input, T_input = _get_constants((X, B, T))

alpha = tf.Variable(tf.zeros([k]), 'alpha')
beta = tf.Variable(tf.zeros([k]), 'beta')
Expand All @@ -167,11 +162,10 @@ def fit(self, X, B, T):
LL_penalized = LL - self._L2_reg * tf.reduce_sum(beta * beta, 0)

with tf.Session() as sess:
feed_dict = {X_input: X, B_input: B, T_input: T}
_optimize(sess, LL_penalized, feed_dict, (alpha, beta, log_k_var))
_optimize(sess, LL_penalized, (alpha, beta, log_k_var))
self.params = _get_params(sess, {'beta': beta, 'alpha': alpha, 'k': k})
self.params['alpha_hessian'] = _get_hessian(sess, LL_penalized, alpha, feed_dict)
self.params['beta_hessian'] = _get_hessian(sess, LL_penalized, beta, feed_dict)
self.params['alpha_hessian'] = _get_hessian(sess, LL_penalized, alpha)
self.params['beta_hessian'] = _get_hessian(sess, LL_penalized, beta)

def predict(self, x, t, ci=None, n=1000):
t = _fix_t(t)
Expand All @@ -191,7 +185,7 @@ def predict_time(self, x, ci=None, n=1000):
class GammaRegression(Regression):
def fit(self, X, B, T):
n, k = X.shape
X_input, B_input, T_input = _get_placeholders(n, k)
X_input, B_input, T_input = _get_constants((X, B, T))

alpha = tf.Variable(tf.zeros([k]), 'alpha')
beta = tf.Variable(tf.zeros([k]), 'beta')
Expand All @@ -214,11 +208,10 @@ def fit(self, X, B, T):
LL_penalized = LL - self._L2_reg * tf.reduce_sum(beta * beta, 0)

with tf.Session() as sess:
feed_dict = {X_input: X, B_input: B, T_input: T}
_optimize(sess, LL_penalized, feed_dict, (alpha, beta, log_k_var))
_optimize(sess, LL_penalized, (alpha, beta, log_k_var))
self.params = _get_params(sess, {'beta': beta, 'alpha': alpha, 'k': k})
self.params['alpha_hessian'] = _get_hessian(sess, LL_penalized, alpha, feed_dict)
self.params['beta_hessian'] = _get_hessian(sess, LL_penalized, beta, feed_dict)
self.params['alpha_hessian'] = _get_hessian(sess, LL_penalized, alpha)
self.params['beta_hessian'] = _get_hessian(sess, LL_penalized, beta)

def predict(self, x, t, ci=None, n=1000):
t = _fix_t(t)
Expand Down

0 comments on commit 75520da

Please sign in to comment.