From 6d51297d563c9dd1771e9f929b7c9267fa7ae9ed Mon Sep 17 00:00:00 2001 From: Sasha Date: Fri, 8 Nov 2019 14:28:13 -0500 Subject: [PATCH] . --- torch_struct/semirings/checkpoint.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torch_struct/semirings/checkpoint.py b/torch_struct/semirings/checkpoint.py index 1fdcc769..924cfeea 100644 --- a/torch_struct/semirings/checkpoint.py +++ b/torch_struct/semirings/checkpoint.py @@ -9,7 +9,7 @@ def forward(ctx, a, b): size = [max(p, q) for p, q in zip(a.shape, b.shape)][:-1] return accumulate_(a, b, size, lambda a, b: cls.dot(a, b), - preserve=len(ret.shape), + preserve=len(size), step=max_size // a.shape[-1] + 2) @staticmethod @@ -57,8 +57,6 @@ def accumulate_(a, b, size, fn, preserve, step=10000): if step > total: return fn(a, b) - print("miss") - ret = torch.zeros(*size, dtype=a.dtype, device=a.device) a_one, b_one = ones(a), ones(b) indices = torch.tensor(np.mgrid[slices]).view(len(ret.shape[:preserve]), -1)