Skip to content

Commit

Permalink
[Fix] Fix binary C=1 focal loss & dataset fileio (open-mmlab#2935)
Browse files Browse the repository at this point in the history
  • Loading branch information
csatsurnh committed Apr 23, 2023
1 parent 8327e29 commit da09304
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 6 deletions.
4 changes: 3 additions & 1 deletion mmseg/datasets/chase_db1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmengine.fileio as fileio

from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset
Expand Down Expand Up @@ -27,4 +28,5 @@ def __init__(self,
seg_map_suffix=seg_map_suffix,
reduce_zero_label=reduce_zero_label,
**kwargs)
assert self.file_client.exists(self.data_prefix['img_path'])
assert fileio.exists(
self.data_prefix['img_path'], backend_args=self.backend_args)
4 changes: 3 additions & 1 deletion mmseg/datasets/drive.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmengine.fileio as fileio

from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset
Expand Down Expand Up @@ -27,4 +28,5 @@ def __init__(self,
seg_map_suffix=seg_map_suffix,
reduce_zero_label=reduce_zero_label,
**kwargs)
assert self.file_client.exists(self.data_prefix['img_path'])
assert fileio.exists(
self.data_prefix['img_path'], backend_args=self.backend_args)
4 changes: 3 additions & 1 deletion mmseg/datasets/hrf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmengine.fileio as fileio

from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset
Expand Down Expand Up @@ -27,4 +28,5 @@ def __init__(self,
seg_map_suffix=seg_map_suffix,
reduce_zero_label=reduce_zero_label,
**kwargs)
assert self.file_client.exists(self.data_prefix['img_path'])
assert fileio.exists(
self.data_prefix['img_path'], backend_args=self.backend_args)
5 changes: 4 additions & 1 deletion mmseg/datasets/stare.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmengine.fileio as fileio

from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset

Expand Down Expand Up @@ -26,4 +28,5 @@ def __init__(self,
seg_map_suffix=seg_map_suffix,
reduce_zero_label=reduce_zero_label,
**kwargs)
assert self.file_client.exists(self.data_prefix['img_path'])
assert fileio.exists(
self.data_prefix['img_path'], backend_args=self.backend_args)
14 changes: 12 additions & 2 deletions mmseg/models/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,13 @@ def forward(self,
num_classes = pred.size(1)
if torch.cuda.is_available() and pred.is_cuda:
if target.dim() == 1:
one_hot_target = F.one_hot(target, num_classes=num_classes)
one_hot_target = F.one_hot(
target, num_classes=num_classes + 1)
if num_classes == 1:
one_hot_target = one_hot_target[:, 1]
target = 1 - target
else:
one_hot_target = one_hot_target[:, :num_classes]
else:
one_hot_target = target
target = target.argmax(dim=1)
Expand All @@ -280,7 +286,11 @@ def forward(self,
else:
one_hot_target = None
if target.dim() == 1:
target = F.one_hot(target, num_classes=num_classes)
target = F.one_hot(target, num_classes=num_classes + 1)
if num_classes == 1:
target = target[:, 1]
else:
target = target[:, num_classes]
else:
valid_mask = (target.argmax(dim=1) != ignore_index).view(
-1, 1)
Expand Down

0 comments on commit da09304

Please sign in to comment.