Skip to content

Commit

Permalink
tests now run, but don't pass?
Browse files Browse the repository at this point in the history
  • Loading branch information
mstksg committed May 3, 2018
1 parent e797e5c commit 40dd2ec
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 20 deletions.
44 changes: 27 additions & 17 deletions test/Nudge.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module Nudge where

Expand All @@ -21,7 +23,6 @@ import Hedgehog
import Lens.Micro
import Lens.Micro.Platform ()
import Numeric.Backprop
import Numeric.Backprop.Tuple
import qualified Data.Ix as Ix
import qualified Data.Vector.Sized as SV
import qualified Hedgehog.Gen as Gen
Expand All @@ -36,7 +37,7 @@ nudge = 1e-6
eps :: Double
eps = 1e-11

class (Num c, Show c, Show (TIx c)) => Testing c where
class (Backprop c, Show c, Show (TIx c)) => Testing c where
type TIx c :: Type
allIx :: c -> [TIx c]
ixLens :: TIx c -> Lens' c Double
Expand Down Expand Up @@ -95,7 +96,7 @@ instance Testing (HU.Matrix Double) where
scalarize = liftOp1 . op1 $ \xs -> (HU.sumElements xs, (`HU.konst` HU.size xs))
genTest = HU.fromLists <$> (replicateM 3 . replicateM 2) genTest

instance (KnownNat n, Testing a) => Testing (SV.Vector n a) where
instance (KnownNat n, Testing a, Num a) => Testing (SV.Vector n a) where
type TIx (SV.Vector n a) = (Finite n, TIx a)
allIx = fst . SV.imapM (\i x -> ((fromIntegral i,) <$> allIx x , x))
ixLens (i,j) = SV.ix i . ixLens j
Expand All @@ -105,30 +106,30 @@ instance (KnownNat n, Testing a) => Testing (SV.Vector n a) where
o = op1 $ \xs -> (SV.sum xs, SV.replicate)
genTest = SV.replicateM genTest

instance (Testing a, Testing b) => Testing (T2 a b) where
type TIx (T2 a b) = Either (TIx a) (TIx b)
allIx (T2 x y) = (Left <$> allIx x)
++ (Right <$> allIx y)
instance (Testing a, Testing b) => Testing (a, b) where
type TIx (a, b) = Either (TIx a) (TIx b)
allIx (x, y) = (Left <$> allIx x)
++ (Right <$> allIx y)
ixLens (Left i) = _1 . ixLens i
ixLens (Right j) = _2 . ixLens j
scalarize t = B.norm_2V (B.vec2 (scalarize (t ^^. _1))
(scalarize (t ^^. _2))
)
genTest = T2 <$> genTest <*> genTest
genTest = (,) <$> genTest <*> genTest

instance (Testing a, Testing b, Testing c, Num a, Num b, Num c) => Testing (T3 a b c) where
type TIx (T3 a b c) = Either (TIx a) (Either (TIx b) (TIx c))
allIx (T3 x y z) = (Left <$> allIx x)
++ (Right . Left <$> allIx y)
++ (Right . Right <$> allIx z)
instance (Testing a, Testing b, Testing c, Num a, Num b, Num c) => Testing (a, b, c) where
type TIx (a, b, c) = Either (TIx a) (Either (TIx b) (TIx c))
allIx (x, y, z) = (Left <$> allIx x)
++ (Right . Left <$> allIx y)
++ (Right . Right <$> allIx z)
ixLens (Left i ) = _1 . ixLens i
ixLens (Right (Left j)) = _2 . ixLens j
ixLens (Right (Right k)) = _3 . ixLens k
scalarize t = B.norm_2V (B.vec3 (scalarize (t ^^. _1))
(scalarize (t ^^. _2))
(scalarize (t ^^. _3))
)
genTest = T3 <$> genTest <*> genTest <*> genTest
genTest = (,,) <$> genTest <*> genTest <*> genTest

validGrad
:: Monad m
Expand Down Expand Up @@ -166,12 +167,21 @@ nudgeProp2 f = property $ do
(inpC, inpD, i) <- forAll $ do
inpC <- genTest
inpD <- genTest
i <- Gen.element (allIx (T2 inpC inpD))
i <- Gen.element (allIx (inpC, inpD))
return (inpC, inpD, i)
let (r, gr) = second tupT2 $ backprop2 (\x -> scalarize . f x) inpC inpD
let (r, gr) = backprop2 (\x -> scalarize . f x) inpC inpD
when (r**2 < eps) discard
(old, new) <- validGrad (ixLens i) (T2 inpC inpD) gr
(old, new) <- validGrad (ixLens i) (inpC, inpD) gr
(evalBP (\t -> scalarize $ f (t ^^. _1) (t ^^. _2)))
footnoteShow (r, gr, old, new, (old - new)**2, ((old - new)/old)**2)
assert $ ((old - new)/old)**2 < eps

instance (HU.Container HU.Vector a, Num a) => Backprop (HU.Matrix a) where
zero = HU.cmap (const 0)
add = HU.add
one = HU.cmap (const 1)

instance (KnownNat n, Num a) => Backprop (SV.Vector n a) where
zero = (0 <$)
add = (+)
one = (1 <$)
5 changes: 2 additions & 3 deletions test/Spec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import Hedgehog
import Lens.Micro
import Nudge
import Numeric.Backprop
import Numeric.Backprop.Tuple
import Numeric.LinearAlgebra.Static (L, R)
import System.Exit
import System.IO
Expand All @@ -21,11 +20,11 @@ prop_vec2 :: Property
prop_vec2 = nudgeProp2 B.vec2

prop_vec3 :: Property
prop_vec3 = nudgeProp @(T3 Double Double Double)
prop_vec3 = nudgeProp @(Double, Double, Double)
(\t -> B.vec3 (t ^^. _1) (t ^^. _2) (t ^^. _3))

prop_vec4 :: Property
prop_vec4 = nudgeProp2 @(T2 Double Double) @(T2 Double Double)
prop_vec4 = nudgeProp2 @(Double, Double) @(Double, Double)
(\x y -> B.vec4 (x ^^. _1) (x ^^. _2) (y ^^. _1) (y ^^. _2))

prop_snoc :: Property
Expand Down

0 comments on commit 40dd2ec

Please sign in to comment.