Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added temperature flag to generation script #131

Merged
merged 7 commits into from Oct 18, 2016
Merged

Added temperature flag to generation script #131

merged 7 commits into from Oct 18, 2016

Conversation

robinsloan
Copy link
Contributor

It's nice to be able to specify sampling "temperature" when generating output, usually for aesthetic reasons, so I added some code to scale the sampling probabilities if a temperature other than 1.0 is provided.

Demo: https://soundcloud.com/robinsloan/sets/tensorflow-wavenet-temperature-demo

np.seterr(divide='ignore')
prediction = np.log(prediction) / args.temp
prediction[np.isneginf(prediction)] = 0
prediction = np.exp(prediction) / np.sum(np.exp(prediction))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line is better implemented as something like:

prediction = prediction - scipy.misc.logsumexp(prediction)
prediction = np.exp(prediction)

By operating in the log domain as much as possible we can avoid the division which leads to instability when we divide by something very close to zero. Especially if we ever want to run using float16. And you were in the log domain with prediction already, anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never would have thought of that; thanks for the guidance!

@robinsloan
Copy link
Contributor Author

I made @jyegerlehner's suggested changes, which introduces a scipy dependency, which is… not great? But there's no direct numpy equivalent to the scipy.misc.logsumexp function, unfortunately.

@jyegerlehner
Copy link
Contributor

If we don't like the dependency, would logsumexp() be equivalent to np.log(np.sum(np.exp(prediction)))?

I wouldn't uncritically paste my code in there; I was just describing an idea, and my code is not usually right the first time around :).

A test that shows that the temperature sampling with T=1 produces same distribution as without would be good.

And perhaps, if I understand this, in the limit as T->0, it should be equivalent to choosing the most likely quantization level. And then as T-> inf, all quantization levels become equally likely?

@robinsloan
Copy link
Contributor Author

Great idea on the test. I've never written a Python test before so I'll poke at the existing tests in this project and figure it out.

I'll try that numpy approach. I tried something different (no.logaddexp.reduce) and didn't get what I expected so I think a more stepwise approach would would be better for me.

@ibab
Copy link
Owner

ibab commented Oct 10, 2016

I'm fine with using scipy as long as we put it into the requirements.txt :)
Note that logsumexp(predictions) is different from doing log(sum(exp(predictions))).
It shifts by max(predictions) in order to avoid underflow.

This is a really nice addition to the project!

Copy link
Collaborator

@lemonzi lemonzi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple nits.

@@ -179,6 +187,15 @@ def main():

# Run the WaveNet to predict the next sample.
prediction = sess.run(outputs, feed_dict={samples: window})[0]

# Scale sample distribution using temperature, if applicable.
if (args.temp != 1.0 and args.temp > 0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parentheses are not needed here.

@@ -36,6 +39,11 @@ def _str_to_bool(s):
default=SAMPLES,
help='How many waveform samples to generate')
parser.add_argument(
'--temp',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use temperature or sampling_temperature.

@robinsloan
Copy link
Contributor Author

Quick question for @jyegerlehner or anyone else as I'm fixing this up:

When I scale the prediction distribution with temperature=1.0, running it through log space, I get back the original distribution, as expected, though it's not identical -- it is np.allclose to within 1e-09. Does that sound reasonable? (I have no intuition for this.)

@jyegerlehner
Copy link
Contributor

jyegerlehner commented Oct 13, 2016

@robinsloan That sounds plenty good to me. We might want to future-proof for fp16, and make the tolerance wider, because float16 only has 6-7 significant digits. Maybe use 1e-4 for a tolerance?

[Edit] oops, float32 has 6-7 digits. fp16 only has 3 or 4.

@robinsloan
Copy link
Contributor Author

Updates:

  • Application of temperature to prediction distribution now happens in log space
  • A test ensures scaling at temperature=1.0 gives us back the original predictions, to within 1e-5
  • No scipy dependency
  • Style changes, per @lemonzi

Thanks for all the feedback & support on this!

@@ -7,12 +7,14 @@
import os

import librosa

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this blank line?

np.seterr(divide='warn')

# Prediction distribution at temperature=1.0 should be unchanged after scaling.
if args.temperature == 1.0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a unit test, we can move it later when tests for generation are ready.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, when #142 is merged.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't written test coverage in Python before and couldn't determine where in the codebase to put this, if not here. Pointers appreciated!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You didn't find it because it doesn't exist yet -- there is a PR that provides tests for this file. Thanks for specifying how to test it, though! We'll move it later, no problem.

@@ -27,6 +29,12 @@ def _str_to_bool(s):
'boolean, got {}'.format(s))
return {'true': True, 'false': False}[s.lower()]

def _ensure_positive_float(f):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be generic and not specific to the sampling temperature?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't want to change anything outside the scope of this feature; maybe this suggestion is better left to a general refactoring of the argument parsing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohhh wait I see what you're saying. I didn't notice that the ArgumentTypeError indicated automatically the argument to which it was objecting. So the error can just say "yo you need a positive float." Got it got it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't you add this function as part of the feature? It has a very generic name but it's specific to the temperature parameter, that's why I mentioned this. I think argparse takes care of referring to the argument name when displaying an error, so I suggested leaving the "positive float" generic rather than adding a _parse_temperature function. The _str_to_bool is due for a refactor anyway though, so we could do both of them together later on, but I feel it's a but unnecessary.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha yep, we typed at the same time. Thanks!

@lemonzi
Copy link
Collaborator

lemonzi commented Oct 13, 2016

For the other reviewers: these are minor nits, feel free to merge and mark to fix later.

@lemonzi
Copy link
Collaborator

lemonzi commented Oct 14, 2016

LGTM

@jyegerlehner
Copy link
Contributor

jyegerlehner commented Oct 14, 2016

@lemonzi Would we want to wait to merge this until after PR 142 is merged? Those tests in 142 will be running this code, so we will get to see that they still pass before we merge this change.

@ibab
Copy link
Owner

ibab commented Oct 14, 2016

LGTM, too!
Feel free to merge.

@lemonzi
Copy link
Collaborator

lemonzi commented Oct 14, 2016

@jyegerlehner Good point, we can wait, there's only that conflict left to resolve.

@ibab
Copy link
Owner

ibab commented Oct 14, 2016

@jyegerlehner: Okay, that's a good idea. @robinsloan will have to rebase on top of master once we merge that PR so that travis will run your test.

@jyegerlehner
Copy link
Contributor

jyegerlehner commented Oct 17, 2016

@robinsloan Could you push a dummy change to your branch (or rebase it?), to trigger another run of the tests, so we can see if they pass, now that the generation test has been merged? Thx

@lemonzi
Copy link
Collaborator

lemonzi commented Oct 17, 2016

(You can do that with git commit --allow-empty -m "Trigger Travis")

@robinsloan
Copy link
Contributor Author

That's new to me—thanks @lemonzi!

@jyegerlehner jyegerlehner merged commit 2606971 into ibab:master Oct 18, 2016
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants