Permalink
Browse files

Added viterbi algorithm, more performance tests, more ways to create …

…default HMMs
  • Loading branch information...
1 parent 3fe845a commit f6377a4d1ce4f830c3029d06f1b9226a89f3ce25 @mikeizbicki committed Mar 17, 2012
Showing with 150 additions and 65 deletions.
  1. +13 −2 BioHMM.hs
  2. +120 −21 HMM.hs
  3. +17 −42 HMMPerf.hs
View
@@ -33,6 +33,13 @@ import System.IO
-- putStrLn $ show hmm'
-- applyLoop nexthmm $ tail tfL
+findGenes = do
+ hmmTF <- loadHMM "hmm/TF-3.hmm"
+ hmmDNA <- loadHMM "hmm/autowinegrape-1000-2.hmm"
+ let hmm' = hmmJoin hmmDNA hmmTF 0.999
+ dna <- loadDNAArray 1000
+ return $ viterbi hmm' dna
+
createTFhmm file hmm = do
x <- strTF
let hmm' = baumWelch hmm (listArray (1,length x) x) 10
@@ -61,7 +68,12 @@ createDNAhmm file len hmm = do
putStrLn $ show hmm'
saveHMM file hmm'
return hmm'
-
+
+verifyHMMFile file = do
+ hmm <- ((loadHMM file) :: IO (HMM String Char))
+ verifyhmm hmm
+
+
loadDNAArray len = do
dna <- readFile "dna/winegrape-chromosone2"
let dnaArray = listArray (1,len) $ filter isBP dna
@@ -111,7 +123,6 @@ omTest s e
| s==2 && e=='C' = 0.4
| s==2 && e=='T' = 0.1
-
bwTest = do
hmm <- loadHMM "hmm/test" ::IO (HMM String Char)
return $ baumWelch hmm (listArray (1,10) "AAAAAAGTGC") 10
View
141 HMM.hs
@@ -2,15 +2,18 @@ module HMM
( HMM(..), Prob, rnf
, forward
, backward
- , baumWelch, baumWelchItr --, baumWelchIO
+ , viterbi
+ , baumWelch, baumWelchItr
, alpha, beta
- , simpleMM, simpleMM2
+ , simpleMM, simpleMM2, simpleHMM, hmmJoin
+ , verifyhmm
)
where
import Debug.Trace
import Data.Array
import Data.List
+import Data.List.Extras
import Data.Number.LogFloat
import qualified Data.MemoCombinators as Memo
import Control.DeepSeq
@@ -44,13 +47,14 @@ hmm2str hmm = "HMM" ++ "{ states=" ++ (show $ states hmm)
++ ", outMatrix=" ++ (show [(s,e,outMatrix hmm s e) | s <- states hmm, e <- events hmm])
++ "}"
+elemIndex2 :: (Show a, Eq a) => a -> [a] -> Int
elemIndex2 e list = case elemIndex e list of
- Nothing -> seq (error "stateIndex: Index "++show e++" not in HMM "++show list) 0
+ Nothing -> seq (error ("elemIndex2: Index "++show e++" not in HMM "++show list)) 0
Just x -> x
stateIndex :: (Show stateType, Show eventType, Eq stateType) => HMM stateType eventType -> stateType -> Int
stateIndex hmm state = case elemIndex state $ states hmm of
- Nothing -> seq (error "stateIndex: Index "++show state++" not in HMM "++show hmm) 0
+ Nothing -> seq (error ("stateIndex: Index "++show state++" not in HMM "++show hmm)) 0
Just x -> x
eventIndex :: (Show stateType, Show eventType, Eq eventType) => HMM stateType eventType -> eventType -> Int
@@ -59,18 +63,6 @@ eventIndex hmm event = case elemIndex event $ events hmm of
Just x -> x
--- standardHMM :: [stateType] -> [eventType] -> HMM stateType eventType
--- standardHMM sL eL = HMM { states=sL
--- , events=eL
--- , initProbs = ip
--- , transMatrix = tm
--- , outMatrix = om
--- }
--- where
--- ip s = 1.0 / (logFloat $ length sL)
--- tm s1 s2 = 1.0 / (logFloat $ length sL)
--- om s e = 1.0 / (logFloat $ length eL)
-
simpleMM2 eL order = HMM { states = sL
, events = eL
, initProbs = \s -> 1.0 / (logFloat $ length sL)
@@ -89,9 +81,9 @@ simpleMM2 eL order = HMM { states = sL
simpleMM eL order = HMM { states = sL
, events = eL
- , initProbs = \s -> skewedDist s
+ , initProbs = \s -> evenDist--skewedDist s
, transMatrix = \s1 -> \s2 -> if (length s1==0) || (isPrefixOf (tail s1) s2)
- then skewedDist s2 -- 1.0 / (logFloat $ length sL)
+ then skewedDist s2 --1.0 / (logFloat $ length sL)
else 0.0
, outMatrix = \s -> \e -> 1.0/(logFloat $ length eL)
}
@@ -103,6 +95,18 @@ simpleMM eL order = HMM { states = sL
| order' == 0 = list
| otherwise = enumerateStates (order'-1) [symbol:l | l <- list, symbol <- eL]
+simpleHMM :: (Eq stateType, Show eventType, Show stateType) =>
+ [stateType] -> [eventType] -> HMM stateType eventType
+simpleHMM sL eL = HMM { states = sL
+ , events = eL
+ , initProbs = \s -> evenDist--skewedDist s
+ , transMatrix = \s1 -> \s2 -> skewedDist s2
+ , outMatrix = \s -> \e -> 1.0/(logFloat $ length eL)
+ }
+ where evenDist = 1.0 / sLlen
+ skewedDist s = (logFloat $ 1+elemIndex2 s sL) / ( (sLlen * (sLlen+ (logFloat (1.0 :: Double))))/2.0)
+ sLlen = logFloat $ length sL
+
-- | forward algorithm
@@ -165,6 +169,40 @@ beta hmm obs = memo_beta
]
+ -- | Viterbi
+
+viterbi :: (Eq eventType, Eq stateType, Show eventType, Show stateType) =>
+ HMM stateType eventType -> Array Int eventType -> [stateType]
+viterbi hmm obs = [memo_x' t | t <- [1..bT]]
+ where bT = snd $ bounds obs
+
+-- x' :: Int -> stateType
+{- memo_x' t = memo_newInitProbs2 (stateIndex hmm state)
+ memo_x'2 = Memo.integral memo_newInitProbs3
+ memo_x'3 state = newInitProbs (states hmm !! state)-}
+ memo_x' = Memo.integral x'
+ x' t
+ | t == bT = argmax (\i -> memo_delta bT i) (states hmm)
+ | otherwise = memo_psi (t+1) (memo_x' (t+1))
+
+-- delta :: Int -> stateType -> Prob
+ memo_delta t state = memo_delta2 t (stateIndex hmm state)
+ memo_delta2 = (Memo.memo2 Memo.integral Memo.integral memo_delta3)
+ memo_delta3 t state = delta t (states hmm !! state)
+ delta t state
+ | t == 1 = (outMatrix hmm state $ obs!t)*(initProbs hmm state)
+ | otherwise = maximum [(memo_delta (t-1) i)*(transMatrix hmm i state)*(outMatrix hmm (state) $ obs!t)
+ | i <- states hmm
+ ]
+
+-- psi :: Int -> stateType -> stateType
+ memo_psi t state = memo_psi2 t (stateIndex hmm state)
+ memo_psi2 = (Memo.memo2 Memo.integral Memo.integral memo_psi3)
+ memo_psi3 t state = psi t (states hmm !! state)
+ psi t state
+ | t == 1 = (states hmm) !! 0
+ | otherwise = argmax (\i -> (memo_delta (t-1) i) * (transMatrix hmm i state)) (states hmm)
+
-- | Baum-Welch
{-gammaArray :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType
@@ -197,8 +235,8 @@ baumWelchItr hmm obs = --par newInitProbs $ par newTransMatrix $ par newOutMatri
HMM { states = states hmm
, events = events hmm
, initProbs = memo_newInitProbs
- , transMatrix = {-newTransMatrix-} memo_newTransMatrix
- , outMatrix = {-outMatrix hmm-} memo_newOutMatrix
+ , transMatrix = {-newTransMatrix---} memo_newTransMatrix
+ , outMatrix = {-outMatrix hmm ---} memo_newOutMatrix
}
where bT = snd $ bounds obs
memo_newInitProbs state = memo_newInitProbs2 (stateIndex hmm state)
@@ -210,7 +248,7 @@ baumWelchItr hmm obs = --par newInitProbs $ par newTransMatrix $ par newOutMatri
memo_newTransMatrix2 = (Memo.memo2 Memo.integral Memo.integral memo_newTransMatrix3)
memo_newTransMatrix3 state1 state2 = newTransMatrix (states hmm !! state1) (states hmm !! state2)
newTransMatrix state1 state2 = --trace ("newTransMatrix"++(hmmid hmm)) $
- sum [xi t state1 state2 | t <- [2..bT]]
+ sum [xi t state2 state1 | t <- [2..bT]]
/sum [gamma t state1 | t <- [2..bT]]
memo_newOutMatrix state event = memo_newOutMatrix2 (stateIndex hmm state) (eventIndex hmm event)
@@ -253,6 +291,67 @@ baumWelchItr hmm obs = --par newInitProbs $ par newTransMatrix $ par newOutMatri
| otherwise = (outMatrix hmm (states hmm !! state') $ obs!t')*(sum [(memo_alpha (t'-1) state2)*(transMatrix hmm state2 (states hmm !! state')) | state2 <- states hmm])
+--
+
+hmmJoin :: (Eq stateType, Eq eventType, Read stateType, Show stateType) =>
+ HMM stateType eventType -> HMM stateType eventType -> Prob -> HMM (Int,stateType) eventType
+hmmJoin hmm1 hmm2 ratio = HMM { states = states1 ++ states2
+ , events = if (events hmm1) == (events hmm2)
+ then events hmm1
+ else error "hmmJoin: event sets not equal"
+ , initProbs = \s -> if (s `elem` states1)
+ then (initProbs hmm1 $ lift s)*r1
+ else (initProbs hmm2 $ lift s)*r2
+ , transMatrix = \s1 -> \s2 -> if (s1 `elem` states1 && s2 `elem` states1)
+ then (transMatrix hmm1 (lift s1) (lift s2))*r1
+ else if (s2 `elem` states2 && s2 `elem` states2)
+ then (transMatrix hmm2 (lift s1) (lift s2))*r2
+ else if (s1 `elem` states1)
+ then (r2)/(logFloat $ length $ states2)
+ else (r1)/(logFloat $ length $ states1)
+ , outMatrix = \s -> if (s `elem` states1)
+ then (outMatrix hmm1 $ lift s)
+ else (outMatrix hmm2 $ lift s)
+ }
+ where r1=ratio
+ r2=1-ratio
+ states1 = map (\x -> (1,x)) $ states hmm1
+ states2 = map (\x -> (2,x)) $ states hmm2
+
+-- lift :: (Int,String) -> a
+ lift x =snd x
+-- lift x =read $ (snd x )
-- debug utils
hmmid hmm = show $ initProbs hmm $ (states hmm) !! 1
+
+ -- | tests
+
+listCPExp :: [a] -> Int -> [[a]]
+listCPExp language order = listCPExp' order [[]]
+ where
+ listCPExp' order list
+ | order == 0 = list
+ | otherwise = listCPExp' (order-1) [symbol:l | l <- list, symbol <- language]
+
+ -- these should equal ~1 if our recurrence if alpha and beta are correct
+
+forwardtest hmm x = sum [forward hmm e | e <- listCPExp (events hmm) x]
+backwardtest hmm x = sum [backward hmm e | e <- listCPExp (events hmm) x]
+
+fbtest hmm events = "fwd: " ++ show (forward hmm events) ++ " bkwd:" ++ show (backward hmm events)
+
+verifyhmm hmm = do
+ seq ip $ check "initProbs" ip
+ check "transMatrix" tm
+ check "outMatrix" om
+
+ where check str var = do
+ putStrLn $ str++" tollerance check: "++show var
+{- if abs(var-1)<0.0001
+ then putStrLn "True"
+ else putStrLn "False"-}
+
+ ip = sum $ [initProbs hmm s | s <- states hmm]
+ tm = (sum $ [transMatrix hmm s1 s2 | s1 <- states hmm, s2 <- states hmm]) -- (length $ states hmm)
+ om = sum $ [outMatrix hmm s e | s <- states hmm, e <- events hmm] -- / length $ states hmm
View
@@ -30,48 +30,23 @@ forceEval x = putStrLn $ (show $ length str) ++ " -> " ++ (take 30 str)
where str = show x
main = defaultMainWith myConfig (return ())
- [ bench ("baumWelch (itr="++show itr++",ord="++show order++",len="++(show arraylen)++")") $
- whnf (baumWelch (simpleMM "AGCT" order) (genArray arraylen)) itr
-
- | arraylen <- [1000]
- , order <- [1..6]
- , itr <- [1]
- ]
-
-{- , bench "newHMM - forward" $ forceEval $ forward newHMM $ genString 100
- , bench "newHMM - backward" $ forceEval $ backward newHMM $ genString 100
- , bench "oldHMM" $ forceEval $ OldHMM.sequenceProb oldHMM $ genString 100-}
-
- -- | tests
-
-listCPExp :: [a] -> Int -> [[a]]
-listCPExp language order = listCPExp' order [[]]
- where
- listCPExp' order list
- | order == 0 = list
- | otherwise = listCPExp' (order-1) [symbol:l | l <- list, symbol <- language]
-
- -- these should equal ~1 if our recurrence if alpha and beta are correct
-
-forwardtest hmm x = sum [forward hmm e | e <- listCPExp (events hmm) x]
-backwardtest hmm x = sum [backward hmm e | e <- listCPExp (events hmm) x]
-
-fbtest hmm events = "fwd: " ++ show (forward hmm events) ++ " bkwd:" ++ show (backward hmm events)
-
-verifyhmm hmm = do
- check "initProbs" ip
- check "transMatrix" tm
- check "outMatrix" om
-
- where check str var = do
- putStrLn $ str++" tollerance check: "++show var
-{- if abs(var-1)<0.0001
- then putStrLn "True"
- else putStrLn "False"-}
-
- ip = sum $ [initProbs hmm s | s <- states hmm]
- tm = (sum $ [transMatrix hmm s1 s2 | s1 <- states hmm, s2 <- states hmm]) -- (length $ states hmm)
- om = sum $ [outMatrix hmm s e | s <- states hmm, e <- events hmm] -- / length $ states hmm
+-- [ bench ("baumWelch (itr="++show itr++",ord="++show order++",len="++(show arraylen)++")") $
+-- whnf (baumWelch (simpleMM "AGCT" order) (genArray arraylen)) itr
+--
+-- | arraylen <- [1000]
+-- , order <- [1..6]
+-- , itr <- [1]
+-- ]
+
+ [ bench ("newHMM - viterbi (states="++show states++",len="++show len++")") $ putStrLn $ show $ viterbi (simpleHMM [1..states] "AGCT") $ genArray len
+ | len <- [1000] -- [10,100,1000,10000,20000,30000,40000,50000,60000,70000,80000,90000,100000,200000,300000,400000,500000,1000000]
+ , states <- [1..100]
+ ]
+-- [ bench "newHMM - forward" $ forceEval $ forward newHMM $ genString len
+-- , bench "newHMM - backward" $ forceEval $ backward newHMM $ genString len
+-- , bench "oldHMM" $ forceEval $ OldHMM.sequenceProb oldHMM $ genString len
+-- | len <- [100,1000,10000]
+-- ]
-- | OldHMM definition

0 comments on commit f6377a4

Please sign in to comment.