From 8cade23dfc530f54377f6a0502cec32eb1b0a8ef Mon Sep 17 00:00:00 2001 From: Sasha Date: Fri, 8 Nov 2019 14:26:43 -0500 Subject: [PATCH] . --- torch_struct/semirings/checkpoint.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/torch_struct/semirings/checkpoint.py b/torch_struct/semirings/checkpoint.py index e03adee2..1fdcc769 100644 --- a/torch_struct/semirings/checkpoint.py +++ b/torch_struct/semirings/checkpoint.py @@ -6,14 +6,11 @@ class _Check(torch.autograd.Function): @staticmethod def forward(ctx, a, b): ctx.save_for_backward(a, b) - if True: - return cls.dot(a, b) - else: - size = [max(p, q) for p, q in zip(a.shape, b.shape)][:-1] - return accumulate_(a, b, size, - lambda a, b: cls.dot(a, b), - preserve=len(ret.shape), - step=max_size // a.shape[-1] + 2) + size = [max(p, q) for p, q in zip(a.shape, b.shape)][:-1] + return accumulate_(a, b, size, + lambda a, b: cls.dot(a, b), + preserve=len(ret.shape), + step=max_size // a.shape[-1] + 2) @staticmethod def backward(ctx, grad_output):