diff --git a/torch_struct/semirings/semirings.py b/torch_struct/semirings/semirings.py index cfc2311..6db6f4c 100644 --- a/torch_struct/semirings/semirings.py +++ b/torch_struct/semirings/semirings.py @@ -49,6 +49,7 @@ def dot(cls, a, b): @staticmethod def fill(c, mask, v): + mask = mask.to(c.device) return torch.where( mask, v.type_as(c).view((-1,) + (1,) * (len(c.shape) - 1)), c )