diff --git a/.travis.yml b/.travis.yml index 7055c67..46d7a50 100644 --- a/.travis.yml +++ b/.travis.yml @@ -45,14 +45,6 @@ matrix: compiler: ": #stack default" addons: {apt: {packages: [ghc-8.2.2,libblas-dev,liblapack-dev], sources: [hvr-ghc]}} - - env: BUILD=stack ARGS="--resolver lts-8" - compiler: ": #stack 8.0.2" - addons: {apt: {packages: [ghc-8.0.2,libblas-dev,liblapack-dev], sources: [hvr-ghc]}} - - - env: BUILD=stack ARGS="--resolver lts-9" - compiler: ": #stack 8.0.2" - addons: {apt: {packages: [ghc-8.0.2,libblas-dev,liblapack-dev], sources: [hvr-ghc]}} - - env: BUILD=stack ARGS="--resolver lts-10" compiler: ": #stack 8.2.2" addons: {apt: {packages: [ghc-8.2.2,libblas-dev,liblapack-dev], sources: [hvr-ghc]}} @@ -67,14 +59,6 @@ matrix: compiler: ": #stack default osx" os: osx - - env: BUILD=stack ARGS="--resolver lts-8" - compiler: ": #stack 8.0.2 osx" - os: osx - - - env: BUILD=stack ARGS="--resolver lts-9" - compiler: ": #stack 8.0.2 osx" - os: osx - - env: BUILD=stack ARGS="--resolver lts-10" compiler: ": #stack 8.2.2 osx" os: osx @@ -126,7 +110,7 @@ install: set -ex case "$BUILD" in stack) - stack --no-terminal --install-ghc $ARGS test --bench --only-dependencies + stack --no-terminal --install-ghc $ARGS test --bench --only-dependencies ;; cabal) cabal --version @@ -141,10 +125,10 @@ script: set -ex case "$BUILD" in stack) - stack --no-terminal $ARGS test --bench --no-run-benchmarks --haddock --no-haddock-deps + stack --no-terminal $ARGS test --bench --no-run-benchmarks --haddock --no-haddock-deps --ghc-options=-Werror ;; cabal) - cabal install --enable-tests --enable-benchmarks --force-reinstalls --ghc-options=-O0 --reorder-goals --max-backjumps=-1 $CABALARGS $PACKAGES + cabal install --enable-tests --enable-benchmarks --force-reinstalls --ghc-options="-O0 -Werror" --reorder-goals --max-backjumps=-1 $CABALARGS $PACKAGES ORIGDIR=$(pwd) for dir in $PACKAGES diff --git a/package.yaml b/package.yaml index ba8e66d..867842e 100644 --- a/package.yaml +++ b/package.yaml @@ -21,29 +21,32 @@ description: Please see the README on Github at = 4.7 && < 5 -- hmatrix >= 0.18.1 - backprop >= 0.1.2 -- ghc-typelits-natnormalise -- ghc-typelits-knownnat +- base >= 4.7 && < 5 +- hmatrix >= 0.18 - microlens -- vector -- vector-sized +- vector-sized >= 0.6 library: source-dirs: src exposed-modules: - Numeric.LinearAlgebra.Static.Backprop - ghc-options: - - -fwarn-redundant-constraints dependencies: - ANum >= 0.2 + - ghc-typelits-knownnat + - ghc-typelits-natnormalise + - vector tests: hmatrix-backprop-test: main: Spec.hs + other-modules: + - Nudge source-dirs: test ghc-options: - -threaded diff --git a/src/Numeric/LinearAlgebra/Static/Backprop.hs b/src/Numeric/LinearAlgebra/Static/Backprop.hs index 1008ad6..2d4cc8e 100644 --- a/src/Numeric/LinearAlgebra/Static/Backprop.hs +++ b/src/Numeric/LinearAlgebra/Static/Backprop.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} @@ -9,6 +10,10 @@ {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +#if MIN_VERSION_base(4,11,0) +#else +{-# OPTIONS_GHC -Wno-compat #-} +#endif -- | -- Module : Numeric.LinearAlgebra.Static.Backprop @@ -80,9 +85,6 @@ -- Some other notes: -- -- * Added 'sumElements', as well, for convenience. --- * Lifted 'H.svd' is temporarily currently not exported, due to a bug --- in /hmatrix/ in the 'H.diagR' function. When this bug is patched, --- 'H.svd' will exported. module Numeric.LinearAlgebra.Static.Backprop ( -- * Vector @@ -124,9 +126,8 @@ module Numeric.LinearAlgebra.Static.Backprop ( , (#>) , (<.>) -- * Factorizations - -- $svd - -- , svd - -- , svd_ + , svd + , svd_ , H.Eigen , eigensystem , eigenvalues @@ -200,6 +201,10 @@ import qualified Numeric.LinearAlgebra as HU import qualified Numeric.LinearAlgebra.Devel as HU import qualified Numeric.LinearAlgebra.Static as H +#if MIN_VERSION_base(4,11,0) +import Prelude hiding ((<>)) +#endif + vec2 :: Reifies s W => BVar s H.ℝ @@ -436,35 +441,29 @@ infixr 8 #> infixr 8 <.> {-# INLINE (<.>) #-} --- $svd --- --- Note: Lifted versions of 'H.svd' temporarily unexported due to a bug in --- /hmatrix/, in 'H.diagR'. - -- | Can only get the singular values, for now. Let me know if you find an -- algorithm that can compute the gradients based on differentials for the -- other matricies! -- --- TODO: bug in diagR -_svd :: forall m n s. (Reifies s W, KnownNat m, KnownNat n) +svd :: forall m n s. (Reifies s W, KnownNat m, KnownNat n) => BVar s (H.L m n) -> BVar s (H.R n) -_svd = liftOp1 . op1 $ \x -> +svd = liftOp1 . op1 $ \x -> let (u, σ, v) = H.svd x in ( σ - , \dΣ -> u H.<> H.diagR 0 dΣ H.<> H.tr v + , \(dΣ :: H.R n) -> (u H.<> H.diagR 0 dΣ) H.<> H.tr v + -- must manually associate because of bug in diagR in + -- hmatrix-0.18.2.0 ) -{-# INLINE _svd #-} +{-# INLINE svd #-} -- | Version of 'svd' that returns the full SVD, but if you attempt to find -- the gradient, it will fail at runtime if you ever use U or V. --- --- Useful if you want to only use 'evalBP'. -_svd_ +svd_ :: forall m n s. (Reifies s W, KnownNat m, KnownNat n) => BVar s (H.L m n) -> (BVar s (H.L m m), BVar s (H.R n), BVar s (H.L n n)) -_svd_ r = (t ^^. _1, t ^^. _2, t ^^. _3) +svd_ r = (t ^^. _1, t ^^. _2, t ^^. _3) where o :: Op '[H.L m n] (T3 (H.L m m) (H.R n) (H.L n n)) o = op1 $ \x -> @@ -472,13 +471,13 @@ _svd_ r = (t ^^. _1, t ^^. _2, t ^^. _3) in ( T3 u σ v , \(T3 dU dΣ dV) -> if H.norm_0 dU == 0 && H.norm_0 dV == 0 - then u H.<> H.diagR 0 dΣ H.<> H.tr v + then (u H.<> H.diagR 0 dΣ) H.<> H.tr v else error "svd_: Cannot backprop if U and V are used." ) {-# INLINE o #-} t = liftOp1 o r {-# NOINLINE t #-} -{-# INLINE _svd_ #-} +{-# INLINE svd_ #-} helpEigen :: KnownNat n => H.Sym n -> (H.R n, H.L n n, H.L n n, H.L n n) helpEigen x = (l, v, H.inv v, H.tr v) @@ -792,9 +791,6 @@ cross = liftOp2 . op2 $ \x y -> {-# INLINE cross #-} -- | Create matrix with diagonal, and fill with default entries --- --- Note that this inherits the bug in 'H.diagR' if used with a version of --- /hmatrix/ wiith the bug (currently, 0.18.2.0). diagR :: forall m n k field vec mat s. ( Reifies s W @@ -836,6 +832,9 @@ dvmap f = liftOp1 . op1 $ \x -> ) {-# INLINE dvmap #-} +-- TODO: Can be made more efficient if backprop exports +-- a custom-total-derivative version + -- | A version of 'dvmap' that is less performant but is based on -- 'H.zipWithVector' from 'H.Domain'. dvmap' diff --git a/test/Nudge.hs b/test/Nudge.hs index 55f06cf..fc6c4c9 100644 --- a/test/Nudge.hs +++ b/test/Nudge.hs @@ -1,12 +1,13 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# OPTIONS_GHC -fno-warn-redundant-constraints #-} module Nudge where @@ -16,7 +17,7 @@ import Data.Finite import Data.Kind import Data.Maybe import Data.Proxy -import GHC.TypeNats +import GHC.TypeLits import Hedgehog import Lens.Micro import Lens.Micro.Platform () @@ -70,7 +71,7 @@ instance KnownNat n => Testing (H.R n) where scalarize = B.norm_2V genTest = H.vector <$> replicateM n genTest where - n = fromIntegral $ natVal (Proxy @n) + n = fromInteger $ natVal (Proxy @n) instance (KnownNat n, KnownNat m) => Testing (H.L n m) where type TIx (H.L n m) = (Int, Int) @@ -79,7 +80,7 @@ instance (KnownNat n, KnownNat m) => Testing (H.L n m) where scalarize = sqrt . B.sumElements . (**2) genTest = H.matrix <$> replicateM nm genTest where - nm = fromIntegral $ natVal (Proxy @n) * natVal (Proxy @m) + nm = fromInteger $ natVal (Proxy @n) * natVal (Proxy @m) instance Testing (HU.Vector Double) where type TIx (HU.Vector Double) = Int diff --git a/test/Spec.hs b/test/Spec.hs index 7d1ddb0..2e4f282 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -88,13 +88,11 @@ prop_tr = nudgeProp @(L 3 2) B.tr prop_diag :: Property prop_diag = nudgeProp @(R 3) B.diag --- TODO: bug in diagR --- prop_svd :: Property --- prop_svd = nudgeProp (genMat @5 @4) B.svd +prop_svd :: Property +prop_svd = nudgeProp @(L 3 2) B.svd --- TODO: bug in diagR --- prop_svd_ :: Property --- prop_svd_ = nudgeProp (genMat @5 @4) ((\(_,x,_) -> x) . B.svd_) +prop_svd_ :: Property +prop_svd_ = nudgeProp @(L 3 2) ((\(_,x,_) -> x) . B.svd_) prop_eigensystem1 :: Property prop_eigensystem1 = nudgeProp @(L 3 2) (fst . B.eigensystem . B.mTm)