Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Nov 21, 2019
1 parent 7bcec54 commit fa935b0
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions torch_struct/semirings/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import torch
try:
import genbmm
except ImportError:
pass


def broadcast_size(a, b):
Expand Down Expand Up @@ -26,9 +30,27 @@ def backward(ctx, grad_output):
q = cls.matmul(a, b)
return torch.autograd.grad(q, (a, b), grad_output)

class _CheckBand(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
ctx.a = a
ctx.b = b
return cls.matmul(a, b)

@staticmethod
def backward(ctx, grad_output):
with torch.enable_grad():
q = cls.matmul(ctx.a, ctx.b)
grad_a, grad_b = torch.autograd.grad(q.data, (ctx.a.data, ctx.b.data),
grad_output)
return BandedMatrix(grad_a, a.lu, a.lb, a.fill), BandedMatrix(grad_b, b.lu, b.lb, b.fill)


class _CheckpointSemiring(cls):
@staticmethod
def matmul(a, b):
if isinstance(a, genbmm.BandedMatrix):
return _CheckBand.apply(a, b)
if broadcast_size(a, b) > min_size:
return _Check.apply(a, b)
else:
Expand Down

0 comments on commit fa935b0

Please sign in to comment.