Permalink
Fetching contributors…
Cannot retrieve contributors at this time
550 lines (491 sloc) 17.2 KB
% Extensible Neural Networks with Backprop
% Justin Le
This write-up is a follow-up to the *MNIST* tutorial
([rendered][mnist-rendered] here, and [literate haskell][mnist-lhs] here).
This write-up itself is available as a [literate haskell file][lhs], and
also [rendered as a pdf][rendered].
[mnist-rendered]: https://github.com/mstksg/backprop/blob/master/renders/backprop-mnist.pdf
[mnist-lhs]: https://github.com/mstksg/backprop/blob/master/samples/backprop-mnist.lhs
[rendered]: https://github.com/mstksg/backprop/blob/master/renders/extensible-neural.pdf
[lhs]: https://github.com/mstksg/backprop/blob/master/samples/extensible-neural.lhs
The (extra) packages involved are:
* hmatrix
* lens
* mnist-idx
* mwc-random
* one-liner-instances
* singletons
* split
> {-# LANGUAGE BangPatterns #-}
> {-# LANGUAGE DataKinds #-}
> {-# LANGUAGE DeriveGeneric #-}
> {-# LANGUAGE FlexibleContexts #-}
> {-# LANGUAGE GADTs #-}
> {-# LANGUAGE InstanceSigs #-}
> {-# LANGUAGE LambdaCase #-}
> {-# LANGUAGE RankNTypes #-}
> {-# LANGUAGE ScopedTypeVariables #-}
> {-# LANGUAGE TemplateHaskell #-}
> {-# LANGUAGE TypeApplications #-}
> {-# LANGUAGE TypeInType #-}
> {-# LANGUAGE TypeOperators #-}
> {-# LANGUAGE ViewPatterns #-}
> {-# OPTIONS_GHC -Wno-orphans #-}
>
> import Control.DeepSeq
> import Control.Exception
> import Control.Monad
> import Control.Monad.IO.Class
> import Control.Monad.Primitive
> import Control.Monad.Trans.Maybe
> import Control.Monad.Trans.State
> import Data.Bitraversable
> import Data.Foldable
> import Data.IDX
> import Data.Kind
> import Data.List.Split
> import Data.Singletons
> import Data.Singletons.Prelude
> import Data.Singletons.TypeLits
> import Data.Time.Clock
> import Data.Traversable
> import Data.Tuple
> import GHC.Generics (Generic)
> import Lens.Micro
> import Lens.Micro.TH
> import Numeric.Backprop
> import Numeric.Backprop.Class
> import Numeric.LinearAlgebra.Static
> import Numeric.OneLiner
> import Text.Printf
> import qualified Data.Vector as V
> import qualified Data.Vector.Generic as VG
> import qualified Data.Vector.Unboxed as VU
> import qualified Numeric.LinearAlgebra as HM
> import qualified System.Random.MWC as MWC
> import qualified System.Random.MWC.Distributions as MWC
Introduction
============
The *[backprop][hackage]* library lets us manipulate our values in a
natural way. We write the function to compute our result, and the library
then automatically finds the *gradient* of that function, which we can use
for gradient descent.
[hackage]: http://hackage.haskell.org/package/backprop
In the last post, we looked at using a fixed-structure neural network.
However, in [this blog series][blog], I discuss a system of extensible
neural networks that can be chained and composed.
[blog]: https://blog.jle.im/entries/series/+practical-dependent-types-in-haskell.html
One issue, however, in naively translating the implementations, is that we
normally run the network by pattern matching on each layer. However, we
cannot directly pattern match on `BVar`s.
We *could* get around it by being smart with prisms and `^^?`, to extract a
"Maybe BVar". However, we can do better! This is because the *shape* of a
`Net i hs o` is known already at compile-time, so there is no need for
runtime checks like prisms and `^^?`.
Instead, we can just directly use lenses, since we know *exactly* what
constructor will be present! We can use singletons to determine which
constructor is present, and so always just directly use lenses without any
runtime nondeterminism.
Types
=====
First, our types:
> data Layer i o =
> Layer { _lWeights :: !(L o i)
> , _lBiases :: !(R o)
> }
> deriving (Show, Generic)
>
> instance NFData (Layer i o)
> makeLenses ''Layer
>
> data Net :: Nat -> [Nat] -> Nat -> Type where
> NO :: !(Layer i o) -> Net i '[] o
> (:~) :: !(Layer i h) -> !(Net h hs o) -> Net i (h ': hs) o
Unfortunately, we can't automatically generate lenses for GADTs, so we have
to make them by hand.[^poly]
[^poly]: We write them originally as a polymorphic lens family to help us
with type safety via paraemtric polymorphism.
> _NO :: Lens (Net i '[] o) (Net i' '[] o')
> (Layer i o ) (Layer i' o' )
> _NO f (NO l) = NO <$> f l
>
> _NIL :: Lens (Net i (h ': hs) o) (Net i' (h ': hs) o)
> (Layer i h ) (Layer i' h )
> _NIL f (l :~ n) = (:~ n) <$> f l
>
> _NIN :: Lens (Net i (h ': hs) o) (Net i (h ': hs') o')
> (Net h hs o) (Net h hs' o')
> _NIN f (l :~ n) = (l :~) <$> f n
You can read `_NO` as:
```haskell
_NO :: Lens' (Net i '[] o) (Layer i o)
```
A lens into a single-layer network, and
```haskell
_NIL :: Lens' (Net i (h ': hs) o) (Layer i h )
_NIN :: Lens' (Net i (h ': hs) o) (Net h hs o)
```
Lenses into a multiple-layer network, getting the first layer and the tail
of the network.
If we pattern match on `Sing hs`, we can always determine exactly which
lenses we can use, and so never fumble around with prisms or
nondeterminism.
Running the network
===================
Here's the meat of process, then: specifying how to run the network. We
re-use our `BVar`-based combinators defined in the last write-up:
> runLayer
> :: (KnownNat i, KnownNat o, Reifies s W)
> => BVar s (Layer i o)
> -> BVar s (R i)
> -> BVar s (R o)
> runLayer l x = (l ^^. lWeights) #>! x + (l ^^. lBiases)
> {-# INLINE runLayer #-}
For `runNetwork`, we pattern match on `hs` using singletons, so we always
know exactly what type of network we have:
> runNetwork
> :: (KnownNat i, KnownNat o, Reifies s W)
> => BVar s (Net i hs o)
> -> Sing hs
> -> BVar s (R i)
> -> BVar s (R o)
> runNetwork n = \case
> SNil -> softMax . runLayer (n ^^. _NO)
> SCons SNat hs -> withSingI hs $
> runNetwork (n ^^. _NIN) hs
> . logistic
> . runLayer (n ^^. _NIL)
> {-# INLINE runNetwork #-}
The rest of it is the same as before.
> netErr
> :: (KnownNat i, KnownNat o, SingI hs, Reifies s W)
> => R i
> -> R o
> -> BVar s (Net i hs o)
> -> BVar s Double
> netErr x targ n = crossEntropy targ (runNetwork n sing (constVar x))
> {-# INLINE netErr #-}
>
> trainStep
> :: forall i hs o. (KnownNat i, KnownNat o, SingI hs)
> => Double -- ^ learning rate
> -> R i -- ^ input
> -> R o -- ^ target
> -> Net i hs o -- ^ initial network
> -> Net i hs o
> trainStep r !x !targ !n = n - realToFrac r * gradBP (netErr x targ) n
> {-# INLINE trainStep #-}
>
> trainList
> :: (KnownNat i, SingI hs, KnownNat o)
> => Double -- ^ learning rate
> -> [(R i, R o)] -- ^ input and target pairs
> -> Net i hs o -- ^ initial network
> -> Net i hs o
> trainList r = flip $ foldl' (\n (x,y) -> trainStep r x y n)
> {-# INLINE trainList #-}
>
> testNet
> :: forall i hs o. (KnownNat i, KnownNat o, SingI hs)
> => [(R i, R o)]
> -> Net i hs o
> -> Double
> testNet xs n = sum (map (uncurry test) xs) / fromIntegral (length xs)
> where
> test :: R i -> R o -> Double -- test if the max index is correct
> test x (extract->t)
> | HM.maxIndex t == HM.maxIndex (extract r) = 1
> | otherwise = 0
> where
> r :: R o
> r = evalBP (\n' -> runNetwork n' sing (constVar x)) n
And that's it!
Running
=======
Everything here is the same as before, except now we can dynamically pick
the network size. Here we pick `'[300,100]` for the hidden layer sizes.
> main :: IO ()
> main = MWC.withSystemRandom $ \g -> do
> Just train <- loadMNIST "data/train-images-idx3-ubyte" "data/train-labels-idx1-ubyte"
> Just test <- loadMNIST "data/t10k-images-idx3-ubyte" "data/t10k-labels-idx1-ubyte"
> putStrLn "Loaded data."
> net0 <- MWC.uniformR @(Net 784 '[300,100] 10) (-0.5, 0.5) g
> flip evalStateT net0 . forM_ [1..] $ \e -> do
> train' <- liftIO . fmap V.toList $ MWC.uniformShuffle (V.fromList train) g
> liftIO $ printf "[Epoch %d]\n" (e :: Int)
>
> forM_ ([1..] `zip` chunksOf batch train') $ \(b, chnk) -> StateT $ \n0 -> do
> printf "(Batch %d)\n" (b :: Int)
>
> t0 <- getCurrentTime
> n' <- evaluate . force $ trainList rate chnk n0
> t1 <- getCurrentTime
> printf "Trained on %d points in %s.\n" batch (show (t1 `diffUTCTime` t0))
>
> let trainScore = testNet chnk n'
> testScore = testNet test n'
> printf "Training error: %.2f%%\n" ((1 - trainScore) * 100)
> printf "Validation error: %.2f%%\n" ((1 - testScore ) * 100)
>
> return ((), n')
> where
> rate = 0.02
> batch = 5000
Looking Forward
===============
One common thing people might do is want to be able to mix different types
of layers. This could also be easily encoded as different constructors in
`Layer`, and so `runLayer` will now be different depending on what
constructor is present.
In this case, we can either:
1. Have a different indexed type for layers, so that we can always know
exactly what layer is involved, so we don't have to runtime pattern
match:
```haskell
data LayerType = FullyConnected | Convolutional
data Layer :: LayerType -> Nat -> Nat -> Type where
LayerFC :: .... -> Layer 'FullyConnected i o
LayerC :: .... -> Layer 'Convolutional i o
```
We would then have `runLayer` take `Sing (t :: LayerType)`, so we can
again use `^^.` and directly pattern match.
2. Use a typeclass-based approach, so users can add their own layer types.
In this situation, layer types would all be different types, and
running them would be a typeclass method that would give our
`BVar s (Layer i o) -> BVar s (R i) -> BVar s (R o)` operation as a
typeclass method.
```haskell
class Layer (l :: Nat -> Nat -> Type) where
runLayer
:: forall s. Reifies s W
=> BVar s (l i o)
-> BVar s (R i)
-> BVar s (R o)
```
In all cases, it shouldn't be much more cognitive overhead to use
*backprop* to build your neural network framework!
And, remember that `evalBP` (directly running the function) introduces
virtually zero overhead, so if you only provided `BVar` functions, you
could easily get the original non-`BVar` functions with `evalBP` without
any loss.
What now?
---------
Ready to start? Check out the docs for the [Numeric.Backprop][] module for
the full technical specs, and find more examples and updates at the [github
repo][repo]!
[Numeric.Backprop]: http://hackage.haskell.org/package/backprop/docs/Numeric-Backprop.html
[repo]: https://github.com/mstksg/backprop
Internals
=========
That's it for the post! Now for the internal plumbing :)
> loadMNIST
> :: FilePath
> -> FilePath
> -> IO (Maybe [(R 784, R 10)])
> loadMNIST fpI fpL = runMaybeT $ do
> i <- MaybeT $ decodeIDXFile fpI
> l <- MaybeT $ decodeIDXLabelsFile fpL
> d <- MaybeT . return $ labeledIntData l i
> r <- MaybeT . return $ for d (bitraverse mkImage mkLabel . swap)
> liftIO . evaluate $ force r
> where
> mkImage :: VU.Vector Int -> Maybe (R 784)
> mkImage = create . VG.convert . VG.map (\i -> fromIntegral i / 255)
> mkLabel :: Int -> Maybe (R 10)
> mkLabel n = create $ HM.build 10 (\i -> if round i == n then 1 else 0)
HMatrix Operations
------------------
> infixr 8 #>!
> (#>!)
> :: (KnownNat m, KnownNat n, Reifies s W)
> => BVar s (L m n)
> -> BVar s (R n)
> -> BVar s (R m)
> (#>!) = liftOp2 . op2 $ \m v ->
> ( m #> v, \g -> (g `outer` v, tr m #> g) )
>
> infixr 8 <.>!
> (<.>!)
> :: (KnownNat n, Reifies s W)
> => BVar s (R n)
> -> BVar s (R n)
> -> BVar s Double
> (<.>!) = liftOp2 . op2 $ \x y ->
> ( x <.> y, \g -> (konst g * y, x * konst g)
> )
>
> konst'
> :: (KnownNat n, Reifies s W)
> => BVar s Double
> -> BVar s (R n)
> konst' = liftOp1 . op1 $ \c -> (konst c, HM.sumElements . extract)
>
> sumElements'
> :: (KnownNat n, Reifies s W)
> => BVar s (R n)
> -> BVar s Double
> sumElements' = liftOp1 . op1 $ \x -> (HM.sumElements (extract x), konst)
>
> softMax :: (KnownNat n, Reifies s W) => BVar s (R n) -> BVar s (R n)
> softMax x = konst' (1 / sumElements' expx) * expx
> where
> expx = exp x
> {-# INLINE softMax #-}
>
> crossEntropy
> :: (KnownNat n, Reifies s W)
> => R n
> -> BVar s (R n)
> -> BVar s Double
> crossEntropy targ res = -(log res <.>! constVar targ)
> {-# INLINE crossEntropy #-}
>
> logistic :: Floating a => a -> a
> logistic x = 1 / (1 + exp (-x))
> {-# INLINE logistic #-}
Instances
---------
> instance (KnownNat i, KnownNat o) => Num (Layer i o) where
> (+) = gPlus
> (-) = gMinus
> (*) = gTimes
> negate = gNegate
> abs = gAbs
> signum = gSignum
> fromInteger = gFromInteger
>
> instance (KnownNat i, KnownNat o) => Fractional (Layer i o) where
> (/) = gDivide
> recip = gRecip
> fromRational = gFromRational
>
> instance (KnownNat i, KnownNat o) => Backprop (Layer i o)
>
>
> liftNet0
> :: forall i hs o. (KnownNat i, KnownNat o)
> => (forall m n. (KnownNat m, KnownNat n) => Layer m n)
> -> Sing hs
> -> Net i hs o
> liftNet0 x = go
> where
> go :: forall w ws. KnownNat w => Sing ws -> Net w ws o
> go = \case
> SNil -> NO x
> SCons SNat hs -> x :~ go hs
>
> liftNet1
> :: forall i hs o. (KnownNat i, KnownNat o)
> => (forall m n. (KnownNat m, KnownNat n)
> => Layer m n
> -> Layer m n
> )
> -> Sing hs
> -> Net i hs o
> -> Net i hs o
> liftNet1 f = go
> where
> go :: forall w ws. KnownNat w
> => Sing ws
> -> Net w ws o
> -> Net w ws o
> go = \case
> SNil -> \case
> NO x -> NO (f x)
> SCons SNat hs -> \case
> x :~ xs -> f x :~ go hs xs
>
> liftNet2
> :: forall i hs o. (KnownNat i, KnownNat o)
> => (forall m n. (KnownNat m, KnownNat n)
> => Layer m n
> -> Layer m n
> -> Layer m n
> )
> -> Sing hs
> -> Net i hs o
> -> Net i hs o
> -> Net i hs o
> liftNet2 f = go
> where
> go :: forall w ws. KnownNat w
> => Sing ws
> -> Net w ws o
> -> Net w ws o
> -> Net w ws o
> go = \case
> SNil -> \case
> NO x -> \case
> NO y -> NO (f x y)
> SCons SNat hs -> \case
> x :~ xs -> \case
> y :~ ys -> f x y :~ go hs xs ys
>
> instance ( KnownNat i
> , KnownNat o
> , SingI hs
> )
> => Num (Net i hs o) where
> (+) = liftNet2 (+) sing
> (-) = liftNet2 (-) sing
> (*) = liftNet2 (*) sing
> negate = liftNet1 negate sing
> abs = liftNet1 abs sing
> signum = liftNet1 signum sing
> fromInteger x = liftNet0 (fromInteger x) sing
>
> instance ( KnownNat i
> , KnownNat o
> , SingI hs
> )
> => Fractional (Net i hs o) where
> (/) = liftNet2 (/) sing
> recip = liftNet1 negate sing
> fromRational x = liftNet0 (fromRational x) sing
>
> instance (KnownNat i, KnownNat o, SingI hs) => Backprop (Net i hs o) where
> zero = liftNet1 zero sing
> add = liftNet2 add sing
> one = liftNet1 one sing
>
> instance KnownNat n => MWC.Variate (R n) where
> uniform g = randomVector <$> MWC.uniform g <*> pure Uniform
> uniformR (l, h) g = (\x -> x * (h - l) + l) <$> MWC.uniform g
>
> instance (KnownNat m, KnownNat n) => MWC.Variate (L m n) where
> uniform g = uniformSample <$> MWC.uniform g <*> pure 0 <*> pure 1
> uniformR (l, h) g = (\x -> x * (h - l) + l) <$> MWC.uniform g
>
> instance (KnownNat i, KnownNat o) => MWC.Variate (Layer i o) where
> uniform g = Layer <$> MWC.uniform g <*> MWC.uniform g
> uniformR (l, h) g = (\x -> x * (h - l) + l) <$> MWC.uniform g
>
> instance ( KnownNat i
> , KnownNat o
> , SingI hs
> )
> => MWC.Variate (Net i hs o) where
> uniform :: forall m. PrimMonad m => MWC.Gen (PrimState m) -> m (Net i hs o)
> uniform g = go sing
> where
> go :: forall w ws. KnownNat w => Sing ws -> m (Net w ws o)
> go = \case
> SNil -> NO <$> MWC.uniform g
> SCons SNat hs -> (:~) <$> MWC.uniform g <*> go hs
> uniformR (l, h) g = (\x -> x * (h - l) + l) <$> MWC.uniform g
>
> instance NFData (Net i hs o) where
> rnf = \case
> NO l -> rnf l
> x :~ xs -> rnf x `seq` rnf xs
>
> instance Backprop (R n) where
> zero = zeroNum
> add = addNum
> one = oneNum
>
> instance (KnownNat n, KnownNat m) => Backprop (L m n) where
> zero = zeroNum
> add = addNum
> one = oneNum
[hmatrix-backprop]: http://hackage.haskell.org/package/hmatrix-backprop