Skip to content

Commit

Permalink
Add nms for a motivational example of Record Dot Syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
junjihashimoto committed Dec 9, 2021
1 parent 9560d14 commit 095f061
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions hasktorch/src/Torch/Typed/Vision.hs
Expand Up @@ -141,3 +141,75 @@ initMnist path = do
testImagesBS <- decompressFile path "t10k-images-idx3-ubyte.gz"
testLabelsBS <- decompressFile path "t10k-labels-idx1-ubyte.gz"
return (MnistData imagesBS labelsBS, MnistData testImagesBS testLabelsBS)

data Box a = Box
{ x1 :: a,
y1 :: a,
x2 :: a,
y2 :: a,
score :: a
}
deriving (Show, Eq, Generic, Default)

nms_cpu :: Num a => [Box a] -> a -> [Box a]
nms_cpu dets = nms_cpu' (sortBy score dets)
where
nms_cpu' :: Num a => [Box a] -> a -> [Box a]
nms_cpu' [] _ = []
nms_cpu' (head_:tail_) iou_threshold = x: nms_cpu filtered_boxes iou_threshold
where
head_area = (x2 head_ - x1 head_) * (y2 head_ - y1 head_)
filtered_boxes = filter (\v ->
let
xx1 = max (x1 head_) (x1 v)
yy1 = max (y1 head_) (y1 v)
xx2 = min (x2 head_) (x2 v)
yy1 = min (y2 head_) (y2 v)
v_area = (xx2 - xx1) * (yy2 - yy1)
inter_area = (xx2 - xx1) * (yy2 - yy1)
iou = inter_area / (head_area + v_area - inter_area)
in iou < iou_threshold
) tail_

-- THe reference code of nms
-- https://github.com/rbgirshick/fast-rcnn/blob/master/lib/utils/nms.py

nms :: NamedTensor device dtype '[Vector n, Box] -> Float -> NamedTensor device dtype '[Vector m, Box]
nms dets iou_threshold = dets ! (loop sort_idxes 0)
where
sort_idxes = sortNamedDim @"score" dets
areas =
(det ^. field @"x2" - det ^. field @"x1") *
(det ^. field @"y2" - det ^. field @"y1")
loop sort_idxes i | length sort_idxes <= i = sort_idxes
| otherwise =
let
idx = sort_idxes ! i
other_idxes = sort_idxes ! [slice|({i}+1):|]
xx1 = max (dets ! idx ^.field@"x1") ((dets ! other_idxes) ^. field @"x1")
xx2 = min (dets ! idx ^.field@"x2") ((dets ! other_idxes) ^. field @"x2")
yy1 = max (dets ! idx ^.field@"y1") ((dets ! other_idxes) ^. field @"y1")
yy2 = min (dets ! idx ^.field@"y2") ((dets ! other_idxes) ^. field @"y2")
inter_areas = (xx2 - xx1) * (yy2 - yy1)
iou = inter_areas / ((area ! idx) + (areas ! other_idxes) - inter_areas)
in loop (delete sort_idxes (iou >= iou_threshold)) (i+1)

nms_with_dotsyntax :: NamedTensor device dtype '[Vector n, Box] -> Float -> NamedTensor device dtype '[Vector m, Box]
nms_with_dotsyntax dets iou_threshold = dets ! (loop sort_idxes 0)
where
sort_idxes = sortNamedDim @"score" dets
areas =
(det.x2 - det.x1) *
(det.y2 - det.y1)
loop sort_idxes i | length sort_idxes <= i = sort_idxes
| otherwise =
let
idx = sort_idxes ! i
other_idxes = sort_idxes ! [slice|({i}+1):|]
xx1 = max (dets ! i).x1 (dets ! other_idxes).x1
xx2 = min (dets ! i).x2 (dets ! other_idxes).x2
yy1 = max (dets ! i).y1 (dets ! other_idxes).y1
yy2 = min (dets ! i).y2 (dets ! other_idxes).y2
inter_areas = (xx2 - xx1) * (yy2 - yy1)
iou = inter_areas / ((area ! idx) + (areas ! other_idxes) - inter_areas)
in loop (delete sort_idxes (iou >= iou_threshold)) (i+1)

0 comments on commit 095f061

Please sign in to comment.