New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added temperature flag to generation script #131
Conversation
np.seterr(divide='ignore') | ||
prediction = np.log(prediction) / args.temp | ||
prediction[np.isneginf(prediction)] = 0 | ||
prediction = np.exp(prediction) / np.sum(np.exp(prediction)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this line is better implemented as something like:
prediction = prediction - scipy.misc.logsumexp(prediction)
prediction = np.exp(prediction)
By operating in the log domain as much as possible we can avoid the division which leads to instability when we divide by something very close to zero. Especially if we ever want to run using float16. And you were in the log domain with prediction already, anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never would have thought of that; thanks for the guidance!
I made @jyegerlehner's suggested changes, which introduces a |
If we don't like the dependency, would logsumexp() be equivalent to np.log(np.sum(np.exp(prediction)))? I wouldn't uncritically paste my code in there; I was just describing an idea, and my code is not usually right the first time around :). A test that shows that the temperature sampling with T=1 produces same distribution as without would be good. And perhaps, if I understand this, in the limit as T->0, it should be equivalent to choosing the most likely quantization level. And then as T-> inf, all quantization levels become equally likely? |
Great idea on the test. I've never written a Python test before so I'll poke at the existing tests in this project and figure it out. I'll try that numpy approach. I tried something different (no.logaddexp.reduce) and didn't get what I expected so I think a more stepwise approach would would be better for me. |
I'm fine with using scipy as long as we put it into the This is a really nice addition to the project! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple nits.
@@ -179,6 +187,15 @@ def main(): | |||
|
|||
# Run the WaveNet to predict the next sample. | |||
prediction = sess.run(outputs, feed_dict={samples: window})[0] | |||
|
|||
# Scale sample distribution using temperature, if applicable. | |||
if (args.temp != 1.0 and args.temp > 0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parentheses are not needed here.
@@ -36,6 +39,11 @@ def _str_to_bool(s): | |||
default=SAMPLES, | |||
help='How many waveform samples to generate') | |||
parser.add_argument( | |||
'--temp', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use temperature
or sampling_temperature
.
Quick question for @jyegerlehner or anyone else as I'm fixing this up: When I scale the prediction distribution with |
@robinsloan That sounds plenty good to me. We might want to future-proof for fp16, and make the tolerance wider, because float16 only has 6-7 significant digits. Maybe use 1e-4 for a tolerance? [Edit] oops, float32 has 6-7 digits. fp16 only has 3 or 4. |
Updates:
Thanks for all the feedback & support on this! |
@@ -7,12 +7,14 @@ | |||
import os | |||
|
|||
import librosa | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this blank line?
np.seterr(divide='warn') | ||
|
||
# Prediction distribution at temperature=1.0 should be unchanged after scaling. | ||
if args.temperature == 1.0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a unit test, we can move it later when tests for generation are ready.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, when #142 is merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't written test coverage in Python before and couldn't determine where in the codebase to put this, if not here. Pointers appreciated!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You didn't find it because it doesn't exist yet -- there is a PR that provides tests for this file. Thanks for specifying how to test it, though! We'll move it later, no problem.
@@ -27,6 +29,12 @@ def _str_to_bool(s): | |||
'boolean, got {}'.format(s)) | |||
return {'true': True, 'false': False}[s.lower()] | |||
|
|||
def _ensure_positive_float(f): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be generic and not specific to the sampling temperature?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't want to change anything outside the scope of this feature; maybe this suggestion is better left to a general refactoring of the argument parsing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohhh wait I see what you're saying. I didn't notice that the ArgumentTypeError
indicated automatically the argument to which it was objecting. So the error can just say "yo you need a positive float." Got it got it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't you add this function as part of the feature? It has a very generic name but it's specific to the temperature parameter, that's why I mentioned this. I think argparse
takes care of referring to the argument name when displaying an error, so I suggested leaving the "positive float" generic rather than adding a _parse_temperature
function. The _str_to_bool
is due for a refactor anyway though, so we could do both of them together later on, but I feel it's a but unnecessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha yep, we typed at the same time. Thanks!
For the other reviewers: these are minor nits, feel free to merge and mark to fix later. |
LGTM |
@lemonzi Would we want to wait to merge this until after PR 142 is merged? Those tests in 142 will be running this code, so we will get to see that they still pass before we merge this change. |
LGTM, too! |
@jyegerlehner Good point, we can wait, there's only that conflict left to resolve. |
@jyegerlehner: Okay, that's a good idea. @robinsloan will have to rebase on top of |
@robinsloan Could you push a dummy change to your branch (or rebase it?), to trigger another run of the tests, so we can see if they pass, now that the generation test has been merged? Thx |
(You can do that with |
That's new to me—thanks @lemonzi! |
It's nice to be able to specify sampling "temperature" when generating output, usually for aesthetic reasons, so I added some code to scale the sampling probabilities if a temperature other than 1.0 is provided.
Demo: https://soundcloud.com/robinsloan/sets/tensorflow-wavenet-temperature-demo