Permalink
Find file
Fetching contributors…
Cannot retrieve contributors at this time
115 lines (93 sloc) 2.83 KB
-- ref.: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/
local AM = torch.class("torch.AliasMultinomial")
function AM:__init(probs)
self.J, self.q = self:setup(probs)
end
function AM:setup(probs)
assert(probs:dim() == 1)
local K = probs:nElement()
local q = probs.new(K):zero()
local J = torch.LongTensor(K):zero()
-- Sort the data into the outcomes with probabilities
-- that are larger and smaller than 1/K.
local smaller, larger = {}, {}
local maxk, maxp = 0, -1
for kk = 1,K do
local prob = probs[kk]
q[kk] = K*prob
if q[kk] < 1 then
table.insert(smaller, kk)
else
table.insert(larger, kk)
end
if maxk > maxp then
end
end
-- Loop through and create little binary mixtures that
-- appropriately allocate the larger outcomes over the
-- overall uniform mixture.
while #smaller > 0 and #larger > 0 do
local small = table.remove(smaller)
local large = table.remove(larger)
J[small] = large
q[large] = q[large] - (1.0 - q[small])
if q[large] < 1.0 then
table.insert(smaller,large)
else
table.insert(larger,large)
end
end
assert(q:min() >= 0)
if q:max() > 1 then
q:div(q:max())
end
assert(q:max() <= 1)
if J:min() <= 0 then
-- sometimes an large index isn't added to J.
-- fix it by making the probability 1 so that J isn't indexed.
local i = 0
J:apply(function(x)
i = i + 1
if x <= 0 then
q[i] = 1
end
end)
end
return J, q
end
function AM:draw()
J = self.J
q = self.q
local K = J:nElement()
-- Draw from the overall uniform mixture.
local kk = math.random(1,K)
-- Draw from the binary mixture, either keeping the
-- small one, or choosing the associated larger one.
if math.random() < q[kk] then
return kk
else
return J[kk]
end
end
function AM:batchdraw(output)
assert(torch.type(output) == 'torch.LongTensor')
assert(output:nElement() > 0)
local J = self.J
local K = J:nElement()
self._kk = self._kk or output.new()
self._kk:resizeAs(output):random(1,K)
self._q = self._q or self.q.new()
self._q:index(self.q, 1, self._kk:view(-1))
self._mask = self._b or torch.LongTensor()
self._mask:resize(self._q:size()):bernoulli(self._q)
self.__kk = self.__kk or output.new()
self.__kk:resize(self._kk:size()):copy(self._kk)
self.__kk:cmul(self._mask)
-- if mask == 0 then output[i] = J[kk[i]] else output[i] = 0
self._mask:add(-1):mul(-1) -- (1,0) - > (0,1)
output:view(-1):index(J, 1, self._kk:view(-1))
output:cmul(self._mask)
-- elseif mask == 1 then output[i] = kk[i]
output:add(self.__kk)
return output
end