In [1]:
{-# LANGUAGE BangPatterns          #-}
{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE TupleSections         #-}
{-# LANGUAGE TypeFamilies          #-}
import           Control.Monad
import           Control.Monad.Random
import           Data.List ( foldl' )

import qualified Data.ByteString as B
import           Data.Serialize
import           Data.Semigroup ( (<>) )

import           GHC.TypeLits

import qualified Numeric.LinearAlgebra.Static as SA

import           Options.Applicative

import           Grenade

In [14]:
type FFNet = Network '[ FullyConnected 2 40, Tanh, FullyConnected 40 10, Relu, FullyConnected 10 1, Logit ]
                     '[ D1 2, D1 40, D1 40, D1 10, D1 10, D1 1, D1 1]

randomNet :: MonadRandom m => m FFNet
randomNet = randomNetwork

In [3]:
netTrain :: MonadRandom m => FFNet -> LearningParameters -> Int -> Int -> m FFNet
netTrain net0 rate n epochs = do
    inps <- replicateM n $ do
      s  <- getRandom
      return $ S1D $ SA.randomVector s SA.Uniform * 2 - 1
    let outs = flip map inps $ \(S1D v) ->
                 if v `inCircle` (fromRational 0.33, 0.33)  || v `inCircle` (fromRational (-0.33), 0.33)
                   then S1D $ fromRational 1
                   else S1D $ fromRational 0

    let trainingData = concat $ foldr (const $ (zip inps outs :)) [] [1..epochs] -- TODO: each epoch should be randomized
    let trained = foldl' trainEach net0 trainingData
    return trained

  where
    inCircle :: KnownNat n => SA.R n -> (SA.R n, Double) -> Bool
    v `inCircle` (o, r) = SA.norm_2 (v - o) <= r
    trainEach !network (i,o) = train rate network i o

In [4]:
netLoad :: FilePath -> IO FFNet
netLoad modelPath = do
  modelData <- B.readFile modelPath
  either fail return $ runGet (get :: Get FFNet) modelData

netScore :: FFNet -> IO ()
netScore network = do
    let testIns = [ [ (x,y)  | x <- [0..50] ]
                             | y <- [20,19..0] ]
        outMat  = fmap (fmap (\(x,y) -> (render . normx) $ runNet network (S1D $ SA.vector [x / 25 - 1,y / 10 - 1]))) testIns
    putStrLn $ unlines outMat

  where
    render n'  | n' <= 0.2  = ' '
               | n' <= 0.4  = '.'
               | n' <= 0.6  = '-'
               | n' <= 0.8  = '='
               | otherwise = '#'

    normx :: S ('D1 1) -> Double
    normx (S1D r) = SA.mean r

In [15]:
net0 <- randomNet
netScore net0

###################################################
###################################################
###################################################
###################################################
###################################################
###################################################
###################################################
###################################################
###################################################
###################################################
###################################################
###################################################
###########=====###################################

In [17]:
net <- netTrain net0 (LearningParameters 0.01 0.9 0.0005) 250 400
netScore net

                                                   
                                                   
                                                   
                                                   
                            -=######=..            
                          -############-           
                        -##############=.          
                       -###############=.          
                        .=##############-          
                          -=#######==--.           
               .-==-..     ...                     
           .-=######==-.                           
          =############=-                          
         .##############.                          
         .#############=.                          
          .=###########-                           
            .-=#####=-.                            
               .