Skip to content
Fetching contributors… Cannot retrieve contributors at this time
executable file 140 lines (123 sloc) 4.69 KB
 #!/usr/bin/env stack -- stack --resolver lts-5.15 --install-ghc runghc --package hmatrix --package MonadRandom {-# LANGUAGE BangPatterns #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE ScopedTypeVariables #-} import Control.Monad import Control.Monad.Random import Data.List import Data.Maybe import Numeric.LinearAlgebra import System.Environment import Text.Read data Weights = W { wBiases :: !(Vector Double) -- n , wNodes :: !(Matrix Double) -- n x m } -- "m to n" layer data Network :: * where O :: !Weights -> Network (:&~) :: !Weights -> !Network -> Network infixr 5 :&~ logistic :: Floating a => a -> a logistic x = 1 / (1 + exp (-x)) logistic' :: Floating a => a -> a logistic' x = logix * (1 - logix) where logix = logistic x runLayer :: Weights -> Vector Double -> Vector Double runLayer (W wB wN) v = wB + wN #> v runNet :: Network -> Vector Double -> Vector Double runNet (O w) !v = logistic (runLayer w v) runNet (w :&~ n') !v = let v' = logistic (runLayer w v) in runNet n' v' randomWeights :: MonadRandom m => Int -> Int -> m Weights randomWeights i o = do seed1 :: Int <- getRandom seed2 :: Int <- getRandom let wB = randomVector seed1 Uniform o * 2 - 1 wN = uniformSample seed2 o (replicate i (-1, 1)) return \$ W wB wN randomNet :: MonadRandom m => Int -> [Int] -> Int -> m Network randomNet i [] o = O <\$> randomWeights i o randomNet i (h:hs) o = (:&~) <\$> randomWeights i h <*> randomNet h hs o train :: Double -- ^ learning rate -> Vector Double -- ^ input vector -> Vector Double -- ^ target vector -> Network -- ^ network to train -> Network train rate x0 target = fst . go x0 where go :: Vector Double -- ^ input vector -> Network -- ^ network to train -> (Network, Vector Double) -- handle the output layer go !x (O w@(W wB wN)) = let y = runLayer w x o = logistic y -- the gradient (how much y affects the error) -- (logistic' is the derivative of logistic) dEdy = logistic' y * (o - target) -- new bias weights and node weights wB' = wB - scale rate dEdy wN' = wN - scale rate (dEdy `outer` x) w' = W wB' wN' -- bundle of derivatives for next step dWs = tr wN #> dEdy in (O w', dWs) -- handle the inner layers go !x (w@(W wB wN) :&~ n) = let y = runLayer w x o = logistic y -- get dWs', bundle of derivatives from rest of the net (n', dWs') = go o n -- the gradient (how much y affects the error) dEdy = logistic' y * dWs' -- new bias weights and node weights wB' = wB - scale rate dEdy wN' = wN - scale rate (dEdy `outer` x) w' = W wB' wN' -- bundle of derivatives for next step dWs = tr wN #> dEdy in (w' :&~ n', dWs) netTest :: MonadRandom m => Double -> Int -> m String netTest rate n = do inps <- replicateM n \$ do s <- getRandom return \$ randomVector s Uniform 2 * 2 - 1 let outs = flip map inps \$ \v -> if v `inCircle` (fromRational 0.33, 0.33) || v `inCircle` (fromRational (-0.33), 0.33) then fromRational 1 else fromRational 0 net0 <- randomNet 2 [16,8] 1 let trained = foldl' trainEach net0 (zip inps outs) where trainEach :: Network -> (Vector Double, Vector Double) -> Network trainEach nt (i, o) = train rate i o nt outMat = [ [ render (norm_2 (runNet trained (vector [x / 25 - 1,y / 10 - 1]))) | x <- [0..50] ] | y <- [0..20] ] render r | r <= 0.2 = ' ' | r <= 0.4 = '.' | r <= 0.6 = '-' | r <= 0.8 = '=' | otherwise = '#' return \$ unlines outMat where inCircle :: Vector Double -> (Vector Double, Double) -> Bool v `inCircle` (o, r) = norm_2 (v - o) <= r main :: IO () main = do args <- getArgs let n = readMaybe =<< (args !!? 0) rate = readMaybe =<< (args !!? 1) putStrLn "Training network..." putStrLn =<< evalRandIO (netTest (fromMaybe 0.25 rate) (fromMaybe 500000 n ) ) (!!?) :: [a] -> Int -> Maybe a xs !!? i = listToMaybe (drop i xs)
You can’t perform that action at this time.