From 24c390e216d5906e834f5a9f55eb0eaa21f57a87 Mon Sep 17 00:00:00 2001 From: Urchade Zaratiana <38214774+urchade@users.noreply.github.com> Date: Sun, 2 Jan 2022 12:16:52 +0100 Subject: [PATCH] Fix for GPU runtime error for semiring Fix #121 --- torch_struct/semirings/semirings.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_struct/semirings/semirings.py b/torch_struct/semirings/semirings.py index cfc2311..6db6f4c 100644 --- a/torch_struct/semirings/semirings.py +++ b/torch_struct/semirings/semirings.py @@ -49,6 +49,7 @@ def dot(cls, a, b): @staticmethod def fill(c, mask, v): + mask = mask.to(c.device) return torch.where( mask, v.type_as(c).view((-1,) + (1,) * (len(c.shape) - 1)), c )