Skip to content

Commit

Permalink
Add a typed mode operator.
Browse files Browse the repository at this point in the history
  • Loading branch information
cdepillabout committed Jun 16, 2020
1 parent 78d504f commit a48dfa8
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions hasktorch/src/Torch/Typed/Functional.hs
Expand Up @@ -1406,6 +1406,38 @@ allclose
allclose rtol atol equalNaN input other =
unsafePerformIO $ ATen.cast5 ATen.Managed.allclose_ttddb input other rtol atol equalNaN

-- | mode
-- See https://pytorch.org/docs/stable/torch.html#torch.mode.
--
-- >>> t = fromJust [[0, 5], [0, 2], [3, 5]] :: CPUTensor 'D.Int64 '[3, 2]
--
-- >>> (modes :: CPUTensor 'D.Int64 '[2], indicies :: CPUTensor 'D.Int64 '[2]) = mode @0 @DropDim t
-- >>> (dtype modes, shape modes, D.asValue (toDynamic modes) :: [Int])
-- (Int64,[2],[0,5])
-- >>> (dtype indicies, shape indicies, D.asValue (toDynamic indicies) :: [Int])
-- (Int64,[2],[1,2])
--
-- >>> t = fromJust [[0, 0], [0, 1], [3, 3]] :: CPUTensor 'D.Float '[3, 2]
--
-- >>> (modes :: CPUTensor 'D.Float '[3,1], indicies :: CPUTensor 'D.Int64 '[3,1]) = mode @1 @KeepDim t
-- >>> (dtype modes, shape modes, D.asValue (toDynamic modes) :: [[Float]])
-- (Float,[3,1],[[0.0],[0.0],[3.0]])
-- >>> (dtype indicies, shape indicies, D.asValue (toDynamic indicies) :: [[Int]])
-- (Int64,[3,1],[[1],[0],[1]])
mode
:: forall dim keepOrDropDim shape' shape dtype device
. ( KnownNat dim
, KnownKeepOrDropDim keepOrDropDim
, shape' ~ ConditionalDropDimension shape dim keepOrDropDim
, DTypeIsNotBool device dtype
)
=> Tensor device dtype shape -- ^ input
-> (Tensor device dtype shape', Tensor device 'D.Int64 shape') -- ^ output
mode input = unsafePerformIO $ ATen.cast3 ATen.Managed.mode_tlb
input
(natValI @dim)
(keepOrDropDimVal @keepOrDropDim)

-- | argmax
-- See https://pytorch.org/docs/stable/torch.html#torch.argmax.
--
Expand Down

0 comments on commit a48dfa8

Please sign in to comment.