Skip to content

Commit

Permalink
remove deprecated back_prop kwarg from tf.while_loop
Browse files Browse the repository at this point in the history
  • Loading branch information
duhaime committed Oct 12, 2021
1 parent 0cc4a02 commit aebb53b
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions gpt_2_simple/src/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,23 @@ def body(past, prev, output):
def cond(*args):
return True

_, _, tokens = tf.while_loop(
cond=cond, body=body,
maximum_iterations=length,
loop_vars=[
context_output['presents'],
context[:, -1],
context,
],
shape_invariants=[
tf.TensorShape(model.past_shape(
hparams=hparams, batch_size=batch_size)),
tf.TensorShape([batch_size]),
tf.TensorShape([batch_size, None]),
],
back_prop=False,
_, _, tokens = tf.nest.map_structure(
tf.stop_gradient,
tf.while_loop(
cond=cond,
body=body,
maximum_iterations=length,
loop_vars=[
context_output['presents'],
context[:, -1],
context,
],
shape_invariants=[
tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)),
tf.TensorShape([batch_size]),
tf.TensorShape([batch_size, None]),
],
)
)

return tokens

0 comments on commit aebb53b

Please sign in to comment.