In [1]:
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE LiberalTypeSynonyms #-}



import qualified Torch as Untyped
import qualified Torch.Functional.Internal as Untyped
import Torch.Typed
import Torch.HList
import Torch.Typed.Auxiliary
import Torch.Typed.Factories
import Torch.Typed.Tensor
import qualified Torch.DType as D
import Data.Maybe (fromJust)
import GHC.Exts
import GHC.TypeLits
import Data.Proxy
import Data.Constraint
import Unsafe.Coerce (unsafeCoerce)

boxes :: Tensor '(CPU,0) 'Double '[3,4]
boxes = fromJust [[0,0,1,1],[0.5,0.5,1.5,1.5],[2,0,3,1]]

-- select @0 @0 boxes :: Tensor '(CPU,0) 'Double '[4]

unsafeConstraint :: forall c a. (c => a) -> a
unsafeConstraint = withDict (dummyDict @c)
  where
    dummyDict :: forall b. Dict b
    dummyDict = unsafeCoerce (Dict :: Dict ())

forEach :: forall n a. KnownNat n => (forall i. KnownNat i => Proxy i -> a) -> [a]
forEach func = map (\i -> withNat i func) [0.. (natValI @n -1)]

-- forEach :: forall bi d t. (KnownDevice d, KnownDType t, KnownNat bi) => Tensor d t '[bi] -> (forall bo. KnownNat bo => Tensor d t '[bi,5] -> Tensor d t '[bi]

maximum :: forall shape'' shape shape' dtype device. (shape'' ~ Broadcast shape shape')
        => Tensor device dtype shape -> Tensor device dtype shape' -> Tensor device dtype shape''
maximum a b = UnsafeMkTensor $ Untyped.maximum (toDynamic a) (toDynamic b)

minimum :: forall shape'' shape shape' dtype device. (shape'' ~ Broadcast shape shape')
        => Tensor device dtype shape -> Tensor device dtype shape' -> Tensor device dtype shape''
minimum a b = UnsafeMkTensor $ Untyped.minimum (toDynamic a) (toDynamic b)

logicalOr :: forall shape'' shape shape' device. (shape'' ~ Broadcast shape shape')
        => Tensor device 'Bool shape -> Tensor device 'Bool shape' -> Tensor device 'Bool shape''
logicalOr a b = UnsafeMkTensor $ Untyped.logical_or (toDynamic a) (toDynamic b)


iou :: forall bi d t n. 
    ( KnownDevice d
    , KnownDType t
    , KnownNat bi
    , KnownNat n
    , BasicArithmeticDTypeIsValid d t
    , InRange '[n] 0 0
    , InRange '[n] 0 1
    , InRange '[n] 0 2
    , InRange '[n] 0 3
    , InRange '[bi,n] 1 0
    , InRange '[bi,n] 1 1
    , InRange '[bi,n] 1 2
    , InRange '[bi,n] 1 3
    )
    => Tensor d t '[n] 
    -> Tensor d t '[bi,n] 
    -> Tensor d t '[bi]
iou source targets = 
  let sx1 = select @0 @0 source :: Tensor d t '[]
      sy1 = select @0 @1 source :: Tensor d t '[]
      sx2 = select @0 @2 source :: Tensor d t '[]
      sy2 = select @0 @3 source :: Tensor d t '[]
      tx1 = select @1 @0 targets :: Tensor d t '[bi]
      ty1 = select @1 @1 targets :: Tensor d t '[bi]
      tx2 = select @1 @2 targets :: Tensor d t '[bi]
      ty2 = select @1 @3 targets :: Tensor d t '[bi]
      dx = minimum tx2 sx2 - maximum tx1 sx1 :: Tensor d t '[bi]
      dy = minimum ty2 sy2 - maximum ty1 sy1 :: Tensor d t '[bi]  
      dxdy = dx * dy :: Tensor d t '[bi]
      s = (sx2 - sx1) * (sy2 - sy1)  :: Tensor d t '[]
      t = (tx2 - tx1) * (ty2 - ty1) :: Tensor d t '[bi]
      ts = t `add` s :: Tensor d t '[bi]
  in  dxdy / (ts - dxdy)


batchedIou :: forall b0 b1 d t n. 
    ( KnownDevice d
    , KnownDType t
    , KnownNat b0
    , KnownNat b1
    , KnownNat n
    , BasicArithmeticDTypeIsValid d t
    , InRange '[b0,n] 1 0
    , InRange '[b0,n] 1 1
    , InRange '[b0,n] 1 2
    , InRange '[b0,n] 1 3
    , InRange '[b1,n] 1 0
    , InRange '[b1,n] 1 1
    , InRange '[b1,n] 1 2
    , InRange '[b1,n] 1 3
    )
    => Tensor d t '[b0,n] 
    -> Tensor d t '[b1,n] 
    -> Tensor d t '[b0,b1]
batchedIou source targets = 
  let sx1 = select @1 @0 source :: Tensor d t '[b0]
      sy1 = select @1 @1 source :: Tensor d t '[b0]
      sx2 = select @1 @2 source :: Tensor d t '[b0]
      sy2 = select @1 @3 source :: Tensor d t '[b0]
      sx1' = reshape sx1 :: Tensor d t '[b0,1]
      sy1' = reshape sy1 :: Tensor d t '[b0,1]
      sx2' = reshape sx2 :: Tensor d t '[b0,1]
      sy2' = reshape sy2 :: Tensor d t '[b0,1]
      tx1 = select @1 @0 targets :: Tensor d t '[b1]
      ty1 = select @1 @1 targets :: Tensor d t '[b1]
      tx2 = select @1 @2 targets :: Tensor d t '[b1]
      ty2 = select @1 @3 targets :: Tensor d t '[b1]
      tx1' = reshape tx1 :: Tensor d t '[1,b1]
      ty1' = reshape ty1 :: Tensor d t '[1,b1]
      tx2' = reshape tx2 :: Tensor d t '[1,b1]
      ty2' = reshape ty2 :: Tensor d t '[1,b1]
      dx = minimum tx2' sx2' - maximum tx1' sx1' :: Tensor d t '[b0,b1]
      dy = minimum ty2' sy2' - maximum ty1' sy1' :: Tensor d t '[b0,b1]  
      dxdy = dx * dy :: Tensor d t '[b0,b1]
      s = (sx2' - sx1') * (sy2' - sy1')  :: Tensor d t '[b0,1]
      t = (tx2' - tx1') * (ty2' - ty1') :: Tensor d t '[1,b1]
      ts = t `add` s :: Tensor d t '[b0,b1]
  in  dxdy / (ts - dxdy)

batchedIou boxes boxes
Untyped.asValue (toDynamic boxes) :: [[Double]]


Tensor Double [3,3] [[ 1.0000   ,  0.1429   , -0.3333   ],
                     [ 0.1429   ,  1.0000   , -0.1111   ],
                     [-0.3333   , -0.1111   ,  1.0000   ]]

[[0.0,0.0,1.0,1.0],[0.5,0.5,1.5,1.5],[2.0,0.0,3.0,1.0]]

In [114]:
splitBatch :: forall b. KnownNat b => Tensor '(CPU,0) 'Double '[b,4] -> [Tensor '(CPU,0) 'Double '[4]]
splitBatch ti = forEach @b $ \(Proxy :: Proxy ii) -> unsafeConstraint @(InRange [b,4] 0 ii) $ select @0 @ii ti

splitBatch boxes

[Tensor Double [4] [ 0.0000,  0.0000,  1.0000   ,  1.0000   ],Tensor Double [4] [ 0.5000   ,  0.5000   ,  1.5000   ,  1.5000   ],Tensor Double [4] [ 2.0000   ,  0.0000,  3.0000   ,  1.0000   ]]

In [None]:

toBool :: Tensor d 'Bool '[] -> Bool
toBool = Untyped.asValue . toDynamic



nms :: forall d t bi a. 
  (KnownNat bi, BasicArithmeticDTypeIsValid d t, ComparisonDTypeIsValid d t, KnownDevice d,KnownDType t)
  => Tensor d t '[]
  -> Tensor d t '[bi, 4]
  -> (forall bo. KnownNat bo => Tensor d t '[bo,4] -> a) -> a
nms thresh boxes func = 
    let to = toDynamic boxes Untyped.! toDynamic deletedIdxes
    in withNat (head $ Untyped.shape to) $ \(Proxy :: Proxy bo) -> func (UnsafeMkTensor @d @t @'[bo,4] to)
  where
    deletedIdxes = loop [0..(natValI @bi - 1)] (zeros :: Tensor d 'Bool '[bi])
    candidates :: Tensor d 'Bool '[bi,bi]
    candidates = tril (-1) $ batchedIou boxes boxes `gt` thresh 
    loop :: [Int] -> Tensor d 'Bool '[bi] -> Tensor d 'Bool '[bi]
    loop [] v = v
    loop (x:xs) deleted = 
      withNat x $ \(Proxy :: Proxy i) -> 
        if toBool (unsafeConstraint @(InRange '[bi] 0 i) $ select @0 @i deleted :: Tensor d 'Bool '[])
          then loop xs deleted 
          else loop xs (deleted `logicalOr` (unsafeConstraint @(InRange '[bi,bi] 0 i) $ select @0 @i candidates :: Tensor d 'Bool '[bi]))

---- Can not define cps style in this case.
-- import Control.Monad.Cont
-- cpsNms thresh boxes = cont (nms thresh boxes)

nms 0.5 boxes (print.shape) --  :: forall n. Tensor '(CPU,0) 'Double '[n,4]


: 

In [None]:
-- A variable length tensor
data VTensor d t shape = forall b. KnownNat b => VTensor { unVTensor :: Tensor d t (b : shape) } 

instance Show (VTensor d t shape) where
  show v = case v of 
    (VTensor (te :: Tensor d t (b : shape))) -> show te -- show $ toDynamic te

nms' :: forall d t bi a. 
  (KnownNat bi, BasicArithmeticDTypeIsValid d t, ComparisonDTypeIsValid d t, KnownDevice d,KnownDType t)
  => Tensor d t '[]
  -> Tensor d t '[bi, 4]
  -> VTensor d t '[4]
nms' thresh boxes = 
    let to = toDynamic boxes Untyped.! toDynamic deletedIdxes
    in withNat (head $ Untyped.shape to) $ \(Proxy :: Proxy n) -> VTensor $ UnsafeMkTensor @d @t @'[n,4] to
  where
    deletedIdxes = loop [0..(natValI @bi - 1)] (zeros :: Tensor d 'Bool '[bi])
    candidates :: Tensor d 'Bool '[bi,bi]
    candidates = tril (-1) $ batchedIou boxes boxes `gt` thresh 
    loop :: [Int] -> Tensor d 'Bool '[bi] -> Tensor d 'Bool '[bi]
    loop [] v = v
    loop (x:xs) deleted = 
      withNat x $ \(Proxy :: Proxy i) -> 
        if toBool (unsafeConstraint @(InRange '[bi] 0 i) $ select @0 @i deleted :: Tensor d 'Bool '[])
          then loop xs deleted 
          else loop xs (deleted `logicalOr` (unsafeConstraint @(InRange '[bi,bi] 0 i) $ select @0 @i candidates :: Tensor d 'Bool '[bi]))

nms' 0.5 boxes 

Tensor Double [0,4] []

In [7]:
a = zeros :: Tensor '(CPU,0) 'Double '[4]
reshape a :: Tensor '(CPU,0) 'Double '[1,4]
reshape a :: Tensor '(CPU,0) 'Double '[4,1]

Tensor Double [1,4] [[ 0.0000,  0.0000,  0.0000,  0.0000]]

Tensor Double [4,1] [[ 0.0000],
                     [ 0.0000],
                     [ 0.0000],
                     [ 0.0000]]