Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making loss_not_reduced work with DiceLoss #3583

Merged
merged 1 commit into from
Feb 16, 2022

Conversation

hiromis
Copy link
Contributor

@hiromis hiromis commented Feb 16, 2022

When using loss_not_reduced context manager with DiceLoss, it would still reduce the loss by sum because it is using if self.reduction == 'mean' else logic.

The original issue I came across was the following:

learn = unet_learner(dls, resnet34, loss_func=DiceLoss(axis=1))
interp = SegmentationInterpretation.from_learner(learn)

The above code will result in:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_60504/2451220244.py in <module>
      1 learn = unet_learner(dls, resnet34, loss_func=DiceLoss(axis=1))
----> 2 interp = SegmentationInterpretation.from_learner(learn)

/workspace/fastai/fastai/interpret.py in from_learner(cls, learn, ds_idx, dl, act)
     39         if dl is None: dl = learn.dls[ds_idx].new(shuffle=False, drop_last=False)
     40         _,_,losses = learn.get_preds(dl=dl, with_input=False, with_loss=True, with_decoded=False,
---> 41                                      with_preds=False, with_targs=False, act=act)
     42         return cls(learn, dl, losses, act)
     43 

/workspace/fastai/fastai/learner.py in get_preds(self, ds_idx, dl, with_input, with_decoded, with_loss, act, inner, reorder, cbs, **kwargs)
    253         if with_loss: ctx_mgrs.append(self.loss_not_reduced())
    254         with ContextManagers(ctx_mgrs):
--> 255             self._do_epoch_validate(dl=dl)
    256             if act is None: act = getattr(self.loss_func, 'activation', noop)
    257             res = cb.all_tensors()

/workspace/fastai/fastai/learner.py in _do_epoch_validate(self, ds_idx, dl)
    201         if dl is None: dl = self.dls[ds_idx]
    202         self.dl = dl
--> 203         with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)
    204 
    205     def _do_epoch(self):

/workspace/fastai/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    161 
    162     def _with_events(self, f, event_type, ex, final=noop):
--> 163         try: self(f'before_{event_type}');  f()
    164         except ex: self(f'after_cancel_{event_type}')
    165         self(f'after_{event_type}');  final()

/workspace/fastai/fastai/learner.py in all_batches(self)
    167     def all_batches(self):
    168         self.n_iter = len(self.dl)
--> 169         for o in enumerate(self.dl): self.one_batch(*o)
    170 
    171     def _do_one_batch(self):

/workspace/fastai/fastai/learner.py in one_batch(self, i, b)
    192         b = self._set_device(b)
    193         self._split(b)
--> 194         self._with_events(self._do_one_batch, 'batch', CancelBatchException)
    195 
    196     def _do_epoch_train(self):

/workspace/fastai/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    163         try: self(f'before_{event_type}');  f()
    164         except ex: self(f'after_cancel_{event_type}')
--> 165         self(f'after_{event_type}');  final()
    166 
    167     def all_batches(self):

/workspace/fastai/fastai/learner.py in __call__(self, event_name)
    139 
    140     def ordered_cbs(self, event): return [cb for cb in self.cbs.sorted('order') if hasattr(cb, event)]
--> 141     def __call__(self, event_name): L(event_name).map(self._call_one)
    142 
    143     def _call_one(self, event_name):

~/.local/lib/python3.7/site-packages/fastcore/foundation.py in map(self, f, gen, *args, **kwargs)
    153     def range(cls, a, b=None, step=None): return cls(range_of(a, b=b, step=step))
    154 
--> 155     def map(self, f, *args, gen=False, **kwargs): return self._new(map_ex(self, f, *args, gen=gen, **kwargs))
    156     def argwhere(self, f, negate=False, **kwargs): return self._new(argwhere(self, f, negate, **kwargs))
    157     def argfirst(self, f, negate=False): return first(i for i,o in self.enumerate() if f(o))

~/.local/lib/python3.7/site-packages/fastcore/basics.py in map_ex(iterable, f, gen, *args, **kwargs)
    696     res = map(g, iterable)
    697     if gen: return res
--> 698     return list(res)
    699 
    700 # Cell

~/.local/lib/python3.7/site-packages/fastcore/basics.py in __call__(self, *args, **kwargs)
    681             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    682         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 683         return self.func(*fargs, **kwargs)
    684 
    685 # Cell

/workspace/fastai/fastai/learner.py in _call_one(self, event_name)
    143     def _call_one(self, event_name):
    144         if not hasattr(event, event_name): raise Exception(f'missing {event_name}')
--> 145         for cb in self.cbs.sorted('order'): cb(event_name)
    146 
    147     def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)

/workspace/fastai/fastai/callback/core.py in __call__(self, event_name)
     43                (self.run_valid and not getattr(self, 'training', False)))
     44         res = None
---> 45         if self.run and _run: res = getattr(self, event_name, noop)()
     46         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
     47         return res

/workspace/fastai/fastai/callback/core.py in after_batch(self)
    131         if self.with_loss:
    132             bs = find_bs(self.yb)
--> 133             loss = self.loss if self.loss.numel() == bs else self.loss.view(bs,-1).mean(1)
    134             self.losses.append(self.learn.to_detach(loss))
    135 

/workspace/fastai/fastai/torch_core.py in __torch_function__(self, func, types, args, kwargs)
    339         convert=False
    340         if _torch_handled(args, self._opt, func): convert,types = type(self),(torch.Tensor,)
--> 341         res = super().__torch_function__(func, types, args=args, kwargs=kwargs)
    342         if convert: res = convert(res)
    343         if isinstance(res, TensorBase): res.set_meta(self, as_copy=True)

/opt/conda/lib/python3.7/site-packages/torch/_tensor.py in __torch_function__(cls, func, types, args, kwargs)
   1049 
   1050         with _C.DisableTorchFunction():
-> 1051             ret = func(*args, **kwargs)
   1052             if func in get_default_nowrap_functions():
   1053                 return ret

RuntimeError: shape '[8, -1]' is invalid for input of size 1

@hiromis hiromis requested a review from jph00 as a code owner February 16, 2022 19:50
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@hiromis hiromis changed the title Making loss_not_reduced to work with DiceLoss Making loss_not_reduced work with DiceLoss Feb 16, 2022
@jph00
Copy link
Member

jph00 commented Feb 16, 2022

Many thanks!

@jph00 jph00 merged commit 0255ced into fastai:master Feb 16, 2022
@jph00 jph00 added the bug label Mar 25, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants