Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Nov 8, 2019
1 parent 26be6e1 commit a37b029
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions torch_struct/semirings/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import numpy as np
from time import time

def CheckpointSemiring(cls, max_size, min_size=0):
class _Check(torch.autograd.Function):
Expand Down Expand Up @@ -57,17 +58,22 @@ def accumulate_(a, b, size, fn, preserve, step=10000):
if step > total:
return fn(a, b)

print("trigger", step, total)

ret = torch.zeros(*size, dtype=a.dtype, device=a.device)
a_one, b_one = ones(a), ones(b)

t = time.time()
indices = torch.tensor(np.mgrid[slices]).view(len(ret.shape[:preserve]), -1)
print("trigger", step, total, time.time() - a)

t = time.time()
for p in range(0, total, step):
print(p)

ind = indices[:, p : p + step].unbind()
a_ind = mind(a_one, ind)
b_ind = mind(b_one, ind)
ret[ind] = fn(a[tuple(a_ind)], b[tuple(b_ind)])
print("done", time.time() - t)
return ret

# def unaccumulate_(a, b, grad_output, fn, step=10000):
Expand Down

0 comments on commit a37b029

Please sign in to comment.