Skip to content

Commit

Permalink
Seg loss afects whole network ultralytics#9
Browse files Browse the repository at this point in the history
  • Loading branch information
manole-alexandru committed Mar 26, 2023
1 parent 25b23f6 commit bf0183b
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def parse_opt(known=False):


def main(opt, callbacks=Callbacks()):
print('\n---------- VERSION:', '#0008', '----------\n')
print('\n---------- VERSION:', '#0009', '----------\n')
# Checks
if RANK in {-1, 0}:
print_args(vars(opt))
Expand Down
2 changes: 1 addition & 1 deletion utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def forward(self, pred, true):

def weighted_bce(y_pred, y_true, BETA=2):
weights = (y_true * (BETA - 1)) + 1
bce = nn.BCELoss(reduction='none')(y_pred, y_true)
bce = nn.BCEWithLogitsLoss(reduction='none')(y_pred, y_true)
wbce = torch.mean(bce * weights)
return wbce

Expand Down
4 changes: 2 additions & 2 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,8 @@ def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
g_det[1].append(p)
else:
g_det[0].append(p) # weight (with decay)
else:
# if m_index < 287 or m_index > 290: # SEGMENTATION Optimizer
# else:
if m_index < 287 or m_index > 290: # SEGMENTATION Optimizer
for p_index, (p_name, p) in enumerate(v.named_parameters(recurse=0)):
if p_name == 'bias': # bias (no decay)
g_seg[2].append(p)
Expand Down

0 comments on commit bf0183b

Please sign in to comment.