Skip to content

Commit

Permalink
max marg
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Mar 25, 2020
1 parent 9f5a0f5 commit 255d132
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 1 deletion.
2 changes: 1 addition & 1 deletion torch_struct/linearchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _dp_scan(self, log_potentials, lengths=None, force_grad=False):
c = chart[:, :, :].view(ssize, batch * bin_N, C, C)
lp = big[:, :, :].view(ssize, batch * bin_N, C, C)
mask = torch.arange(bin_N).view(1, bin_N).expand(batch, bin_N).type_as(c)
mask = mask >= (lengths - 1).view(batch, 1)
mask = mask >= (lengths.float() - 1).view(batch, 1)
mask = mask.view(batch * bin_N, 1, 1).to(lp.device)
semiring.zero_mask_(lp.data, mask)
semiring.zero_mask_(c.data, (~mask))
Expand Down
1 change: 1 addition & 0 deletions torch_struct/semirings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
StdSemiring,
KMaxSemiring,
MaxSemiring,
MaxMarginalSemiring,
EntropySemiring,
TempMax,
)
Expand Down
70 changes: 70 additions & 0 deletions torch_struct/semirings/semirings.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,73 @@ def sparse_sum(xs, dim=-1):
return m, (torch.zeros(a.shape[:-1]).long(), a)

return _TempMax


class _MaxMaginal(torch.autograd.Function):
@staticmethod
def forward(ctx, input, dim):
m, _ = torch.max(input, dim=dim)
ctx.save_for_backward(input, m, torch.tensor(dim))
return m

@staticmethod
def backward(ctx, grad_output):
logits, m, dim = ctx.saved_tensors
diff = logits - m.unsqueeze(dim)
grad_input = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.unsqueeze(dim).add(diff)
return grad_input, None


class _MaxMaginalTimes(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
return a + b

@staticmethod
def backward(ctx, grad_output):
return grad_output, grad_output

class _MaxMaginalUnsqueeze(torch.autograd.Function):
@staticmethod
def forward(ctx, a, dim, shape):
ctx.dim = dim
a = a.unsqueeze(dim)
t = list(a.shape)
t[dim] = shape
return a.expand(t)

@staticmethod
def backward(ctx, grad_output):
return grad_output.max(ctx.dim)[0], None, None


class MaxMarginalSemiring(_BaseLog):
"""
Implements a max marginal semiring
"Gradients" give max-marginals.
This is an exact approach.
"""

@staticmethod
def sum(xs, dim=-1):
return _MaxMaginal.apply(xs, dim)

@staticmethod
def times(a, b):
return _MaxMaginalTimes.apply(a, b)

@classmethod
def matmul(cls, a, b):
"Generalized tensordot. Classes should override."
dims = 1
act_on = -(dims + 1)
a = _MaxMaginalUnsqueeze.apply(a, -1, b.shape[-1])
b = _MaxMaginalUnsqueeze.apply(b, act_on - 1, a.shape[act_on -1])
c = cls.times(a, b)
for d in range(act_on, -1, 1):
c = cls.sum(c.transpose(-2, -1))
return c
34 changes: 34 additions & 0 deletions torch_struct/semirings/test_semirings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
KMaxSemiring,
MaxSemiring,
StdSemiring,
MaxMarginalSemiring
)


Expand Down Expand Up @@ -40,6 +41,39 @@ def test_max(a, b, c):
assert torch.isclose(b, b2[0]).all()


@given(lint, lint, lint)
def test_max_marg(a, b, c):
torch.manual_seed(0)
t1 = torch.rand(1, c).requires_grad_(True)
t2 = torch.rand(1, c,a).requires_grad_(True)
t3 = torch.rand(1, c, a).requires_grad_(True)

r1 = MaxMarginalSemiring.dot(t1, MaxMarginalSemiring.dot(t2, t3))
r1.sum().backward()
print(r1)
print("grad a", t1.grad)
print("grad b", t2.grad)
print(t1)
print(t2)
assert(False)
# t1a = torch.zeros(2, a, 1, c)
# t2a = torch.zeros(2, 1, b, c)
# t1a[0] = t1
# t2a[0] = t2
# t1a[1].fill_(-1e10)
# t2a[1].fill_(-1e10)

# r2 = KMaxSemiring(2).dot(t1a, t2a)
# assert torch.isclose(r1, r2[0]).all()

# (a, b) = torch.autograd.grad(r1.sum(), (t1, t2))
# (a2, b2) = torch.autograd.grad(r2[0].sum(), (t1a, t2a))

# assert torch.isclose(a, a2[0]).all()
# assert torch.isclose(b, b2[0]).all()



@given(lint, lint, lint)
def test_checkpoint(a, b, c):
torch.manual_seed(0)
Expand Down
24 changes: 24 additions & 0 deletions torch_struct/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
CheckpointShardSemiring,
KMaxSemiring,
SparseMaxSemiring,
MaxMarginalSemiring,
MaxSemiring,
StdSemiring,
SampledSemiring,
Expand Down Expand Up @@ -95,6 +96,29 @@ def test_entropy(data):
assert entropy.shape == alpha.shape
assert torch.isclose(entropy, alpha).all()

@given(data())
def test_maxmarginals(data):
model = data.draw(sampled_from([LinearChain, SemiMarkov]))
semiring = MaxMarginalSemiring
struct = model(semiring)
vals, (batch, N) = model._rand()
m = struct.sum(vals)
alpha = struct.marginals(vals)
print("done")
# alpha = struct().marginals(vals)
print((alpha - 1.0))

print(m)
# print(m.view(-1, 1, 1, 1) + (alpha - 1.0))
# assert(False)
# log_z = model(LogSemiring).sum(vals)
# log_probs = model(LogSemiring).enumerate(vals)[1]
# log_probs = torch.stack(log_probs, dim=1) - log_z
# print(log_probs.shape, log_z.shape, log_probs.exp().sum(1))
# entropy = -log_probs.mul(log_probs.exp()).sum(1).squeeze(0)
# assert entropy.shape == alpha.shape
# assert torch.isclose(entropy, alpha).all()


@given(data())
def test_kmax(data):
Expand Down

0 comments on commit 255d132

Please sign in to comment.