Skip to content

Commit

Permalink
Changed loss and model to work with BCEWithLogits
Browse files Browse the repository at this point in the history
  • Loading branch information
manole-alexandru committed Mar 23, 2023
1 parent fbdcf9f commit 4048228
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
4 changes: 2 additions & 2 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ def __init__(self, in_channels):
self.cv2 = Conv(32, 64, k=3)
self.cv3 = Conv(64, 1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
# self.sigmoid = nn.Sigmoid()

def forward(self, x):
# print('----entry shape', x.shape, '---\n')
Expand All @@ -869,7 +869,7 @@ def forward(self, x):
x = self.relu(x)
x = self.cv3(x)
# print('----out shape', x.shape, '---\n')
x = self.sigmoid(x)
# x = self.sigmoid(x)
return x


Expand Down
2 changes: 1 addition & 1 deletion utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __call__(self, preds, targets, seg_masks): # predictions, targets
# print('\n-----REAL MASK', seg_masks.shape, '-------\n')
print('\n----------- PRED VALID: ', torch.all(pred_mask >= 0), '-----------------\n')
print('\n----------- SEG MASK VALID: ', torch.all(seg_masks >= 0), '-----------------\n')
seg_loss = nn.functional.binary_cross_entropy(pred_mask, seg_masks, reduce=False, reduction='none').mean()
seg_loss = nn.functional.binary_cross_entropy_with_logits(pred_mask, seg_masks, reduction='none').mean()
lseg += seg_loss

# Append targets to text file
Expand Down
1 change: 1 addition & 0 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def save_one_json(predn, jdict, path, class_map):

def compute_seg_iou(pred, target, n_classes=2):
ious = []
pred = torch.sigmoid(pred)
pred = pred.view(-1)
target = target.view(-1)
print(target)
Expand Down

0 comments on commit 4048228

Please sign in to comment.