Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
executable file 140 lines (123 sloc) 4.69 KB
#!/usr/bin/env stack
-- stack --resolver lts-5.15 --install-ghc runghc --package hmatrix --package MonadRandom
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
import Control.Monad
import Control.Monad.Random
import Data.List
import Data.Maybe
import Numeric.LinearAlgebra
import System.Environment
import Text.Read
data Weights = W { wBiases :: !(Vector Double) -- n
, wNodes :: !(Matrix Double) -- n x m
} -- "m to n" layer
data Network :: * where
O :: !Weights
-> Network
(:&~) :: !Weights
-> !Network
-> Network
infixr 5 :&~
logistic :: Floating a => a -> a
logistic x = 1 / (1 + exp (-x))
logistic' :: Floating a => a -> a
logistic' x = logix * (1 - logix)
where
logix = logistic x
runLayer :: Weights -> Vector Double -> Vector Double
runLayer (W wB wN) v = wB + wN #> v
runNet :: Network -> Vector Double -> Vector Double
runNet (O w) !v = logistic (runLayer w v)
runNet (w :&~ n') !v = let v' = logistic (runLayer w v)
in runNet n' v'
randomWeights :: MonadRandom m => Int -> Int -> m Weights
randomWeights i o = do
seed1 :: Int <- getRandom
seed2 :: Int <- getRandom
let wB = randomVector seed1 Uniform o * 2 - 1
wN = uniformSample seed2 o (replicate i (-1, 1))
return $ W wB wN
randomNet :: MonadRandom m => Int -> [Int] -> Int -> m Network
randomNet i [] o = O <$> randomWeights i o
randomNet i (h:hs) o = (:&~) <$> randomWeights i h <*> randomNet h hs o
train :: Double -- ^ learning rate
-> Vector Double -- ^ input vector
-> Vector Double -- ^ target vector
-> Network -- ^ network to train
-> Network
train rate x0 target = fst . go x0
where
go :: Vector Double -- ^ input vector
-> Network -- ^ network to train
-> (Network, Vector Double)
-- handle the output layer
go !x (O w@(W wB wN))
= let y = runLayer w x
o = logistic y
-- the gradient (how much y affects the error)
-- (logistic' is the derivative of logistic)
dEdy = logistic' y * (o - target)
-- new bias weights and node weights
wB' = wB - scale rate dEdy
wN' = wN - scale rate (dEdy `outer` x)
w' = W wB' wN'
-- bundle of derivatives for next step
dWs = tr wN #> dEdy
in (O w', dWs)
-- handle the inner layers
go !x (w@(W wB wN) :&~ n)
= let y = runLayer w x
o = logistic y
-- get dWs', bundle of derivatives from rest of the net
(n', dWs') = go o n
-- the gradient (how much y affects the error)
dEdy = logistic' y * dWs'
-- new bias weights and node weights
wB' = wB - scale rate dEdy
wN' = wN - scale rate (dEdy `outer` x)
w' = W wB' wN'
-- bundle of derivatives for next step
dWs = tr wN #> dEdy
in (w' :&~ n', dWs)
netTest :: MonadRandom m => Double -> Int -> m String
netTest rate n = do
inps <- replicateM n $ do
s <- getRandom
return $ randomVector s Uniform 2 * 2 - 1
let outs = flip map inps $ \v ->
if v `inCircle` (fromRational 0.33, 0.33)
|| v `inCircle` (fromRational (-0.33), 0.33)
then fromRational 1
else fromRational 0
net0 <- randomNet 2 [16,8] 1
let trained = foldl' trainEach net0 (zip inps outs)
where
trainEach :: Network -> (Vector Double, Vector Double) -> Network
trainEach nt (i, o) = train rate i o nt
outMat = [ [ render (norm_2 (runNet trained (vector [x / 25 - 1,y / 10 - 1])))
| x <- [0..50] ]
| y <- [0..20] ]
render r | r <= 0.2 = ' '
| r <= 0.4 = '.'
| r <= 0.6 = '-'
| r <= 0.8 = '='
| otherwise = '#'
return $ unlines outMat
where
inCircle :: Vector Double -> (Vector Double, Double) -> Bool
v `inCircle` (o, r) = norm_2 (v - o) <= r
main :: IO ()
main = do
args <- getArgs
let n = readMaybe =<< (args !!? 0)
rate = readMaybe =<< (args !!? 1)
putStrLn "Training network..."
putStrLn =<< evalRandIO (netTest (fromMaybe 0.25 rate)
(fromMaybe 500000 n )
)
(!!?) :: [a] -> Int -> Maybe a
xs !!? i = listToMaybe (drop i xs)
You can’t perform that action at this time.