Skip to content

Commit

Permalink
cleanup code
Browse files Browse the repository at this point in the history
  • Loading branch information
ketranm committed Oct 28, 2016
1 parent b0ce5bb commit 5b66851
Show file tree
Hide file tree
Showing 13 changed files with 1,612 additions and 0 deletions.
248 changes: 248 additions & 0 deletions BaumWelch.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
--[[ Message Passing for the 1st order HMM (linear chain model)
Reference:
[1] A tutorial on hidden Markov models and selected applications
in speech recognition.
L. R. Rabiner ; AT&T Bell Lab., Murray Hill, NJ, USA
[2] http://www.cs.cmu.edu/~roni/11761-s16/assignments/shen_tutorial.pdf
Author: Ke Tran <m.k.tran@uva.nl>
NOTE: This code is written for GPUs, and for the love of the speed.
For that reason, I intentinally use scaling-factor instead of sumlogexp when
computing alpha and beta messages (multiplying is much faster).
The downside of using scaling factor is that, sometimes, we can not compute
exactly those messages due to numerical approximation.
This often happens at the begining of learning.
Nevertheless, when I tested (-set debug = true)
the tolerance limit of approximation is acceptable.
]]

local BaumWelch, parent = torch.class('nn.BaumWelch', 'nn.Module')
local utils = require 'utils'

function BaumWelch:__init(padidx)
self.padidx = padidx
-- for message passing algorithm
self.alpha = torch.Tensor()
self.beta = torch.Tensor()
self.gamma = torch.Tensor()
self.eta = torch.Tensor()
self.scale = torch.Tensor()

-- BUFFER TENSOR
self.prob_trans = torch.Tensor()
self.prob_emiss = torch.Tensor()
self.prob_prior = torch.Tensor()
self.buffer = torch.Tensor()

self.debug = true
end


function BaumWelch:run(input, stats)
local N = input:size(1) -- batch size
local T = input:size(2) -- sequence length
local buffer = self.buffer
local prob_prior = self.prob_prior
local prob_trans = self.prob_trans
local prob_emiss = self.prob_emiss

local log_emiss, log_trans, log_prior = unpack(stats)

local K = log_prior:numel()

prob_prior:resizeAs(log_prior):copy(log_prior):exp()
prob_trans:resizeAs(log_trans):copy(log_trans):exp()
prob_emiss:resizeAs(log_emiss):copy(log_emiss):exp()

local masked = input:ne(self.padidx)
-- we need this for computing correctly the log-likelihood
local masked_pad = input:eq(self.padidx)

-- Message Passing

-- nicely alias
local alpha = self.alpha
local beta = self.beta
local gamma = self.gamma
local scale = self.scale
local eta = self.eta

-- FORWARD MESSAGE

alpha:resize(N, T, K):zero()
scale:resize(N, T):zero()

-- (1) compute the first alpha
local a1 = alpha[{{}, 1, {}}]
a1:add(prob_prior:view(1, -1):expand(N, K))
a1:cmul(prob_emiss[{{}, 1}])

-- rescale
scale[{{}, {1}}] = utils.renorm(alpha[{{}, 1}], 2)

-- (2) compute the rest of alpha
for t = 2, T do
local emi_t = prob_emiss[{{}, {t}}]
local curr_a = alpha[{{}, {t}}] -- (N, 1, K)
local prev_a = alpha[{{}, {t-1}}]
local tran_t = prob_trans[{{}, t-1}]
-- transition matrix is row major, sum over row should return 1
curr_a:bmm(prev_a, tran_t):cmul(emi_t)
scale[{{}, {t}}] = utils.renorm(alpha[{{}, t}], 2)
end

-- BACKWARD MESSAGE

beta:resize(N, T, K):fill(1)

-- because we store inverted of scaling factor
beta[{{}, T}]:cdiv(scale[{{}, {T}}]:expand(N, K))
buffer:resize(N, 1, K)

-- NOTE: if pad is in the last index, we need to overwrite beta
-- check boundary of sequence where pad appear for the first time
-- eos: is used to check the true eos, after this point, pad appears
-- because pad is always at the end of the sequence,
-- we will do it in one go

local eos = masked[{{}, {1, T-1}}]:ne(masked[{{}, {2, T}}])
for t = T-1, 1, -1 do
local eos_t = eos[{{}, {t}}]:expand(N, K)
local emi_t = prob_emiss[{{}, {t+1}}]
local prev_b = beta[{{}, {t+1}}]
local curr_b = beta[{{}, {t}}]
buffer:cmul(prev_b, emi_t)
local tran_t = prob_trans[{{}, t}]
curr_b:bmm(buffer, tran_t:transpose(2, 3))
if eos_t:sum() > 0 then
curr_b:maskedFill(eos_t, 1)
end
curr_b:cdiv(scale[{{}, {t}}]:expand(N, K))
end

-- compute posteriors
-- NOTE: the beta message is computed correctly up to EOS symbols
-- after that, it's incorrect, but we keep it for the sake of speed
-- the gamma is correctly computed, we will use masked_w to zero out
-- EOS symbols

gamma:resize(N, T, K):zero()
for t = 1, T do
local gamma_t = gamma[{{}, {t}}]
gamma_t:cmul(alpha[{{}, {t}}], beta[{{}, {t}}])
-- NOTE un-comment for debugging purpose
--[[
gamma_t:cmul(scale[{{}, {t}}]:expand(N, K)) -- sweet eq(110), partial term
if self.debug then
-- checking correctness p(z | s): \sum_z p(z | s) = 1
local checksum = gamma[{{}, t}]:sum(2):add(-1):cmul(masked[{{}, t}]):abs():sum()
assert(checksum < 1e-2, string.format('gamma checksum error %.7f', checksum))
end
]]
end
-- comment out the following line (renorm) in debugging mode
utils.renorm(gamma, 3)
gamma:cmul(masked:view(N, T, 1):expand(N, T, K))

--[[ Compute eta
Now we compute eta. The eta is only available from the begining of sequence
to the index before the real end of sequence
for example, if 0 is used to indicate EOS then
input = 1 2 3 4 5 0 0 0
the etas are only needed for indices 1 2 3 4
which are corresponding to transition from (1, 2), (2, 3), ..., (4, 5)
so to compute the correct eta at time step t, we need to know whether
word at time step t+1 is a EOS or not.
]]

eta:resize(N, T, K, K):zero()
for t = 1, T-1 do
local emi_t = prob_emiss[{{}, {t+1}}]
local bmsg = beta[{{}, {t+1}}]
local amsg = alpha[{{}, {t}}]
local tran_t = prob_trans[{{}, t}]
local eta_t = eta[{{}, t}]
-- NOTE: un-comment for debugging
--[[
bmsg:cmul(emi_t):cmul(masked[{{}, {t+1}}]:expand(N, K))
eta_t:bmm(amsg:transpose(2, 3), bmsg):cmul(tran_t)
if self.debug then
-- this is what happened: when we see the real symbol before padding (EOS) at time t
-- the p(z_t | x) exists but p(z_t, z_{t+1}| x)
-- so we have to zero out p(z_t | x) when do checking
local eos_t= eos[{{}, {t}}] -- check for end of sequence
local gamma_t = gamma[{{}, t}]
local derr = eta_t:sum(3):squeeze():add(-1, gamma_t)
derr:maskedFill(eos_t:expand(N, K), 0)
local checksum = derr:abs():sum()
assert(checksum < 1e-3, string.format('eta checksum error %.7f', checksum))
-- good job Ke! This is pain in the ass.
end
]]
-- comment out the following lines in debugging mode
bmsg:cmul(emi_t)
eta_t:bmm(amsg:transpose(2, 3), bmsg):cmul(tran_t) -- will be N, K, K
local z = eta_t:sum(2):sum(3):expand(N, K, K)
eta_t:cdiv(z):cmul(masked[{{}, {t+1}}]:contiguous():view(N, 1, 1):expand(N, K, K))
--eta_t:cdiv(z):cmul(masked[{{}, {t}}]:contiguous():view(N, 1, 1):expand(N, K, K))
end

local prior = gamma[{{}, 1}]:sum(1):squeeze()

scale:maskedFill(masked_pad, 1)
local loglik = scale:clone():log():sum() / masked:sum()
-- return posteriors
return {prior, eta, gamma}, loglik
end

function BaumWelch:argmax(input, stats)
-- inference, we just need alpha message
local T = input:numel()
local prob_prior = self.prob_prior
local prob_trans = self.prob_trans
local prob_emiss = self.prob_emiss

local log_emiss, log_trans, log_prior = unpack(stats)

local K = log_prior:numel()

prob_prior:resizeAs(log_prior):copy(log_prior):exp()
prob_trans:resizeAs(log_trans):copy(log_trans):exp()
prob_emiss:resizeAs(log_emiss):copy(log_emiss):exp()

local alpha = self.alpha
alpha:resize(T, K):zero()

-- borrow viterbi-path implementation from Kevin Murphy
-- psi[t][j]: the best predecessor state,
-- given that we ended up in state j at t
local psi = torch.zeros(T, K):typeAs(alpha)

local a1 = alpha[{1, {}}]
a1:add(prob_prior)
a1:cmul(prob_emiss[{{}, 1}])
utils.renorm(alpha[{{1}, {}}], 2)

for t = 2, T do
local emi_t = prob_emiss[{{}, t}]
local curr_a = alpha[t]
local prev_a = alpha[t-1]
local z = prev_a:view(-1, 1):repeatTensor(1, K)
z:cmul(prob_trans[{{}, t-1}])
local val, idx = z:max(1)
psi[t]:copy(idx)
curr_a:copy(val)
curr_a:cmul(emi_t)
utils.renorm(curr_a, 1)
end
local val, idx = alpha[{T, {}}]:max(1)
local path = torch.zeros(T)
path[T] = idx[1]
for t = T-1, 1, -1 do
path[t] = psi[t+1][path[t+1]]
end

return path
end
85 changes: 85 additions & 0 deletions EmiConv.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
local factory = require 'factory'

local model, parent = torch.class('nn.EmiConv', 'nn.Module')
function model:__init(word2char, nvars, feature_maps, kernels, charsize, hidsize)
local K = nvars
self.word2char = word2char
local char_dim = charsize
local H = hidsize
local featsize = torch.Tensor(feature_maps):sum()
local nchars = word2char:max()
local maxchars = word2char:size(2)
local V = word2char:size(1)

local char_cnn = factory.build_cnn(feature_maps, kernels, charsize, hidsize, nchars, maxchars)

local state_emb = nn.Sequential()
state_emb:add(nn.LookupTable(K, H))
state_emb:add(nn.ReLU())

local prl = nn.ParallelTable()
prl:add(state_emb)
prl:add(char_cnn)

local bias = nn.Linear(1, V, false) -- shared

local emi0b = nn.Sequential()
emi0b:add(prl)
emi0b:add(nn.MM(false, true))

local prlx = nn.ParallelTable()
prlx:add(emi0b)
prlx:add(bias)

local emi = nn.Sequential()
emi:add(prlx)
emi:add(nn.CAddTable())
emi:add(nn.LogSoftMax())
self.net = emi

self._input = {{torch.range(1, K), self.word2char}, torch.ones(K, 1)}
self.gradOutput = torch.Tensor(K, V)
self._buffer = torch.Tensor()
end

function model:reset()
self.net:reset()
end

function model:training()
self.net:training()
parent.training(self)
end

function model:evaluate()
self.net:evaluate()
parent.evaluate(self)
end

function model:precompute()
self._cache = self.net:forward(self._input)
end

function model:log_prob(input)
local N, T = input:size(1), input:size(2)
if not self._cache then
self._logp = self.net:forward(self._input)
else
self._logp = self._cache
end

return self._logp:index(2, input:view(-1)):view(-1, N, T):transpose(1, 2):transpose(2, 3)
end

function model:update(input, gradOutput)
local N, T = input:size(1), input:size(2)
local dx = gradOutput:transpose(2, 3):transpose(1, 2)
self._buffer:resizeAs(dx):copy(dx)
self.gradOutput:zero()
self.gradOutput:indexAdd(2, input:view(-1), self._buffer:view(-1, N * T))
self.net:backward(self._input, self.gradOutput)
end

function model:parameters()
return self.net:parameters()
end
53 changes: 53 additions & 0 deletions Emission.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
local model, parent = torch.class('nn.EmiNet', 'nn.Module')

function model:__init(nobs, nvars, hidsize)
local K, V, H = nvars, nobs, hidsize
self.net = nn.Sequential()
self.net:add(nn.LookupTable(K, H))
self.net:add(nn.ReLU())
self.net:add(nn.Linear(H, H))
self.net:add(nn.ReLU())
self.net:add(nn.Linear(H, V))
self.net:add(nn.LogSoftMax())

self._input = torch.range(1, K)
self.gradOutput = torch.Tensor(K, V)
self._buffer = torch.Tensor()
end

function model:reset()
self.net:reset()
end

function model:parameters()
return self.net:parameters()
end

function model:precompute()
self._cache = self.net:forward(self._input)
end

function model:log_prob(input)
local N, T = input:size(1), input:size(2)
if not self._cache then
self._logp = self.net:forward(self._input)
else
self._logp = self._cache
end

return self._logp:index(2, input:view(-1)):view(-1, N, T):transpose(1, 2):transpose(2, 3)
end

function model:update(input, gradOutput)
local N, T = input:size(1), input:size(2)
local dx = gradOutput:transpose(2, 3):transpose(1, 2)
self._buffer:resizeAs(dx):copy(dx)
self.gradOutput:zero()
self.gradOutput:indexAdd(2, input:view(-1), self._buffer:view(-1, N * T))
self.net:backward(self._input, self.gradOutput)
end


function model:parameters()
return self.net:parameters()
end
Loading

0 comments on commit 5b66851

Please sign in to comment.