Skip to content

Commit

Permalink
Merge pull request #37 from rht/master
Browse files Browse the repository at this point in the history
Move the sigmoid activation to the model itself
  • Loading branch information
milesial committed Nov 11, 2018
2 parents 24a26ea + e3f8ca7 commit 7dd7c8b
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion eval.py
Expand Up @@ -20,7 +20,7 @@ def eval_net(net, dataset, gpu=False):
true_mask = true_mask.cuda()

mask_pred = net(img)[0]
mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
mask_pred = (mask_pred > 0.5).float()

tot += dice_coeff(mask_pred, true_mask).item()
return tot / i
4 changes: 2 additions & 2 deletions predict.py
Expand Up @@ -43,8 +43,8 @@ def predict_img(net,
output_left = net(X_left)
output_right = net(X_right)

left_probs = F.sigmoid(output_left).squeeze(0)
right_probs = F.sigmoid(output_right).squeeze(0)
left_probs = output_left.squeeze(0)
right_probs = output_right.squeeze(0)

tf = transforms.Compose(
[
Expand Down
4 changes: 1 addition & 3 deletions train.py
Expand Up @@ -6,7 +6,6 @@
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

from eval import eval_net
Expand Down Expand Up @@ -74,8 +73,7 @@ def train_net(net,
true_masks = true_masks.cuda()

masks_pred = net(imgs)
masks_probs = F.sigmoid(masks_pred)
masks_probs_flat = masks_probs.view(-1)
masks_probs_flat = masks_pred.view(-1)

true_masks_flat = true_masks.view(-1)

Expand Down
4 changes: 3 additions & 1 deletion unet/unet_model.py
@@ -1,5 +1,7 @@
# full assembly of the sub-parts to form the complete net

import torch.nn.functional as F

from .unet_parts import *

class UNet(nn.Module):
Expand Down Expand Up @@ -27,4 +29,4 @@ def forward(self, x):
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
return x
return F.sigmoid(x)

0 comments on commit 7dd7c8b

Please sign in to comment.