Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Nov 1, 2019
1 parent 1e68e75 commit 0445372
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 66 deletions.
129 changes: 76 additions & 53 deletions torch_struct/alignment.py
Expand Up @@ -216,7 +216,31 @@ def reflect(x, size):

class Merge(torch.autograd.Function):
@staticmethod
def forward(ctx, left, right, rsize, nrsize):
def forward(ctx, xa, xb, rsize, size):
nrsize = (rsize - 1) * 2 + 3
st = []
left = (
pad_conv(
demote(xa[:, :, 0 : size * 2 : 2, :], 3), nrsize, 7, semiring, 2, 2
)
.transpose(-1, -2)
.view(ssize, batch, size, bin_MN, 1, LOC, LOC, 3, nrsize, rsize + 2)
)

right = (
pad_conv(
pad(
demote(xb[:, :, 1 : size * 2 : 2, :, :], 4),
1,
1,
-1,
semiring,
), nrsize, 3, semiring)
.transpose(-1, -2)
.view(ssize, batch, size, bin_MN, LOC, 1, LOC, 1, 3, nrsize, rsize)
)


st = []
grads = []
for op in (Up, Down, Mid):
Expand Down Expand Up @@ -249,34 +273,56 @@ def forward(ctx, left, right, rsize, nrsize):
)
return torch.stack(st, dim=-1)

@staticmethod
def backward(ctx, grad_output):
grad, ls, rs, v = ctx.saved_tensors
rsize, nrsize = v.tolist()
grad_in = grad.mul(grad_output.unsqueeze(-2))
left = torch.zeros(
*ls.tolist(), dtype=grad_output.dtype, device=grad_output.device
)
right = torch.zeros(
*rs.tolist(), dtype=grad_output.dtype, device=grad_output.device
)
# grad_in = grad_in.permute(0, 1, 2, 7, 3, 4, 5, 6, 8)
grad_in = grad_in.permute(9, 0, 1, 2, 4, 5, 6, 7, 3, 8)
for i, op in enumerate((Up, Down, Mid)):
top, bot = rsize + 1, 1
if op == Up:
top, bot = rsize + 2, 2
if op == Down:
top, bot = rsize, 0

left[:, :, :, :, :, Open, :, :, :, bot:top] += grad_in[i]
right[:, :, :, :, :, Open, :, 0, op, :, :] += grad_in[i].sum(-3)
return left, right, None, None, None
# @staticmethod
# def backward(ctx, grad_output):
# grad, ls, rs, v = ctx.saved_tensors
# rsize, nrsize = v.tolist()
# grad_in = grad.mul(grad_output.unsqueeze(-2))
# left = torch.zeros(
# *ls.tolist(), dtype=grad_output.dtype, device=grad_output.device
# )
# right = torch.zeros(
# *rs.tolist(), dtype=grad_output.dtype, device=grad_output.device
# )
# # grad_in = grad_in.permute(0, 1, 2, 7, 3, 4, 5, 6, 8)
# grad_in = grad_in.permute(9, 0, 1, 2, 4, 5, 6, 7, 3, 8)
# for i, op in enumerate((Up, Down, Mid)):
# top, bot = rsize + 1, 1
# if op == Up:
# top, bot = rsize + 2, 2
# if op == Down:
# top, bot = rsize, 0

# left[:, :, :, :, :, Open, :, :, :, bot:top] += grad_in[i]
# right[:, :, :, :, :, Open, :, 0, op, :, :] += grad_in[i].sum(-3)
# return left, right, None, None, None

merge = Merge.apply
else:
def merge(xa, xb, rsize, size):
nrsize = (rsize - 3) * 2 + 3
left = (
pad_conv(
demote(xa[:, :, 0 : size * 2 : 2, :], 3), nrsize, 7, semiring, 2, 2
)
.transpose(-1, -2)
.view(ssize, batch, size, bin_MN, 1, LOC, LOC, 3, nrsize, rsize + 2)
)

right = (
pad_conv(
pad(
demote(xb[:, :, 1 : size * 2 : 2, :, :], 4),
1,
1,
-1,
semiring,
), nrsize, 3, semiring)
.transpose(-1, -2)
.view(ssize, batch, size, bin_MN, LOC, 1, LOC, 1, 3, nrsize, rsize)
)


def merge(left, right, rsize, nrsize):
st = []
for op in (Up, Down, Mid):
top, bot = rsize + 1, 1
Expand All @@ -300,30 +346,7 @@ def merge(left, right, rsize, nrsize):
def merge2(xa, xb, size, rsize):
nrsize = (rsize - 1) * 2 + 3
rsize += 2
st = []
left = (
pad_conv(
demote(xa[:, :, 0 : size * 2 : 2, :], 3), nrsize, 7, semiring, 2, 2
)
.transpose(-1, -2)
.view(ssize, batch, size, bin_MN, 1, LOC, LOC, 3, nrsize, rsize + 2)
)

right = (
pad_conv(
pad(
demote(xb[:, :, 1 : size * 2 : 2, :, :], 4),
1,
1,
-1,
semiring,
), nrsize, 3, semiring)
.transpose(-1, -2)
.view(ssize, batch, size, bin_MN, LOC, 1, LOC, 1, 3, nrsize, rsize)
)

st = merge(left, right, rsize, nrsize)

st = merge(xa, xb, rsize, size)
if self.local:
left_ = pad(
xa[:, :, 0::2, :, :, Close, :, :],
Expand All @@ -344,8 +367,8 @@ def merge2(xa, xb, size, rsize):
st2.append(torch.stack([semiring.zero_(right.clone()), right], dim=-2))
st = torch.cat([st, torch.stack(st2, dim=-1)], dim=-1)
return semiring.sum(st)
# reporter = MemReporter()
# reporter.report()
reporter = MemReporter()
reporter.report()

size = bin_MN // 2
rsize = 2
Expand All @@ -369,8 +392,8 @@ def merge2(xa, xb, size, rsize):
:, :, 0, M - N + (charta[-1].shape[3] // 2), N, Open, Open, Mid
]

# reporter = MemReporter()
# reporter.report()
reporter = MemReporter()
reporter.report()
return v, [log_potentials], None

@staticmethod
Expand Down
25 changes: 12 additions & 13 deletions torch_struct/semirings.py
Expand Up @@ -139,8 +139,8 @@ def unaccumulate_(a, b, grad_output, fn, step=10000):
a_grad = a.clone().fill_(0)
b_grad = b.clone().fill_(0)
# print("chcek", a_grad.shape)
a_grad2 = torch.tensor(0.0, device=a.device, dtype=a.dtype).set_(a.clone().storage(), a.storage_offset(), a.size(), a.stride()).fill_(0)
b_grad2 = torch.tensor(0.0, device=b.device, dtype=b.dtype).set_(b.clone().storage(), b.storage_offset(), b.size(), b.stride()).fill_(0)
# a_grad2 = torch.tensor(0.0, device=a.device, dtype=a.dtype).set_(a.clone().storage(), a.storage_offset(), a.size(), a.stride()).fill_(0)
# b_grad2 = torch.tensor(0.0, device=b.device, dtype=b.dtype).set_(b.clone().storage(), b.storage_offset(), b.size(), b.stride()).fill_(0)

total = 1
for s in grad_output.shape:
Expand Down Expand Up @@ -170,14 +170,13 @@ def unaccumulate_(a, b, grad_output, fn, step=10000):

q = fn(a[tuple(a_ind)], b[tuple(b_ind)], grad_output[tuple(ind)])
# a_grad[tuple(a_ind)] = a_grad[tuple(a_ind)] + q
print(len(a_ind), q.shape, a_grad.shape)
a_grad.index_put_(tuple(a_ind), q, accumulate=True)
b_grad.index_put_(tuple(b_ind), q, accumulate=True)
a_grad2.index_put_(tuple(a_ind), q, accumulate=True)
b_grad2.index_put_(tuple(b_ind), q, accumulate=True)
assert torch.isclose(a_grad, a_grad2).all(), a_grad - a_grad2
# a_grad2.index_put_(tuple(a_ind), q, accumulate=True)
# b_grad2.index_put_(tuple(b_ind), q, accumulate=True)
# assert torch.isclose(a_grad, a_grad2).all(), a_grad - a_grad2

return a_grad2, b_grad2
return a_grad, b_grad


def accumulate_(a, b, ret, fn, step=10000):
Expand Down Expand Up @@ -271,9 +270,9 @@ def forward(ctx, a, b):
def backward(ctx, grad_output):

a, b = ctx.saved_tensors
# print("backing out", a.shape)
# reporter = MemReporter()
# reporter.report()
print("backing out", a.shape)
reporter = MemReporter()
reporter.report()

size = [max(p, q) for p, q in zip(a.shape, b.shape)][:-1]

Expand All @@ -294,9 +293,9 @@ def backward(ctx, grad_output):
grad_a = back.sum(dim=asum, keepdim=True)
grad_b = back.sum(dim=bsum, keepdim=True)

# print("backing out 2", a.shape)
# reporter = MemReporter()
# reporter.report()
print("backing out 2", a.shape)
reporter = MemReporter()
reporter.report()


return grad_a, grad_b
Expand Down

0 comments on commit 0445372

Please sign in to comment.