diff --git a/torch_em/loss/dice.py b/torch_em/loss/dice.py index b7fd27ed..da0d147f 100644 --- a/torch_em/loss/dice.py +++ b/torch_em/loss/dice.py @@ -89,6 +89,7 @@ def __init__(self, channelwise=True, eps=1e-7, reduce_channel="sum"): super().__init__() self.channelwise = channelwise self.eps = eps + self.reduce_channel = reduce_channel # all torch_em classes should store init kwargs to easily recreate the init call self.init_kwargs = {"channelwise": channelwise, "eps": self.eps, "reduce_channel": self.reduce_channel}