Skip to content

Commit

Permalink
split Linear from Layers
Browse files Browse the repository at this point in the history
  • Loading branch information
stites committed Jul 2, 2018
1 parent fbb8dee commit e67893c
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 53 deletions.
1 change: 0 additions & 1 deletion cabal.project.local-example
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

optimization: False
debug-info: True
executable-stripping: False
Expand Down
8 changes: 8 additions & 0 deletions core/hasktorch-core.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ library
, Torch.Core.Random
, Torch.Core.LogAdd

reexported-modules:
Torch.Types.Numeric

exposed-modules:
Torch.Byte
, Torch.Byte.Dynamic
Expand Down Expand Up @@ -604,6 +607,7 @@ library
, Torch.Double.NN.Conv2d
, Torch.Double.NN.Criterion
, Torch.Double.NN.Layers
, Torch.Double.NN.Linear
, Torch.Double.NN.Math
, Torch.Double.NN.Padding
, Torch.Double.NN.Pooling
Expand Down Expand Up @@ -681,6 +685,7 @@ library
,Torch.Indef.Static.NN.Conv2d as Torch.Double.NN.Conv2d
,Torch.Indef.Static.NN.Criterion as Torch.Double.NN.Criterion
,Torch.Indef.Static.NN.Layers as Torch.Double.NN.Layers
,Torch.Indef.Static.NN.Linear as Torch.Double.NN.Linear
,Torch.Indef.Static.NN.Math as Torch.Double.NN.Math
,Torch.Indef.Static.NN.Padding as Torch.Double.NN.Padding
,Torch.Indef.Static.NN.Pooling as Torch.Double.NN.Pooling
Expand Down Expand Up @@ -1149,6 +1154,7 @@ library
, Torch.Cuda.Double.NN.Conv2d
, Torch.Cuda.Double.NN.Criterion
, Torch.Cuda.Double.NN.Layers
, Torch.Cuda.Double.NN.Linear
, Torch.Cuda.Double.NN.Math
, Torch.Cuda.Double.NN.Padding
, Torch.Cuda.Double.NN.Pooling
Expand Down Expand Up @@ -1222,6 +1228,7 @@ library
,Torch.Indef.Static.NN.Conv2d as Torch.Cuda.Double.NN.Conv2d
,Torch.Indef.Static.NN.Criterion as Torch.Cuda.Double.NN.Criterion
,Torch.Indef.Static.NN.Layers as Torch.Cuda.Double.NN.Layers
,Torch.Indef.Static.NN.Linear as Torch.Cuda.Double.NN.Linear
,Torch.Indef.Static.NN.Math as Torch.Cuda.Double.NN.Math
,Torch.Indef.Static.NN.Padding as Torch.Cuda.Double.NN.Padding
,Torch.Indef.Static.NN.Pooling as Torch.Cuda.Double.NN.Pooling
Expand Down Expand Up @@ -1466,6 +1473,7 @@ library hasktorch-indef-floating
, Torch.Indef.Static.NN.Conv2d
, Torch.Indef.Static.NN.Criterion
, Torch.Indef.Static.NN.Layers
, Torch.Indef.Static.NN.Linear
, Torch.Indef.Static.NN.Math
, Torch.Indef.Static.NN.Padding
, Torch.Indef.Static.NN.Pooling
Expand Down
4 changes: 3 additions & 1 deletion core/tests/Spec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import qualified GarbageCollectionSpec as GS
import qualified Torch.Core.LogAddSpec as LS
import qualified Torch.Core.RandomSpec as RS
import qualified Torch.Static.NN.AbsSpec as AbsNN
import qualified Torch.Static.NN.LinearSpec as LinearNN

main :: IO ()
main = hspec $ do
Expand All @@ -16,6 +17,7 @@ main = hspec $ do
describe "GarbageCollectionSpec" GS.spec
describe "Torch.Core.LogAddSpec" LS.spec
describe "Torch.Core.RandomSpec" RS.spec
describe "Torch.NN.Static.AbsSpec" AbsNN.spec
describe "Torch.Static.NN.AbsSpec" AbsNN.spec
describe "Torch.Static.NN.LinearSpec" LinearNN.spec


1 change: 1 addition & 0 deletions core/tests/Torch/Static/NN/LinearSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module Torch.Static.NN.LinearSpec where
import Test.Hspec
import Torch.Double
import Numeric.Backprop
import Torch.Double.NN.Linear

main :: IO ()
main = hspec spec
Expand Down
1 change: 1 addition & 0 deletions indef/hasktorch-indef.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ library
-- , Torch.Indef.Static.NN.Conv3d
, Torch.Indef.Static.NN.Criterion
, Torch.Indef.Static.NN.Layers
, Torch.Indef.Static.NN.Linear
, Torch.Indef.Static.NN.Math
, Torch.Indef.Static.NN.Padding
, Torch.Indef.Static.NN.Pooling
Expand Down
53 changes: 2 additions & 51 deletions indef/src/Torch/Indef/Static/NN/Layers.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
-- Maintainer: sam@stites.io
-- Stability : experimental
-- Portability: non-portable
--
-- Miscellaneous layer functions.
-------------------------------------------------------------------------------
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
Expand All @@ -24,57 +26,6 @@ import Torch.Indef.Static.Tensor.Math.Blas
import Torch.Indef.Static.NN.Backprop ()
import qualified Torch.Indef.Dynamic.NN as Dynamic

-- | datatype representing a linear layer with bias. Represents
-- @y = Ax + b@.
newtype Linear i o
= Linear { getTensors :: (Tensor '[i, o], Tensor '[o]) }

instance (KnownDim i, KnownDim o) => Show (Linear i o) where
show c = intercalate ","
[ "Linear ("
++ "input: " ++ show (inputSize c)
, " output: " ++ show (outputSize c)
++ ")"
]

instance (KnownDim i, KnownDim o) => Backprop (Linear i o) where
zero = const . Linear $ (constant 0, constant 0)
one = const . Linear $ (constant 1, constant 1)
add c0 c1 = Linear (weights c0 + weights c1, bias c0 + bias c1)

-- | the dense weight matrix of a linear layer
weights :: Linear i o -> Tensor '[i, o]
weights (Linear (w, _)) = w

-- | the bias vector of a linear layer
bias :: Linear i o -> Tensor '[o]
bias (Linear (_, b)) = b

-- | The input size of a linear layer
inputSize :: forall i o . KnownDim i => Linear i o -> Int
inputSize _ = fromIntegral (dimVal (dim :: Dim i))

-- | The output size of a linear layer
outputSize :: forall i o kW dW . KnownDim o => Linear i o -> Int
outputSize _ = fromIntegral (dimVal (dim :: Dim o))

-- ========================================================================= --

-- | Backprop linear function without batching
linear
:: forall s i o
. Reifies s W
=> All KnownDim '[i,o]
=> BVar s (Linear i o)
-> BVar s (Tensor '[i])
-> BVar s (Tensor '[o])
linear = liftOp2 $ op2 $ \l i -> (transpose2d (weights l) `mv` i + bias l, go l i)
where
go :: Linear i o -> Tensor '[i] -> Tensor '[o] -> (Linear i o, Tensor '[i])
go (Linear (w, b)) i gout = (Linear (i `outer` b', b'), w `mv` b')
where
b' = gout - b

-- | A backpropable 'flatten' operation
flattenBP
:: (Reifies s W, KnownDim (Product d), Dimensions (d::[Nat]))
Expand Down
80 changes: 80 additions & 0 deletions indef/src/Torch/Indef/Static/NN/Linear.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
-------------------------------------------------------------------------------
-- |
-- Module : Torch.Indef.Static.NN.Linear
-- Copyright : (c) Sam Stites 2017
-- License : BSD3
-- Maintainer: sam@stites.io
-- Stability : experimental
-- Portability: non-portable
--
-- Linear layers
-------------------------------------------------------------------------------
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MonoLocalBinds #-}
module Torch.Indef.Static.NN.Linear where

import Data.List
import Data.Singletons.Prelude.List hiding (All)
import Numeric.Backprop
import Numeric.Dimensions

import Torch.Indef.Types
import Torch.Indef.Static.Tensor
import Torch.Indef.Static.Tensor.Math
import Torch.Indef.Static.Tensor.Math.Blas
import Torch.Indef.Static.NN.Backprop ()
import qualified Torch.Indef.Dynamic.NN as Dynamic

-- | datatype representing a linear layer with bias. Represents
-- @y = Ax + b@.
newtype Linear i o
= Linear { getTensors :: (Tensor '[i, o], Tensor '[o]) }

instance (KnownDim i, KnownDim o) => Show (Linear i o) where
show c = intercalate ","
[ "Linear ("
++ "input: " ++ show (inputSize c)
, " output: " ++ show (outputSize c)
++ ")"
]

instance (KnownDim i, KnownDim o) => Backprop (Linear i o) where
zero = const . Linear $ (constant 0, constant 0)
one = const . Linear $ (constant 1, constant 1)
add c0 c1 = Linear (weights c0 + weights c1, bias c0 + bias c1)

-- | the dense weight matrix of a linear layer
weights :: Linear i o -> Tensor '[i, o]
weights (Linear (w, _)) = w

-- | the bias vector of a linear layer
bias :: Linear i o -> Tensor '[o]
bias (Linear (_, b)) = b

-- | The input size of a linear layer
inputSize :: forall i o . KnownDim i => Linear i o -> Int
inputSize _ = fromIntegral (dimVal (dim :: Dim i))

-- | The output size of a linear layer
outputSize :: forall i o kW dW . KnownDim o => Linear i o -> Int
outputSize _ = fromIntegral (dimVal (dim :: Dim o))

-- ========================================================================= --

-- | Backprop linear function without batching
linear
:: forall s i o
. Reifies s W
=> All KnownDim '[i,o]
=> BVar s (Linear i o)
-> BVar s (Tensor '[i])
-> BVar s (Tensor '[o])
linear = liftOp2 $ op2 $ \l i -> (transpose2d (weights l) `mv` i + bias l, go l i)
where
go :: Linear i o -> Tensor '[i] -> Tensor '[o] -> (Linear i o, Tensor '[i])
go (Linear (w, b)) i gout = (Linear (i `outer` b', b'), w `mv` b')
where
b' = gout - b


0 comments on commit e67893c

Please sign in to comment.