Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

BW now does ~30000bp/min. Began test suite.

  • Loading branch information...
commit b3e59952d419d81c20711e72061d05d47935056d 1 parent c574a35
Mike Izbicki authored
Showing with 160 additions and 54 deletions.
  1. +56 −49 HMM.hs
  2. +37 −5 HMMPerf.hs
  3. +67 −0 HMMTest.hs
105 HMM.hs
View
@@ -3,6 +3,7 @@ module HMM
, forward
, backward
, baumWelch
+ , alpha, beta
)
where
@@ -12,6 +13,7 @@ import Data.List
import Data.Number.LogFloat
import qualified Data.MemoCombinators as Memo
import Control.DeepSeq
+import Control.Parallel
type Prob = LogFloat
@@ -108,68 +110,73 @@ beta hmm obs = memo_beta
-- | Baum-Welch
-gammaArray :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType
+{-gammaArray :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType
-> Array Int eventType
-> Int
-> stateType
- -> Prob
-gammaArray hmm obs t state = (alpha hmm obs t state)
- *(beta hmm obs t state)
- /(backwardArray hmm obs)
+ -> Prob-}
-- xi i j = P(state (t-1) == i && state (t) == j | obs, lambda)
-xiArray :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType
- -> Array Int eventType
- -> Int
- -> stateType
- -> stateType
- -> Prob
-xiArray hmm obs t state1 state2 = (alpha hmm obs (t-1) state1)
- *(transMatrix hmm state1 state2)
- *(outMatrix hmm state2 $ obs!t)
- *(beta hmm obs t state2)
- /(backwardArray hmm obs)
+-- xiArray :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType
+-- -> Array Int eventType
+-- -> Int
+-- -> stateType
+-- -> stateType
+-- -> Prob
+
baumWelch :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> Array Int eventType -> Int -> HMM stateType eventType
baumWelch hmm obs count
| count == 0 = hmm
- | otherwise = baumWelch (baumWelchItr hmm obs) obs (count-1)
+ | otherwise = itr `seq` baumWelch itr obs (count-1)
+ where itr = baumWelchItr hmm obs
baumWelchItr :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> Array Int eventType -> HMM stateType eventType
-baumWelchItr hmm obs = HMM { states = states hmm
+baumWelchItr hmm obs = --par newInitProbs $ par newTransMatrix $ par newOutMatrix
+ HMM { states = states hmm
, events = events hmm
, initProbs = newInitProbs
- , transMatrix = newTransMatrix
- , outMatrix = newOutMatrix
+ , transMatrix = {-transMatrix hmm-} newTransMatrix
+ , outMatrix = {-outMatrix hmm-} newOutMatrix
}
- where newInitProbs state = gammaArray hmm obs 1 state
- newTransMatrix state1 state2 = sum [xiArray hmm obs t state1 state2 | t <- [2..(snd $ bounds obs)]]
- /sum [gammaArray hmm obs t state1 | t <- [2..(snd $ bounds obs)]]
- newOutMatrix state event = sum [if (obs!t == event)
- then gammaArray hmm obs t state
- else 0
- | t <- [2..(snd $ bounds obs)]
- ]
- /sum [gammaArray hmm obs t state | t <- [2..(snd $ bounds obs)]]
-
- -- | utility functions
- --
- -- | takes the cross product of a list multiple times
-
-listCPExp :: [a] -> Int -> [[a]]
-listCPExp language order = listCPExp' order [[]]
- where
- listCPExp' order list
- | order == 0 = list
- | otherwise = listCPExp' (order-1) [symbol:l | l <- list, symbol <- language]
-
- -- | tests
-
--- these should equal ~1 if our recurrence in alpha is correct
-
-forwardtest hmm x = sum [forward hmm e | e <- listCPExp (events hmm) x]
-backwardtest hmm x = sum [backward hmm e | e <- listCPExp (events hmm) x]
+ where bT = snd $ bounds obs
+ newInitProbs state = gamma 1 state
+ newTransMatrix state1 state2 = sum [xi t state1 state2 | t <- [2..(snd $ bounds obs)]]
+ /sum [gamma t state1 | t <- [2..(snd $ bounds obs)]]
+ newOutMatrix state event = sum [if (obs!t == event)
+ then gamma t state
+ else 0
+ | t <- [2..(snd $ bounds obs)]
+ ]
+ /sum [gamma t state | t <- [2..(snd $ bounds obs)]]
+
+ -- Greek functions, included here for memoization
+ xi t state1 state2 = (memo_alpha (t-1) state1)
+ *(transMatrix hmm state1 state2)
+ *(outMatrix hmm state2 $ obs!t)
+ *(memo_beta t state2)
+ /backwardArrayVar -- (backwardArray hmm obs)
+
+ gamma t state = (memo_alpha t state)
+ *(memo_beta t state)
+ /backwardArrayVar
+
+ backwardArrayVar = (backwardArray hmm obs)
-fbtest hmm events = "fwd: " ++ show (forward hmm events) ++ " bkwd:" ++ show (backward hmm events)
-
+ memo_beta t state = memo_beta2 t (stateIndex hmm state)
+ memo_beta2 = (Memo.memo2 Memo.integral Memo.integral memo_beta3)
+ memo_beta3 t' state'
+ | t' == bT = 1
+ | otherwise = sum [(transMatrix hmm (states hmm !! state') state2)
+ *(outMatrix hmm state2 $ obs!(t'+1))
+ *(memo_beta (t'+1) state2)
+ | state2 <- states hmm
+ ]
+
+ memo_alpha t state = memo_alpha2 t (stateIndex hmm state)
+ memo_alpha2 = (Memo.memo2 Memo.integral Memo.integral memo_alpha3)
+ memo_alpha3 t' state'
+ | t' == 1 = (outMatrix hmm (states hmm !! state') $ obs!t')*(initProbs hmm $ states hmm !! state')
+ | otherwise = (outMatrix hmm (states hmm !! state') $ obs!t')*(sum [(memo_alpha (t'-1) state2)*(transMatrix hmm state2 (states hmm !! state')) | state2 <- states hmm])
+
42 HMMPerf.hs
View
@@ -22,16 +22,48 @@ forceEval x = putStrLn $ (show $ length str) ++ " -> " ++ (take 30 str)
where str = show x
main = defaultMainWith myConfig (return ())
- [ bench "newHMM - baumWelch 20" $ forceEval $ baumWelch newHMM (genArray 20) 1
- , bench "newHMM - baumWelch 40" $ forceEval $ baumWelch newHMM (genArray 40) 1
- , bench "newHMM - baumWelch 60" $ forceEval $ baumWelch newHMM (genArray 60) 1
- , bench "newHMM - baumWelch 80" $ forceEval $ baumWelch newHMM (genArray 80) 1
- , bench "newHMM - baumWelch 100" $ forceEval $ baumWelch newHMM (genArray 100) 1
+ [ bench "newHMM - baumWelch 10" $ forceEval $ baumWelch newHMM (genArray 10) 1
+ , bench "newHMM - baumWelch 100" $ forceEval $ baumWelch newHMM (genArray 100) 1
+ , bench "newHMM - baumWelch 1000" $ forceEval $ baumWelch newHMM (genArray 1000) 1
+ , bench "newHMM - baumWelch 10000" $ forceEval $ baumWelch newHMM (genArray 10000) 1
+ , bench "newHMM - baumWelch 100000" $ forceEval $ baumWelch newHMM (genArray 100000) 1
{- , bench "newHMM - forward" $ forceEval $ forward newHMM $ genString 100
, bench "newHMM - backward" $ forceEval $ backward newHMM $ genString 100
, bench "oldHMM" $ forceEval $ OldHMM.sequenceProb oldHMM $ genString 100-}
]
+ -- | tests
+
+listCPExp :: [a] -> Int -> [[a]]
+listCPExp language order = listCPExp' order [[]]
+ where
+ listCPExp' order list
+ | order == 0 = list
+ | otherwise = listCPExp' (order-1) [symbol:l | l <- list, symbol <- language]
+
+ -- these should equal ~1 if our recurrence if alpha and beta are correct
+
+forwardtest hmm x = sum [forward hmm e | e <- listCPExp (events hmm) x]
+backwardtest hmm x = sum [backward hmm e | e <- listCPExp (events hmm) x]
+
+fbtest hmm events = "fwd: " ++ show (forward hmm events) ++ " bkwd:" ++ show (backward hmm events)
+
+verifyhmm hmm = do
+ check "initProbs" ip
+ check "transMatrix" tm
+ check "outMatrix" om
+
+ where check str var = do
+ putStrLn $ str++" tollerance check: "++show var
+{- if abs(var-1)<0.0001
+ then putStrLn "True"
+ else putStrLn "False"-}
+
+ ip = sum $ [initProbs hmm s | s <- states hmm]
+ tm = (sum $ [transMatrix hmm s1 s2 | s1 <- states hmm, s2 <- states hmm]) -- (length $ states hmm)
+ om = sum $ [outMatrix hmm s e | s <- states hmm, e <- events hmm] -- / length $ states hmm
+
+
-- | OldHMM definition
-- data HMM state observation = HMM [state] [Prob] [[Prob]] (observation -> [Prob])
67 HMMTest.hs
View
@@ -0,0 +1,67 @@
+import HMM
+
+import Debug.Trace
+
+ -- | utility functions
+ --
+ -- | takes the cross product of a list multiple times
+
+listCPExp :: [a] -> Int -> [[a]]
+listCPExp language order = listCPExp' order [[]]
+ where
+ listCPExp' order list
+ | order == 0 = list
+ | otherwise = listCPExp' (order-1) [symbol:l | l <- list, symbol <- language]
+
+ -- | tests
+
+-- these should equal ~1 if our recurrence if alpha and beta are correct
+
+forwardtest hmm x = sum [forward hmm e | e <- listCPExp (events hmm) x]
+backwardtest hmm x = sum [backward hmm e | e <- listCPExp (events hmm) x]
+
+fbtest hmm events = "fwd: " ++ show (forward hmm events) ++ " bkwd:" ++ show (backward hmm events)
+
+verifyhmm hmm = do
+ check "initProbs" ip
+ check "transMatrix" tm
+ check "outMatrix" om
+
+ where check str var = do
+ putStrLn $ str++" tollerance check: "++show var
+{- if abs(var-1)<0.0001
+ then putStrLn "True"
+ else putStrLn "False"-}
+
+ ip = sum $ [initProbs hmm s | s <- states hmm]
+ tm = (sum $ [transMatrix hmm s1 s2 | s1 <- states hmm, s2 <- states hmm]) -- (length $ states hmm)
+ om = sum $ [outMatrix hmm s e | s <- states hmm, e <- events hmm] -- / length $ states hmm
+
+-- Test HMMs
+
+newHMM = HMM { states=[1,2]
+ , events=['A','G','C','T']
+ , initProbs = ipTest
+ , transMatrix = tmTest
+ , outMatrix = omTest
+ }
+
+ipTest s
+ | s == 1 = 0.1
+ | s == 2 = 0.9
+
+tmTest s1 s2
+ | s1==1 && s2==1 = 0.9
+ | s1==1 && s2==2 = 0.1
+ | s1==2 && s2==1 = 0.5
+ | s1==2 && s2==2 = 0.5
+
+omTest s e
+ | s==1 && e=='A' = 0.4
+ | s==1 && e=='G' = 0.1
+ | s==1 && e=='C' = 0.1
+ | s==1 && e=='T' = 0.4
+ | s==2 && e=='A' = 0.1
+ | s==2 && e=='G' = 0.4
+ | s==2 && e=='C' = 0.4
+ | s==2 && e=='T' = 0.1
Please sign in to comment.
Something went wrong with that request. Please try again.