Skip to content

Commit

Permalink
Merge 7ddff0f into 589478d
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Oct 24, 2019
2 parents 589478d + 7ddff0f commit 14f096e
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 1 deletion.
2 changes: 2 additions & 0 deletions torch_struct/__init__.py
Expand Up @@ -18,6 +18,7 @@
LogSemiring,
StdSemiring,
KMaxSemiring,
SparseMaxSemiring,
SampledSemiring,
MaxSemiring,
EntropySemiring,
Expand All @@ -38,6 +39,7 @@
StdSemiring,
SampledSemiring,
MaxSemiring,
SparseMaxSemiring,
KMaxSemiring,
EntropySemiring,
MultiSampledSemiring,
Expand Down
62 changes: 61 additions & 1 deletion torch_struct/semirings.py
Expand Up @@ -252,7 +252,6 @@ def forward(ctx, input, dim):

@staticmethod
def backward(ctx, grad_output):
# assert ((grad_output == 64) + (grad_output == 0) + (grad_output ==1)).all()

logits, part, dim = ctx.saved_tensors
grad_input = None
Expand Down Expand Up @@ -307,3 +306,64 @@ def to_discrete(xs, j):
final = xs % 2
mbits = bits.type_as(xs)
return (((xs % mbits[i + 1]) - (xs % mbits[i]) + final) != 0).type_as(xs)


class SparseMaxSemiring(_BaseLog):
@staticmethod
def sum(xs, dim=-1):
return _SimplexProject.apply(xs, dim)


class _SimplexProject(torch.autograd.Function):
@staticmethod
def forward(ctx, input, dim, z=1):
w_star = project_simplex(input, dim)
ctx.save_for_backward(w_star.clone(), torch.tensor(dim))
x = (w_star - input).norm(p=2, dim=dim)
return x

@staticmethod
def backward(ctx, grad_output):
w_star, dim = ctx.saved_tensors
w_star.requires_grad_(True)

grad_input = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.unsqueeze(dim).mul(
_SparseMaxGrad.apply(w_star, dim)
)
return grad_input, None, None


class _SparseMaxGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, w_star, dim):
ctx.save_for_backward(w_star, dim)
return w_star

@staticmethod
def backward(ctx, grad_output):
w_star, dim = ctx.saved_tensors
print(grad_output.shape, w_star.shape, dim)
return sparsemax_grad(grad_output, w_star, dim.item()), None


def project_simplex(v, dim, z=1):
v_sorted, _ = torch.sort(v, dim=dim, descending=True)
cssv = torch.cumsum(v_sorted, dim=dim) - z
ind = torch.arange(1, 1 + len(v)).to(dtype=v.dtype)
cond = v_sorted - cssv / ind >= 0
k = cond.sum(dim=dim, keepdim=True)
tau = cssv.gather(dim, k - 1) / k.to(dtype=v.dtype)
w = torch.clamp(v - tau, min=0)
return w


def sparsemax_grad(dout, w_star, dim):
out = dout.clone()
supp = w_star > 0
out[w_star <= 0] = 0
nnz = supp.to(dtype=dout.dtype).sum(dim=dim, keepdim=True)
out = out - (out.sum(dim=dim, keepdim=True) / nnz)
out[w_star <= 0] = 0
return out
13 changes: 13 additions & 0 deletions torch_struct/test_algorithms.py
Expand Up @@ -6,6 +6,7 @@
from .semirings import (
LogSemiring,
KMaxSemiring,
SparseMaxSemiring,
MaxSemiring,
StdSemiring,
SampledSemiring,
Expand Down Expand Up @@ -298,3 +299,15 @@ def test_hmm():
observations = torch.randint(0, V, (batch, N))
out = LinearChain.hmm(transition, emission, init, observations)
LinearChain().sum(out)


@given(data())
def test_sparse_max(data):
model = data.draw(sampled_from([LinearChain]))
semiring = SparseMaxSemiring
vals, (batch, N) = model._rand()
vals.requires_grad_(True)
model(semiring).sum(vals)
sparsemax = model(semiring).marginals(vals)
print(vals.requires_grad)
sparsemax.sum().backward()

0 comments on commit 14f096e

Please sign in to comment.