Permalink
Browse files

Added performance testing and reorganized

  • Loading branch information...
1 parent f5820c0 commit c574a3501624e53120d43429c5cba15e8f0f80dd @mikeizbicki committed Mar 1, 2012
Showing with 375 additions and 323 deletions.
  1. +164 −128 HMM.hs
  2. +0 −195 HMM2.hs
  3. +77 −0 HMMPerf.hs
  4. +134 −0 OldHMM.hs
View
292 HMM.hs
@@ -1,139 +1,175 @@
-{-# LANGUAGE ParallelListComp #-}
-{-
-Adapted from Hackage.HMM
--}
-
-module HMM
- (Prob, HMM, HMM.HMM, train, bestSequence, sequenceProb)
+module HMM
+ ( HMM(..), rnf
+ , forward
+ , backward
+ , baumWelch
+ )
where
-
-import qualified Data.Map as M
-import Data.List (sort, groupBy, maximumBy, foldl')
-import Data.Maybe (fromMaybe, fromJust)
-import Data.Ord (comparing)
-import Data.Function (on)
-import Control.Monad
+import Debug.Trace
+import Data.Array
+import Data.List
import Data.Number.LogFloat
-
+import qualified Data.MemoCombinators as Memo
+import Control.DeepSeq
type Prob = LogFloat
--- | The type of Hidden Markov Models.
-data HMM state observation = HMM [state] [Prob] [[Prob]] (observation -> [Prob])
+ -- | The data type for our HMM
+
+data -- (Eq eventType, Eq stateType, Show eventType, Show stateType) =>
+ HMM stateType eventType = HMM { states :: [stateType]
+ , events :: [eventType]
+ , initProbs :: (stateType -> Prob)
+ , transMatrix :: (stateType -> stateType -> Prob)
+ , outMatrix :: (stateType -> eventType -> Prob)
+ }
+
+instance NFData (HMM stateType eventType) where
+ rnf a = a `seq` ()
instance (Show state, Show observation) => Show (HMM state observation) where
- show (HMM states probs tpm _) = "HMM " ++ show states ++ " "
- ++ show probs ++ " " ++ show tpm ++ " <func>"
-
--- | Perform a single step in the Viterbi algorithm.
---
--- Takes a list of path probabilities, and an observation, and returns the updated
--- list of (surviving) paths with probabilities.
-viterbi :: HMM state observation
- -> [(Prob, [state])]
- -> observation
- -> [(Prob, [state])]
-viterbi (HMM states _ state_transitions observations) prev x =
- deepSeq prev `seq`
- [maximumBy (comparing fst)
- [(transition_prob * prev_prob * observation_prob,
- new_state:path)
- | transition_prob <- transition_probs
- | (prev_prob, path) <- prev
- | observation_prob <- observation_probs]
- | transition_probs <- state_transitions
- | new_state <- states]
+ show hmm = "HMM" ++ " states=" ++ (show $ states hmm)
+ ++ " events=" ++ (show $ events hmm)
+ ++ " initProbs=" ++ (show [(s,initProbs hmm s) | s <- states hmm])
+ ++ " transMatrix=" ++ (show [(s1,s2,transMatrix hmm s1 s2) | s1 <- states hmm, s2 <- states hmm])
+ ++ " outMatrix=" ++ (show [(s,e,outMatrix hmm s e) | s <- states hmm, e <- events hmm])
+
+stateIndex :: (Show stateType, Show eventType, Eq stateType) => HMM stateType eventType -> stateType -> Int
+stateIndex hmm state = case elemIndex state $ states hmm of
+ Nothing -> seq (error "stateIndex: Index "++show state++" not in HMM "++show hmm) 0
+ Just x -> x
+
+eventIndex :: (Show stateType, Show eventType, Eq eventType) => HMM stateType eventType -> eventType -> Int
+eventIndex hmm event = case elemIndex event $ events hmm of
+ Nothing -> seq (error "stateIndex: Index "++show event++" not in HMM "++show hmm) 0
+ Just x -> x
+
+ -- | forward algorithm
+
+forward :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> [eventType] -> Prob
+forward hmm obs = forwardArray hmm (listArray (1,bT) obs)
where
- observation_probs = observations x
- deepSeq ((x, y:ys):xs) = x `seq` y `seq` (deepSeq xs)
- deepSeq ((x, _):xs) = x `seq` (deepSeq xs)
- deepSeq [] = []
-
--- | The initial value for the Viterbi algorithm
-viterbi_init :: HMM state observation -> [(Prob, [state])]
-viterbi_init (HMM states state_probs _ _) = zip state_probs (map (:[]) states)
-
--- | Perform a single step of the forward algorithm
---
--- Each item in the input and output list is the probability that the system
--- ended in the respective state.
-forward :: HMM state observation
- -> [Prob]
- -> observation
- -> [Prob]
-forward (HMM _ _ state_transitions observations) prev x =
- last prev `seq`
- [sum [transition_prob * prev_prob * observation_prob
- | transition_prob <- transition_probs
- | prev_prob <- prev
- | observation_prob <- observation_probs]
- | transition_probs <- state_transitions]
+ bT = length obs
+
+forwardArray :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> Array Int eventType -> Prob
+forwardArray hmm obs = sum [alpha hmm obs bT state | state <- states hmm]
where
- observation_probs = observations x
-
--- | The initial value for the forward algorithm
-forward_init :: HMM state observation -> [Prob]
-forward_init (HMM _ state_probs _ _) = state_probs
-
-learn_states :: (Ord state, Fractional prob) => [(observation, state)] -> M.Map state prob
-learn_states xs = histogram $ map snd xs
-
-learn_transitions :: (Ord state, Fractional prob) => [(observation, state)] -> M.Map (state, state) prob
-learn_transitions xs = let xs' = map snd xs in
- histogram $ zip xs' (tail xs')
-
-learn_observations :: (Ord state, Ord observation, Fractional prob) =>
- M.Map state prob
- -> [(observation, state)]
- -> M.Map (observation, state) prob
-learn_observations state_prob = M.mapWithKey (\ (observation, state) prob -> prob / (fromJust $ M.lookup state state_prob))
- . histogram
-
-histogram :: (Ord a, Fractional prob) => [a] -> M.Map a prob
-histogram xs = let hist = foldl' (flip $ flip (M.insertWith (+)) 1) M.empty xs in
- M.map (/ M.fold (+) 0 hist) hist
-
--- | Calculate the parameters of an HMM from a list of observations
--- and the corresponding states.
-train :: (Ord observation, Ord state) =>
- [(observation, state)]
- -> HMM state observation
-train sample = model
+ bT = snd $ bounds obs
+
+alpha :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType
+ -> Array Int eventType
+ -> Int
+ -> stateType
+ -> Prob
+alpha hmm obs = memo_alpha
+ where memo_alpha t state = memo_alpha2 t (stateIndex hmm state)
+ memo_alpha2 = (Memo.memo2 Memo.integral Memo.integral memo_alpha3)
+ memo_alpha3 t' state'
+ | t' == 1 = -- trace ("memo_alpha' t'="++show t'++", state'="++show state') $
+ (outMatrix hmm (states hmm !! state') $ obs!t')*(initProbs hmm $ states hmm !! state')
+ | otherwise = -- trace ("memo_alpha' t'="++show t'++", state'="++show state') $
+ (outMatrix hmm (states hmm !! state') $ obs!t')*(sum [(memo_alpha (t'-1) state2)*(transMatrix hmm state2 (states hmm !! state')) | state2 <- states hmm])
+
+
+ -- | backwards algorithm
+
+backward :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> [eventType] -> Prob
+backward hmm obs = backwardArray hmm $ listArray (1,length obs) obs
+
+backwardArray :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> Array Int eventType -> Prob
+backwardArray hmm obs = backwardArray' hmm obs
+ where
+ backwardArray' hmm obs = sum [(initProbs hmm state)
+ *(outMatrix hmm state $ obs!1)
+ *(beta hmm obs 1 state)
+ | state <- states hmm
+ ]
+
+beta :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType
+ -> Array Int eventType
+ -> Int
+ -> stateType
+ -> Prob
+beta hmm obs = memo_beta
+ where bT = snd $ bounds obs
+ memo_beta t state = memo_beta2 t (stateIndex hmm state)
+ memo_beta2 = (Memo.memo2 Memo.integral Memo.integral memo_beta3)
+ memo_beta3 t' state'
+ | t' == bT = -- trace ("memo_alpha' t'="++show t'++", state'="++show state') $
+ 1
+ | otherwise = -- trace ("memo_alpha' t'="++show t'++", state'="++show state') $
+ sum [(transMatrix hmm (states hmm !! state') state2)
+ *(outMatrix hmm state2 $ obs!(t'+1))
+ *(memo_beta (t'+1) state2)
+ | state2 <- states hmm
+ ]
+
+
+ -- | Baum-Welch
+
+gammaArray :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType
+ -> Array Int eventType
+ -> Int
+ -> stateType
+ -> Prob
+gammaArray hmm obs t state = (alpha hmm obs t state)
+ *(beta hmm obs t state)
+ /(backwardArray hmm obs)
+
+ -- xi i j = P(state (t-1) == i && state (t) == j | obs, lambda)
+
+xiArray :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType
+ -> Array Int eventType
+ -> Int
+ -> stateType
+ -> stateType
+ -> Prob
+xiArray hmm obs t state1 state2 = (alpha hmm obs (t-1) state1)
+ *(transMatrix hmm state1 state2)
+ *(outMatrix hmm state2 $ obs!t)
+ *(beta hmm obs t state2)
+ /(backwardArray hmm obs)
+
+baumWelch :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> Array Int eventType -> Int -> HMM stateType eventType
+baumWelch hmm obs count
+ | count == 0 = hmm
+ | otherwise = baumWelch (baumWelchItr hmm obs) obs (count-1)
+
+baumWelchItr :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> Array Int eventType -> HMM stateType eventType
+baumWelchItr hmm obs = HMM { states = states hmm
+ , events = events hmm
+ , initProbs = newInitProbs
+ , transMatrix = newTransMatrix
+ , outMatrix = newOutMatrix
+ }
+ where newInitProbs state = gammaArray hmm obs 1 state
+ newTransMatrix state1 state2 = sum [xiArray hmm obs t state1 state2 | t <- [2..(snd $ bounds obs)]]
+ /sum [gammaArray hmm obs t state1 | t <- [2..(snd $ bounds obs)]]
+ newOutMatrix state event = sum [if (obs!t == event)
+ then gammaArray hmm obs t state
+ else 0
+ | t <- [2..(snd $ bounds obs)]
+ ]
+ /sum [gammaArray hmm obs t state | t <- [2..(snd $ bounds obs)]]
+
+ -- | utility functions
+ --
+ -- | takes the cross product of a list multiple times
+
+listCPExp :: [a] -> Int -> [[a]]
+listCPExp language order = listCPExp' order [[]]
where
- states = learn_states sample
- state_list = M.keys states
-
- transitions = learn_transitions sample
- trans_prob_mtx = [[fromMaybe 1e-10 $ M.lookup (old_state, new_state) transitions
- | old_state <- state_list]
- | new_state <- state_list]
-
- observations = learn_observations states sample
- observation_probs = fromMaybe (fill state_list []) . (flip M.lookup $
- M.fromList $ map (\ (e, xs) -> (e, fill state_list xs)) $
- map (\ xs -> (fst $ head xs, map snd xs)) $
- groupBy ((==) `on` fst)
- [(observation, (state, prob))
- | ((observation, state), prob) <- M.toAscList observations])
-
- initial = map (\ state -> (fromJust $ M.lookup state states, [state])) state_list
-
- model = HMM state_list (fill state_list $ M.toAscList states) trans_prob_mtx observation_probs
-
- fill :: Eq state => [state] -> [(state, Prob)] -> [Prob]
- fill states [] = map (const 1e-10) states
- fill (s:states) xs@((s', p):xs') = if s /= s' then
- 1e-10 : fill states xs
- else
- p : fill states xs'
-
--- | Calculate the most likely sequence of states for a given sequence of observations
--- using Viterbi's algorithm
-bestSequence :: (Ord observation) => HMM state observation -> [observation] -> [state]
-bestSequence hmm = (reverse . tail . snd . (maximumBy (comparing fst))) . (foldl' (viterbi hmm) (viterbi_init hmm))
-
--- | Calculate the probability of a given sequence of observations
--- using the forward algorithm.
-sequenceProb :: (Ord observation) => HMM state observation -> [observation] -> Prob
-sequenceProb hmm = sum . (foldl' (forward hmm) (forward_init hmm))
+ listCPExp' order list
+ | order == 0 = list
+ | otherwise = listCPExp' (order-1) [symbol:l | l <- list, symbol <- language]
+
+ -- | tests
+
+-- these should equal ~1 if our recurrence in alpha is correct
+
+forwardtest hmm x = sum [forward hmm e | e <- listCPExp (events hmm) x]
+backwardtest hmm x = sum [backward hmm e | e <- listCPExp (events hmm) x]
+
+fbtest hmm events = "fwd: " ++ show (forward hmm events) ++ " bkwd:" ++ show (backward hmm events)
+
Oops, something went wrong.

0 comments on commit c574a35

Please sign in to comment.