Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Oct 31, 2019
1 parent f806c0e commit 1e89b98
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torch_struct/semirings.py
Expand Up @@ -136,11 +136,11 @@ def back(x):

def unaccumulate_(a, b, grad_output, fn, step=10000):
slices = []
# a_grad = a.clone().fill_(0)
# b_grad = b.clone().fill_(0)
a_grad = a.clone().fill_(0)
b_grad = b.clone().fill_(0)
# print("chcek", a_grad.shape)
a_grad = torch.tensor(0.0, device=a.device).set_(a.clone().storage(), a.storage_offset(), a.size(), a.stride()).fill_(0)
b_grad = torch.tensor(0.0, device=b.device).set_(b.clone().storage(), b.storage_offset(), b.size(), b.stride()).fill_(0)
# a_grad = torch.tensor(0.0, device=a.device).set_(a.clone().storage(), a.storage_offset(), a.size(), a.stride()).fill_(0)
# b_grad = torch.tensor(0.0, device=b.device).set_(b.clone().storage(), b.storage_offset(), b.size(), b.stride()).fill_(0)

total = 1
for s in grad_output.shape:
Expand Down

0 comments on commit 1e89b98

Please sign in to comment.