Skip to content

Commit

Permalink
minor improvements after review
Browse files Browse the repository at this point in the history
  • Loading branch information
lferraz committed Nov 29, 2022
1 parent 512bc72 commit 8b3ca54
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
8 changes: 2 additions & 6 deletions kornia/contrib/face_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,10 @@ def forward(self, image: torch.Tensor) -> List[torch.Tensor]:
r"""Detect faces in a given batch of images.
Args:
image (torch.Tensor): batch of images :math:`(B,3,H,W)`
image: batch of images :math:`(B,3,H,W)`
Return:
List[torch.Tensor]: list with the boxes found on each image. :math:`Bx(N,15)`
List[torch.Tensor]: list with the boxes found on each image. :math:`Bx(N,15)`.
"""
img = self.preprocess(image)
Expand Down Expand Up @@ -326,11 +326,7 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
loc_data, conf_data, iou_data = head_data.split((14, 2, 1), dim=-1)

if self.phase == "test":
# trick to make it work with batches
# loc_data = loc_data.view(-1, 14)
# conf_data = torch.softmax(conf_data.view(-1, self.num_classes), dim=-1)
conf_data = torch.softmax(conf_data, dim=-1)
# iou_data = iou_data.view(-1, 1)
else:
loc_data = loc_data.view(loc_data.size(0), -1, 14)
conf_data = conf_data.view(conf_data.size(0), -1, self.num_classes)
Expand Down
4 changes: 2 additions & 2 deletions test/test_contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,9 +579,9 @@ def test_smoke(self, device, dtype):
@pytest.mark.parametrize("batch_size", [1, 2, 4])
def test_valid(self, batch_size, device, dtype):
torch.manual_seed(44)
img = torch.rand(3, 320, 320, device=device, dtype=dtype)
img = torch.rand(batch_size, 3, 320, 320, device=device, dtype=dtype)
face_detection = kornia.contrib.FaceDetector().to(device, dtype)
dets = face_detection(torch.stack([img] * batch_size))
dets = face_detection(img)
assert isinstance(dets, list)
assert len(dets) == batch_size # same as the number of images
assert isinstance(dets[0], torch.Tensor)
Expand Down

0 comments on commit 8b3ca54

Please sign in to comment.