Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

Much improved performance by memoizing alpha and beta functions

  • Loading branch information...
commit f5820c0168bab5c5d1f2ec2d62ef5a37be67b257 1 parent 751fca0
Mike Izbicki authored

Showing 1 changed file with 73 additions and 101 deletions. Show diff stats Hide diff stats

  1. +73 101 HMM2.hs
174 HMM2.hs
@@ -3,6 +3,7 @@ module HMM2
3 3
4 4 import Debug.Trace
5 5 import Data.Array
  6 +import Data.List
6 7 import Data.Number.LogFloat
7 8 import qualified Data.MemoCombinators as Memo
8 9
@@ -10,7 +11,8 @@ type Prob = LogFloat
10 11
11 12 -- | The data type for our HMM
12 13
13   -data HMM stateType eventType = HMM { states :: [stateType]
  14 +data -- (Eq eventType, Eq stateType, Show eventType, Show stateType) =>
  15 + HMM stateType eventType = HMM { states :: [stateType]
14 16 , events :: [eventType]
15 17 , initProbs :: (stateType -> Prob)
16 18 , transMatrix :: (stateType -> stateType -> Prob)
@@ -24,127 +26,108 @@ instance (Show state, Show observation) => Show (HMM state observation) where
24 26 ++ " transMatrix=" ++ (show [(s1,s2,transMatrix hmm s1 s2) | s1 <- states hmm, s2 <- states hmm])
25 27 ++ " outMatrix=" ++ (show [(s,e,outMatrix hmm s e) | s <- states hmm, e <- events hmm])
26 28
  29 +stateIndex :: (Show stateType, Show eventType, Eq stateType) => HMM stateType eventType -> stateType -> Int
  30 +stateIndex hmm state = case elemIndex state $ states hmm of
  31 + Nothing -> seq (error "stateIndex: Index "++show state++" not in HMM "++show hmm) 0
  32 + Just x -> x
  33 +
  34 +eventIndex :: (Show stateType, Show eventType, Eq eventType) => HMM stateType eventType -> eventType -> Int
  35 +eventIndex hmm event = case elemIndex event $ events hmm of
  36 + Nothing -> seq (error "stateIndex: Index "++show event++" not in HMM "++show hmm) 0
  37 + Just x -> x
  38 +
27 39 -- | forward algorithm
28 40
29   -forward :: (Eq eventType) => HMM stateType eventType -> [eventType] -> Prob
30   -forward=forwardArray
31   -
32   -forwardList :: (Eq eventType) => HMM stateType eventType -> [eventType] -> Prob
33   -forwardList hmm obs = sum [alphaList hmm obs state | state <- states hmm]
34   -
35   -alphaList :: (Eq eventType) => HMM stateType eventType -> [eventType] -> stateType -> Prob
36   -alphaList hmm obs@(x:xs) state
37   - | xs==[] = (outMatrix hmm state x)*(initProbs hmm state)
38   - | otherwise = (outMatrix hmm state x)*(sum [(alphaList hmm xs state)*(transMatrix hmm state state2) | state2 <- states hmm
39   - ])
40   -
41   -forwardArray :: (Eq eventType) => HMM stateType eventType -> [eventType] -> Prob
42   -forwardArray hmm obs = sum [alphaArray hmm (listArray (1,bT) obs) bT state | state <- states hmm]
  41 +forward :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> [eventType] -> Prob
  42 +forward hmm obs = forwardArray hmm (listArray (1,bT) obs)
43 43 where
44 44 bT = length obs
  45 +
  46 +forwardArray :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> Array Int eventType -> Prob
  47 +forwardArray hmm obs = sum [alpha hmm obs bT state | state <- states hmm]
  48 + where
  49 + bT = snd $ bounds obs
  50 +
  51 +alpha :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType
  52 + -> Array Int eventType
  53 + -> Int
  54 + -> stateType
  55 + -> Prob
  56 +alpha hmm obs = memo_alpha
  57 + where memo_alpha t state = memo_alpha2 t (stateIndex hmm state)
  58 + memo_alpha2 = (Memo.memo2 Memo.integral Memo.integral memo_alpha3)
  59 + memo_alpha3 t' state'
  60 + | t' == 1 = -- trace ("memo_alpha' t'="++show t'++", state'="++show state') $
  61 + (outMatrix hmm (states hmm !! state') $ obs!t')*(initProbs hmm $ states hmm !! state')
  62 + | otherwise = -- trace ("memo_alpha' t'="++show t'++", state'="++show state') $
  63 + (outMatrix hmm (states hmm !! state') $ obs!t')*(sum [(memo_alpha (t'-1) state2)*(transMatrix hmm state2 (states hmm !! state')) | state2 <- states hmm])
45 64
46   -alphaArray :: (Eq eventType) => HMM stateType eventType -> Array Int eventType -> Int -> stateType -> Prob
47   -alphaArray hmm obs t state
48   - | t == 1 = (outMatrix hmm state $ obs!t)*(initProbs hmm state)
49   - | otherwise = (outMatrix hmm state $ obs!t)*(sum [(alphaArray hmm obs (t-1) state2)*(transMatrix hmm state2 state) | state2 <- states hmm
50   - ])
51   --- memoized_alphaArray :: (Eq eventType) => HMM stateType eventType -> Array Int eventType -> Int -> stateType -> Prob
52   -memoized_alphaArray hmm obs t = (map aa (states hmm) !!)
53   - where aa state = if t==1
54   - then (outMatrix hmm state $ obs!t)*(initProbs hmm state)
55   - else (outMatrix hmm state $ obs!t)*(sum [(memoized_alphaArray hmm obs (t-1) state)*(transMatrix hmm state state2) | state2 <- states hmm])
56   -
57   -memo_alphaArray :: (Eq eventType) => HMM Integer eventType -> Array Int eventType -> Int -> Integer -> Prob
58   -memo_alphaArray hmm obs = Memo.memo2 Memo.integral Memo.integral aa
59   - where aa t state
60   - | t == 1 = (outMatrix hmm state $ obs!t)*(initProbs hmm state)
61   - | otherwise = (outMatrix hmm state $ obs!t)*(sum [(memo_alphaArray hmm obs (t-1) state)*(transMatrix hmm state state2) | state2 <- states hmm
62   - ])
63   -memoized_fib :: Int -> Integer
64   -memoized_fib = (map fib [0 .. 10] !!)
65   - where fib 0 = 0
66   - fib 1 = 1
67   - fib n = memoized_fib (n-2) + memoized_fib (n-1)
68 65
69 66 -- | backwards algorithm
70 67
71   -backward :: (Eq eventType, Show eventType) => HMM stateType eventType -> [eventType] -> Prob
  68 +backward :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> [eventType] -> Prob
72 69 backward hmm obs = backwardArray hmm $ listArray (1,length obs) obs
73 70
74   -backwardArray :: (Eq eventType,Show eventType) => HMM stateType eventType -> Array Int eventType -> Prob
  71 +backwardArray :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> Array Int eventType -> Prob
75 72 backwardArray hmm obs = backwardArray' hmm obs
76 73 where
77 74 backwardArray' hmm obs = sum [(initProbs hmm state)
78 75 *(outMatrix hmm state $ obs!1)
79   - *(betaArray hmm obs 1 state)
  76 + *(beta hmm obs 1 state)
80 77 | state <- states hmm
81 78 ]
82 79
83   -betaArray :: (Eq eventType) => HMM stateType eventType -> Array Int eventType -> Int -> stateType -> Prob
84   -betaArray hmm obs t state
85   - | t == bT = 1
86   - | otherwise = sum [(transMatrix hmm state state2)
87   - *(outMatrix hmm state2 $ obs!(t+1))
88   - *(betaArray hmm obs (t+1) state2)
89   - | state2 <- states hmm
90   - ]
91   - where
92   - bT = snd $ bounds obs
93   -
94   -
95   --- This implementation has a bug somewhere, but it is also not used in Baum-Welch
96   -
97   -backwardList :: (Eq eventType,Show eventType) => HMM stateType eventType -> [eventType] -> Prob
98   -backwardList hmm obs = backwardList' hmm $ reverse obs
99   - where
100   - backwardList' hmm obsrev = sum [(initProbs hmm state)
101   - *(outMatrix hmm state $ head obsrev)
102   - *(betaArray hmm (listArray (1,length obsrev) obsrev) 1 state)
103   --- *(betaList hmm obsrev state)
104   - | state <- states hmm
105   - ]
106   -
107   -betaList :: (Eq eventType) => HMM stateType eventType -> [eventType] -> stateType -> Prob
108   -betaList hmm obs@(x:xs) state
109   - | xs == [] = 1
110   - | otherwise = sum [(transMatrix hmm state state2)
111   - *(outMatrix hmm state2 x)
112   - *(betaList hmm xs state2)
113   - | state2 <- states hmm
114   - ]
  80 +beta :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType
  81 + -> Array Int eventType
  82 + -> Int
  83 + -> stateType
  84 + -> Prob
  85 +beta hmm obs = memo_beta
  86 + where bT = snd $ bounds obs
  87 + memo_beta t state = memo_beta2 t (stateIndex hmm state)
  88 + memo_beta2 = (Memo.memo2 Memo.integral Memo.integral memo_beta3)
  89 + memo_beta3 t' state'
  90 + | t' == bT = -- trace ("memo_alpha' t'="++show t'++", state'="++show state') $
  91 + 1
  92 + | otherwise = -- trace ("memo_alpha' t'="++show t'++", state'="++show state') $
  93 + sum [(transMatrix hmm (states hmm !! state') state2)
  94 + *(outMatrix hmm state2 $ obs!(t'+1))
  95 + *(memo_beta (t'+1) state2)
  96 + | state2 <- states hmm
  97 + ]
115 98
116 99
117 100 -- | Baum-Welch
118 101
119   -gammaArray :: (Eq eventType, Show eventType) => HMM stateType eventType
120   - -> Array Int eventType
121   - -> Int
122   - -> stateType
123   - -> Prob
124   -gammaArray hmm obs t state = (alphaArray hmm obs t state)
125   - *(betaArray hmm obs t state)
  102 +gammaArray :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType
  103 + -> Array Int eventType
  104 + -> Int
  105 + -> stateType
  106 + -> Prob
  107 +gammaArray hmm obs t state = (alpha hmm obs t state)
  108 + *(beta hmm obs t state)
126 109 /(backwardArray hmm obs)
127 110
128 111 -- xi i j = P(state (t-1) == i && state (t) == j | obs, lambda)
129 112
130   -xiArray :: (Eq eventType, Show eventType) => HMM stateType eventType
131   - -> Array Int eventType
132   - -> Int
133   - -> stateType
134   - -> stateType
135   - -> Prob
136   -xiArray hmm obs t state1 state2 = (alphaArray hmm obs (t-1) state1)
  113 +xiArray :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType
  114 + -> Array Int eventType
  115 + -> Int
  116 + -> stateType
  117 + -> stateType
  118 + -> Prob
  119 +xiArray hmm obs t state1 state2 = (alpha hmm obs (t-1) state1)
137 120 *(transMatrix hmm state1 state2)
138 121 *(outMatrix hmm state2 $ obs!t)
139   - *(betaArray hmm obs t state2)
  122 + *(beta hmm obs t state2)
140 123 /(backwardArray hmm obs)
141 124
142   -baumWelch :: (Eq eventType, Show eventType) => HMM stateType eventType -> Array Int eventType -> Int -> HMM stateType eventType
  125 +baumWelch :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> Array Int eventType -> Int -> HMM stateType eventType
143 126 baumWelch hmm obs count
144 127 | count == 0 = hmm
145 128 | otherwise = baumWelch (baumWelchItr hmm obs) obs (count-1)
146 129
147   -baumWelchItr :: (Eq eventType, Show eventType) => HMM stateType eventType -> Array Int eventType -> HMM stateType eventType
  130 +baumWelchItr :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> Array Int eventType -> HMM stateType eventType
148 131 baumWelchItr hmm obs = HMM { states = states hmm
149 132 , events = events hmm
150 133 , initProbs = newInitProbs
@@ -174,23 +157,15 @@ listCPExp language order = listCPExp' order [[]]
174 157
175 158 -- | tests
176 159
177   --- this should equal ~1 if our recurrence in alpha is correct
178   -alphatest hmm x = sum [alphaList hmm e s | e <- listCPExp (events hmm) x, s <- states hmm]
  160 +-- these should equal ~1 if our recurrence in alpha is correct
179 161
180 162 forwardtest hmm x = sum [forward hmm e | e <- listCPExp (events hmm) x]
181   -
182 163 backwardtest hmm x = sum [backward hmm e | e <- listCPExp (events hmm) x]
183 164
184   -fftest hmm events = "fwdLst: " ++ show (forwardList hmm events) ++ " fwdArr:" ++ show (forwardArray hmm events)
185   -bbtest hmm events = "bckLst: " ++ show (backwardList hmm events) ++ " bckArr:" ++ show (backwardArray hmm $ listArray (1,length events) events)
186   -
187 165 fbtest hmm events = "fwd: " ++ show (forward hmm events) ++ " bkwd:" ++ show (backward hmm events)
188 166
189 167 -- | sample HMM used for testing
190 168
191   -arr :: Array Int Char
192   -arr = listArray (1,5) "AGTCA"
193   -
194 169 simpleHMM = HMM { states=[1,2]
195 170 , events=['A','G','C','T']
196 171 , initProbs = ipTest
@@ -198,9 +173,6 @@ simpleHMM = HMM { states=[1,2]
198 173 , outMatrix = omTest
199 174 }
200 175
201   --- ipTest :: Array Int Prob
202   --- ipTest = listArray (1,2) [0.1,0.9]
203   -
204 176 ipTest s
205 177 | s == 1 = 0.1
206 178 | s == 2 = 0.9

0 comments on commit f5820c0

Please sign in to comment.
Something went wrong with that request. Please try again.