Permalink
Browse files

adding beam search, helps quite a bit

  • Loading branch information...
1 parent 590c8bc commit 73e1715f4054e9d23e590aae96b24bdbf0edae14 @karpathy committed Nov 24, 2015
Showing with 221 additions and 11 deletions.
  1. +2 −0 README.md
  2. +1 −1 convert_checkpoint_gpu_to_cpu.lua
  3. +8 −2 eval.lua
  4. +144 −6 misc/LanguageModel.lua
  5. +9 −0 misc/net_utils.lua
  6. +55 −0 test_language_model.lua
  7. +2 −2 train.lua
View
@@ -74,6 +74,8 @@ Now visit `localhost:4000` in your browser and you should see your predicted cap
You can see an [example visualization demo page here](http://cs.stanford.edu/people/karpathy/neuraltalk2/demo.html).
+**Beam Search**. To improve the captioning performance you want to also use the flag `-beam_size`. E.g. `-beam_size 5` uses 5 beams to perform a more exhaustive search of argmax sequences for each image, resulting in better performance. However, beam search is more computationally expensive so the predictions will take a bit more time to compute.
+
#### I only have CPU
Okay, in that case you can download the [cpu model checkpoint](http://cs.stanford.edu/people/karpathy/neuraltalk2/checkpoint_v1_cpu.zip), which does not require the GPU. Make sure you run the eval script with `-gpuid -1` to tell the script to run on CPU. On my machine it takes a bit less than 1 second per image to caption in CPU mode.
@@ -84,7 +84,7 @@ for k,v in pairs(protos) do
protos[k] = cpu_cnn
elseif k == 'lm' then
local debugger = require('fb.debugger'); debugger:enter()
- v.clones = nil -- sanitize the clones inside the language model
+ v.clones = nil -- sanitize the clones inside the language model (if present just in case. but they shouldnt be)
v.lookup_tables = nil
protos[k]:float() -- ship to CPU
else
View
@@ -33,6 +33,11 @@ cmd:option('-image_root', '', 'In case the image paths have to be preprended wit
cmd:option('-batch_size', 0, 'if > 0 then overrule, otherwise load from checkpoint')
cmd:option('-split', 'test', 'val|test|train')
+-- sampling options
+cmd:option('-sample_max', 1, '1 = sample argmax words. 0 = sample from distributions.')
+cmd:option('-beam_size', 1, 'used when sample_max = 1, indicates number of beams in beam search.')
+cmd:option('-temperature', 1.0, 'temperature when sampling from distributions (i.e. when sample_max = 0). Lower = "safer" predictions.')
+
-- misc
cmd:option('-backend', 'cudnn', 'nn|cudnn')
cmd:option('-id', 'evalscript', 'an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files')
@@ -125,7 +130,8 @@ local function eval_split(split, evalopt)
end
-- forward the model to also get generated samples for each image
- local seq = protos.lm:sample(feats)
+ local sample_opts = { sample_max = opt.sample_max, beam_size = opt.beam_size, temperature = opt.temperature }
+ local seq = protos.lm:sample(feats, sample_opts)
local sents = net_utils.decode_sequence(vocab, seq)
for k=1,#sents do
local entry = {image_id = data.infos[k].id, caption = sents[k]}
@@ -149,7 +155,7 @@ local function eval_split(split, evalopt)
end
if data.bounds.wrapped then break end -- the split ran out of data, lets break out
- if n >= num_images then break end -- we've used enough images
+ if num_images >= 0 and n >= num_images then break end -- we've used enough images
end
local lang_stats
View
@@ -96,8 +96,10 @@ Returns: a DxN LongTensor with integer elements 1..M,
where D is sequence length and N is batch (so columns are sequences)
--]]
function layer:sample(imgs, opt)
- local sample_max = utils.getopt(opt, 'sample_max', true)
- assert(sample_max, 'see todo below')
+ local sample_max = utils.getopt(opt, 'sample_max', 1)
+ local beam_size = utils.getopt(opt, 'beam_size', 1)
+ local temperature = utils.getopt(opt, 'temperature', 1.0)
+ if sample_max == 1 and beam_size > 1 then return self:sample_beam(imgs, opt) end -- indirection for beam search
local batch_size = imgs:size(1)
self:_createInitState(batch_size)
@@ -119,13 +121,22 @@ function layer:sample(imgs, opt)
xt = self.lookup_table:forward(it)
else
-- take predictions from previous time step and feed them in
- if sample_max then
+ if sample_max == 1 then
+ -- use argmax "sampling"
sampleLogprobs, it = torch.max(logprobs, 2)
it = it:view(-1):long()
else
- --local prob_prev = torch.exp(logprobs) -- fetch prev distribution
- --it = torch.multinomial(prob_prev, 1):view(-1):long()
- error('todo') -- todo: gather into sampleLogprobs later if we want to sample
+ -- sample from the distribution of previous predictions
+ local prob_prev
+ if temperature == 1.0 then
+ prob_prev = torch.exp(logprobs) -- fetch prev distribution: shape Nx(M+1)
+ else
+ -- scale logprobs by temperature
+ prob_prev = torch.exp(torch.div(logprobs, temperature))
+ end
+ it = torch.multinomial(prob_prev, 1)
+ sampleLogprobs = logprobs:gather(2, it) -- gather the logprobs at sampled positions
+ it = it:view(-1):long() -- and flatten indices for downstream processing
end
xt = self.lookup_table:forward(it)
end
@@ -147,6 +158,133 @@ function layer:sample(imgs, opt)
end
--[[
+Implements beam search. Really tricky indexing stuff going on inside.
+Not 100% sure it's correct, and hard to fully unit test to satisfaction, but
+it seems to work, doesn't crash, gives expected looking outputs, and seems to
+improve performance, so I am declaring this correct.
+]]--
+function layer:sample_beam(imgs, opt)
+ local beam_size = utils.getopt(opt, 'beam_size', 10)
+ local batch_size, feat_dim = imgs:size(1), imgs:size(2)
+ local function compare(a,b) return a.p > b.p end -- used downstream
+
+ assert(beam_size <= self.vocab_size+1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed')
+
+ local seq = torch.LongTensor(self.seq_length, batch_size):zero()
+ local seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
+ -- lets process every image independently for now, for simplicity
+ for k=1,batch_size do
+
+ -- create initial states for all beams
+ self:_createInitState(beam_size)
+ local state = self.init_state
+
+ -- we will write output predictions into tensor seq
+ local beam_seq = torch.LongTensor(self.seq_length, beam_size):zero()
+ local beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size):zero()
+ local beam_logprobs_sum = torch.zeros(beam_size) -- running sum of logprobs for each beam
+ local logprobs -- logprobs predicted in last time step, shape (beam_size, vocab_size+1)
+ local done_beams = {}
+ for t=1,self.seq_length+2 do
+
+ local xt, it, sampleLogprobs
+ local new_state
+ if t == 1 then
+ -- feed in the images
+ local imgk = imgs[{ {k,k} }]:expand(beam_size, feat_dim) -- k'th image feature expanded out
+ xt = imgk
+ elseif t == 2 then
+ -- feed in the start tokens
+ it = torch.LongTensor(beam_size):fill(self.vocab_size+1)
+ xt = self.lookup_table:forward(it)
+ else
+ --[[
+ perform a beam merge. that is,
+ for every previous beam we now many new possibilities to branch out
+ we need to resort our beams to maintain the loop invariant of keeping
+ the top beam_size most likely sequences.
+ ]]--
+ local logprobsf = logprobs:float() -- lets go to CPU for more efficiency in indexing operations
+ ys,ix = torch.sort(logprobsf,2,true) -- sorted array of logprobs along each previous beam (last true = descending)
+ local candidates = {}
+ local cols = math.min(beam_size,ys:size(2))
+ local rows = beam_size
+ if t == 3 then rows = 1 end -- at first time step only the first beam is active
+ for c=1,cols do -- for each column (word, essentially)
+ for q=1,rows do -- for each beam expansion
+ -- compute logprob of expanding beam q with word in (sorted) position c
+ local local_logprob = ys[{ q,c }]
+ local candidate_logprob = beam_logprobs_sum[q] + local_logprob
+ table.insert(candidates, {c=ix[{ q,c }], q=q, p=candidate_logprob, r=local_logprob })
+ end
+ end
+ table.sort(candidates, compare) -- find the best c,q pairs
+
+ -- construct new beams
+ new_state = net_utils.clone_list(state)
+ local beam_seq_prev, beam_seq_logprobs_prev
+ if t > 3 then
+ -- well need these as reference when we fork beams around
+ beam_seq_prev = beam_seq[{ {1,t-3}, {} }]:clone()
+ beam_seq_logprobs_prev = beam_seq_logprobs[{ {1,t-3}, {} }]:clone()
+ end
+ for vix=1,beam_size do
+ local v = candidates[vix]
+ -- fork beam index q into index vix
+ if t > 3 then
+ beam_seq[{ {1,t-3}, vix }] = beam_seq_prev[{ {}, v.q }]
+ beam_seq_logprobs[{ {1,t-3}, vix }] = beam_seq_logprobs_prev[{ {}, v.q }]
+ end
+ -- rearrange recurrent states
+ for state_ix = 1,#new_state do
+ -- copy over state in previous beam q to new beam at vix
+ new_state[state_ix][vix] = state[state_ix][v.q]
+ end
+ -- append new end terminal at the end of this beam
+ beam_seq[{ t-2, vix }] = v.c -- c'th word is the continuation
+ beam_seq_logprobs[{ t-2, vix }] = v.r -- the raw logprob here
+ beam_logprobs_sum[vix] = v.p -- the new (sum) logprob along this beam
+
+ if v.c == self.vocab_size+1 or t == self.seq_length+2 then
+ -- END token special case here, or we reached the end.
+ -- add the beam to a set of done beams
+ table.insert(done_beams, {seq = beam_seq[{ {}, vix }]:clone(),
+ logps = beam_seq_logprobs[{ {}, vix }]:clone(),
+ p = beam_logprobs_sum[vix]
+ })
+ end
+ end
+
+ -- encode as vectors
+ it = beam_seq[t-2]
+ xt = self.lookup_table:forward(it)
+ end
+
+ if new_state then state = new_state end -- swap rnn state, if we reassinged beams
+
+ local inputs = {xt,unpack(state)}
+ local out = self.core:forward(inputs)
+ logprobs = out[self.num_state+1] -- last element is the output vector
+ state = {}
+ for i=1,self.num_state do table.insert(state, out[i]) end
+ end
+
+ table.sort(done_beams, compare)
+ seq[{ {}, k }] = done_beams[1].seq -- the first beam has highest cumulative score
+ seqLogprobs[{ {}, k }] = done_beams[1].logps
+
+ --local debugger = require('fb.debugger'); debugger:enter()
+
+ -- record the best beam for this image
+ --seq[{ {}, k }] = beam_seq[{ {}, 1 }] -- the first beam has highest cumulative score
+ --seqLogprobs[{ {}, k }] = beam_seq_logprobs[{ {}, 1 }]
+ end
+
+ -- return the samples and their log likelihoods
+ return seq, seqLogprobs
+end
+
+--[[
input is a tuple of:
1. torch.Tensor of size NxK (K is dim of image code)
2. torch.LongTensor of size DxN, elements 1..M
View
@@ -171,6 +171,15 @@ function net_utils.decode_sequence(ix_to_word, seq)
return out
end
+function net_utils.clone_list(lst)
+ -- takes list of tensors, clone all
+ local new = {}
+ for k,v in pairs(lst) do
+ new[k] = v:clone()
+ end
+ return new
+end
+
-- hiding this piece of code on the bottom of the file, in hopes that
-- noone will ever find it. Lets just pretend it doesn't exist
function net_utils.language_eval(predictions, id)
@@ -254,13 +254,68 @@ local function sample()
print(seq)
end
+
+-- check that we can call :sample_beam() and that correct-looking things happen
+-- these are not very exhaustive tests and basic sanity checks
+local function sample_beam()
+ local dtype = 'torch.DoubleTensor'
+ torch.manualSeed(1)
+
+ local opt = {}
+ opt.vocab_size = 10
+ opt.input_encoding_size = 4
+ opt.rnn_size = 8
+ opt.num_layers = 1
+ opt.dropout = 0
+ opt.seq_length = 7
+ opt.batch_size = 6
+ local lm = nn.LanguageModel(opt)
+
+ local imgs = torch.randn(opt.batch_size, opt.input_encoding_size):type(dtype)
+
+ local seq_vanilla, logprobs_vanilla = lm:sample(imgs)
+ local seq, logprobs = lm:sample(imgs, {beam_size = 1})
+
+ -- check some basic I/O, types, etc.
+ tester:assertTensorSizeEq(seq, {opt.seq_length, opt.batch_size})
+ tester:asserteq(seq:type(), 'torch.LongTensor')
+ tester:assertge(torch.min(seq), 0)
+ tester:assertle(torch.max(seq), opt.vocab_size+1)
+
+ -- doing beam search with beam size 1 should return exactly what we had before
+ print('')
+ print('vanilla sampling:')
+ print(seq_vanilla)
+ print('beam search sampling with beam size 1:')
+ print(seq)
+ tester:assertTensorEq(seq_vanilla, seq, 0) -- these are LongTensors, expect exact match
+ tester:assertTensorEq(logprobs_vanilla, logprobs, 1e-6) -- logprobs too
+
+ -- doing beam search with higher beam size should yield higher likelihood sequences
+ local seq2, logprobs2 = lm:sample(imgs, {beam_size = 8})
+ local logsum = torch.sum(logprobs, 1)
+ local logsum2 = torch.sum(logprobs2, 1)
+ print('')
+ print('beam search sampling with beam size 1:')
+ print(seq)
+ print('beam search sampling with beam size 8:')
+ print(seq2)
+ print('logprobs:')
+ print(logsum)
+ print(logsum2)
+
+ -- the logprobs should always be >=, since beam_search is better argmax inference
+ tester:assert(torch.all(torch.gt(logsum2, logsum)))
+end
+
tests.doubleApiForwardTest = forwardApiTestFactory('torch.DoubleTensor')
tests.floatApiForwardTest = forwardApiTestFactory('torch.FloatTensor')
tests.cudaApiForwardTest = forwardApiTestFactory('torch.CudaTensor')
tests.gradCheck = gradCheck
tests.gradCheckLM = gradCheckLM
tests.overfit = overfit
tests.sample = sample
+tests.sample_beam = sample_beam
tester:add(tests)
tester:run()
View
@@ -46,8 +46,8 @@ cmd:option('-batch_size',16,'what is the batch size in number of images per batc
cmd:option('-max_iters', -1, 'max number of iterations to run for (-1 = run forever)')
cmd:option('-cnn_optim','adam','optimization to use for CNN')
cmd:option('-cnn_optim_alpha',0.8,'alpha for momentum of CNN')
-cmd:option('-cnn_optim_beta',0.995,'alpha for momentum of CNN')
-cmd:option('-cnn_learning_rate',5e-5,'learning rate for the CNN')
+cmd:option('-cnn_optim_beta',0.999,'alpha for momentum of CNN')
+cmd:option('-cnn_learning_rate',1e-5,'learning rate for the CNN')
cmd:option('-cnn_weight_decay', 0, 'L2 weight decay just for the CNN')
cmd:option('-finetune_cnn_after', -1, 'After what iteration do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)')

0 comments on commit 73e1715

Please sign in to comment.