-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature: beam search for improving global quality of new text samples #138
Comments
|
I recently implemented beam search for an RNN Language Model in context of Image Captioning in NeuralTalk2 repo (https://github.com/karpathy/neuraltalk2/blob/master/misc/LanguageModel.lua#L166). It does makes things work quite a bit better there. I should try to port the option to char-rnn. Thanks! |
|
If you've implemented it before, then even better! I'll be curious to see if beam search does fix my CSS woes, and if it helps out with another |
|
I wrote a beam search implementation on top of char-rnn some time ago. See sample-beam.lua here: https://github.com/pender/char-rnn/blob/master/sample-beam.lua . A beam width of 2 on a fully trained model eliminates effectively all of the typos even at temperature of 1; I've set this as the default and recommend it, because with a wider beam, the generated text is noticeably conservative in odd ways. I hope this is useful! |
|
pender: I've given it a shot. It seems to crash a lot at randomly with odd errors: Lower temperatures & higher beam widths seem more likely to trigger the multinomial error - beam widths of 50+ hardly ever work (some sort of numerical issue?). Anyway, when it does work, I see what you mean about being "noticeably conservative". Example: Different RNN, beam of 11, fixed seed and primetext "The Dragon Reborn ": No prime: Even an extremely wide beam of 50 doesn't help with this issue: Same prime, but beam of 2: (Kind of odd. I had been expecting b>10 to be a waste of computation, not make results worse.) A little more systematically looking at beam widths 1-20 using a char-rnn I trained on my website & IRC logs a while back, with the shell command:
I don't see any obvious quality trend other than b=1 is the worst, and it's interesting that hyperlinks only begin to show up with higher beam widths. Going back to my CSS RNN, using high beam widths triggers the same repetition issue. What about my data URI problem? I tried forcing it to manifest by providing the data URI preamble as a prime text, which worked, since even 10k character is not enough to break out of the URI trap: I did another loop generating 5k characters to see if at any particular beam width it'd break out. However, this might reflect the idiosyncratic nature of my trained RNNs with potential overfitting etc since I'm not very good at this. If someone wants to try it out on other datasets like Tiny Shakespeare and see if there's similar catastrophic repetition at higher beam searches, that would be interesting. (Could pender's implementation be wrong? I still find it weird that more beam search makes things much worse.) So to summarize:
|
|
I think it's recognized theoretically that widening the beam can result in an inferior result, and I seem to recall that the papers that mention beam search in the context of sequence prediction nets generally use a beam width of 2, or another low value. But in this case my intuition is that a sequence prediction net trained to predict the next single token will have stretches when it is very confident about the next token (e.g. when reciting a really long word, or after seeing the first three characters of the string "http://www."), and occasions when it is much less confident (e.g. when starting the first word of a new sentence). The wider your beam, the harder you're selecting for sequences that contain long stretches of predictable characters. So you end up in a loop of a particularly predictable string of text, stitched together in a location where it wasn't confident about what would come next anyway. Whatever minimal probability penalty it takes by looping back to the start of the phrase when it was in a toss-up situation anyway is outweighed by the gains of being so confident about the rest of the phrase. If I've got this right, then the flaw is in the net rather than the search algorithm, and if you searched exhaustively and found the sequence of a particular length with the globally highest cumulative probability, it would look similar. The first crash that you noted above seems to be the result of a misspecified network name, and I suspect that the second is an issue that occurs when your beam is wider than the number of characters in your net's vocabulary. (Given the poor results of wide beams, I haven't done much testing of beam widths greater than 10, nor any for beams widths greater than 50.) I just pushed a fix to the latter. |
|
Thanks again to pender for the excellent 'sample-beam.lua' which greatly improves sampling results with char-rnn. I've also been testing this word level version: https://github.com/larspars/word-rnn Is it possible to modify 'sample-beam.lua' to work properly at the word level? Current sampling output is concatenated: Thequickbrownfoxjumpedoverthelazydog. Instead of The quick brown fox jumped over the lazy dog. I'm new to lua/torch and would appreciate any suggestions to modify the code. This link provides a python solution as a possible reference: http://stackoverflow.com/questions/8870261/how-to-split-text-without-spaces-into-list-of-words Cheers, |
(This is a followup to my earlier comment on "The Unreasonable Effectiveness of Recurrent Neural Networks".)
Once a
char-rnnis trained, the user wants to use the RNN to generate funny and plausible-sounding text by sampling from it.Currently,
sample.luadoes generating in a greedy temperature-based fashion, selecting one character at a time with occasional random picks of lower probability characters. This works fairly well but the local/greedy approach can yield suboptimal global picks - each individual character is sensible, but the overall paragraph is especially nonsensical. Sometimes sampling can get trapped in local minima as well. (I believe that this is what happened when I tried outchar-rnnon generating CSS and sampling would become 'stuck' on data URIs, where it could only keep emitting more random base-64 characters because it was too hard to reach the end-of-URI delimiter, leading to thousands of characters of very low quality and low probability.)A better sampling strategy would be beam search. Such a strategy might look like: at each character, the most probable b next characters are sampled; to quote one description of how this applies to RNN sampling:
Beam search has been applied to RNN work before and generally yields better results more fully exploiting what the RNNs have learned:
The downside is (aside from the work of coding an implementation) that for a large RNN which already takes something like 1s/character, beam search will slow it down even further.
On the other hand, I don't think anyone is seriously trying to apply
char-rnnto anything which needs high performance in generating text and for the most part users would be happy to have better sampled text at the cost of some more time when generating but no need for additional text to train on or GPU training or hyperparameter search or architectural innovation. I would suggest that it be enabled by default as well so users don't need to evaluate the tradeoff themselves; a beam of 5 seems like, from the earlier papers, adequate for better quality.The text was updated successfully, but these errors were encountered: