Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

progress

  • Loading branch information...
commit bf90379e31b1f81530988e1044f3ff327f4d163d 1 parent d7d39a8
@kirel authored
Showing with 70 additions and 26 deletions.
  1. +24 −11 Classifier.hs
  2. +4 −4 StrokeSample.hs
  3. +42 −11 bench.hs
View
35 Classifier.hs
@@ -3,12 +3,13 @@ module Classifier
newClassifier,
trainClassifier,
classifyWithClassifier,
- Sample(..)
+ Sample(..),
) where
import Control.Monad
import Control.Concurrent.STM
import Data.Heap
+import Data.Map
import Data.Maybe
import Strokes
@@ -31,7 +32,10 @@ instance Eq (Hit s) where
instance Ord (Hit s) where
compare h o = compare (score h) (score o)
-data Classifier a = Classifier (TVar [a]) -- Classifier holds Training Data
+data Classifier a = Classifier Int (TVar [a]) -- Classifier holds Training Data
+
+type Score = Double
+type Results = [(String, Score)]
-- helper
update :: TVar a -> (a -> a) -> STM ()
@@ -39,22 +43,31 @@ update var f = readTVar var >>= (writeTVar var) . f
-- classifier logic
findKNearestNeighbors :: Sample s => Int -> s -> [s] -> [Hit s]
-findKNearestNeighbors k unknown known = toList $ foldl step (empty :: MaxHeap (Hit s)) known where
- step heap next | size heap < k = insert (Hit dist next) heap
- | (lb < limit) && (dist < limit) = insert (Hit dist next) $ fromJust $ viewTail heap where
+findKNearestNeighbors k unknown known = Data.Heap.toList $ foldl step (Data.Heap.empty :: MaxHeap (Hit s)) known where
+ step heap next | Data.Heap.size heap < k = Data.Heap.insert (Hit dist next) heap
+ | (lb < limit) && (dist < limit) = Data.Heap.insert (Hit dist next) $ fromJust $ viewTail heap
+ | otherwise = heap where
lb = distancelb unknown next
dist = distance unknown next
limit = score $ fromJust $ viewHead heap
+alterMin :: Score -> Maybe Score -> Maybe Score
+alterMin next Nothing = Just next
+alterMin next (Just before) = Just $ min before next
+
+results :: Sample s => [Hit s] -> Results
+results hits = Data.Map.toList $ foldl step Data.Map.empty hits where
+ step results hit = alter (alterMin $ score hit) (fromJust $ identifier $ sample hit) results
+
-- classifier interface
-newClassifier :: IO (Classifier s)
-newClassifier = atomically $ liftM Classifier (newTVar [])
+newClassifier :: Int -> IO (Classifier s)
+newClassifier k = atomically $ liftM (Classifier k) (newTVar [])
trainClassifier :: Sample s => Classifier s -> s -> IO ()
trainClassifier _ sample | identifier sample == Nothing = error "Can only train samples of known classes."
-trainClassifier (Classifier t) sample = atomically $ update t (sample:)
+trainClassifier (Classifier _ t) sample = atomically $ Classifier.update t (sample:)
-classifyWithClassifier :: Sample s => Classifier s -> s -> IO [Hit s]
-classifyWithClassifier (Classifier t) sample = do
+classifyWithClassifier :: Sample s => Classifier s -> s -> IO Results
+classifyWithClassifier (Classifier k t) sample = do
samples <- atomically $ readTVar t
- return $ findKNearestNeighbors 50 sample samples
+ return $ results $ findKNearestNeighbors k sample samples
View
8 StrokeSample.hs
@@ -15,15 +15,15 @@ data StrokeSample = StrokeSample {
hullSeries :: [ConvexHull],
windowWidth :: Int,
sidentifier :: Maybe String
- }
+ } deriving (Show)
newStrokeSample :: Stroke -> Maybe String -> StrokeSample
-newStrokeSample s i = StrokeSample s (LB.hullSeries _window_ s) _window_ i
+newStrokeSample s i = StrokeSample s hulls _window_ i where
+ hulls = (LB.hullSeries _window_ s)
instance Sample StrokeSample where
distance (StrokeSample _ _ w _) (StrokeSample _ _ v _) | w /= v = error "Dimension mismatch."
distance (StrokeSample a _ _ _) (StrokeSample b _ _ _) = dtw euclideanDistance 2 a b
distancelb (StrokeSample _ _ w _) (StrokeSample _ _ v _) | w /= v = error "Dimension mismatch."
distancelb (StrokeSample a _ _ _) (StrokeSample _ b _ _) = dtwlb a b
- identifier = sidentifier
-
+ identifier = sidentifier
View
53 bench.hs
@@ -6,20 +6,18 @@ import Random
import Control.Parallel.Strategies
import Control.Parallel
+import Control.Monad
import Strokes
-import DTW
-import LB
+import StrokeSample
import Classifier
-instance Sample StrokeSample where
- distance s t = dtw euclideanDistance 2 (stroke s) (stroke t)
-
pmap = parMap rwhnf
compute stuff = hPrint stderr stuff
strokeFromList (x:y:rest) = (x,y):(strokeFromList rest)
+randomStroke :: Int -> Stroke
randomStroke = (strokeFromList . randoms . mkStdGen)
-- main = do
@@ -57,10 +55,43 @@ randomStroke = (strokeFromList . randoms . mkStdGen)
-- compute mini
-- stop <- getPOSIXTime
-- print $ stop - start
-
+
+cK = 50
+strokesize = 30
+num = 20000
+stroke = take strokesize $ (randomStroke 0)
+strokes = take num $ (map ((take strokesize) . randomStroke) [1..])
+ids = cycle $ map (:[]) ['a'..'z']
+
main = do
- c <- newClassifier
- trainClassifier c "blub" [(0,0)]
- trainClassifier c "bla" [(0,0)]
- str <- classifyWithClassifier c [(0,0)]
- print str
+ c <- newClassifier cK
+ -- train with samples
+ print "Training the classifier..."
+ start <- getPOSIXTime
+ forM_ (zip strokes ids) $ \(s, ident) -> do
+ trainClassifier c $ newStrokeSample s (Just ident)
+ stop <- getPOSIXTime
+ print $ stop - start
+
+ print "Classifying..."
+ start <- getPOSIXTime
+ results <- classifyWithClassifier c $ newStrokeSample stroke Nothing
+ compute results
+ stop <- getPOSIXTime
+ print $ stop - start
+
+ print "Classifying again..."
+ start <- getPOSIXTime
+ results <- classifyWithClassifier c $ newStrokeSample stroke Nothing
+ compute results
+ stop <- getPOSIXTime
+ print $ stop - start
+
+ print "And again..."
+ start <- getPOSIXTime
+ results <- classifyWithClassifier c $ newStrokeSample stroke Nothing
+ compute results
+ stop <- getPOSIXTime
+ print $ stop - start
+
+ print $ results
Please sign in to comment.
Something went wrong with that request. Please try again.