Skip to content

Commit

Permalink
Gumbel-CRF Semiring (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Jan 15, 2021
1 parent f8f46ee commit 0339886
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 3 deletions.
14 changes: 14 additions & 0 deletions tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
LogSemiring,
CheckpointSemiring,
CheckpointShardSemiring,
GumbelCRFSemiring,
KMaxSemiring,
SparseMaxSemiring,
MaxSemiring,
Expand Down Expand Up @@ -511,3 +512,16 @@ def test_lc_custom():
# s2 = struct.sum(vals)
# assert torch.isclose(s, s2).all()
# assert torch.isclose(marginals, marginals2).all()


@given(data())
def test_gumbel(data):
model = data.draw(sampled_from([LinearChain, SemiMarkov, DepTree]))
semiring = GumbelCRFSemiring(1.0)
test = test_lookup[model]()
struct = model(semiring)
vals, (batch, N) = test._rand()
vals.requires_grad_(True)
alpha = struct.marginals(vals)
print(alpha[0])
print(torch.autograd.grad(alpha, vals, alpha.detach())[0][0])
8 changes: 8 additions & 0 deletions torch_struct/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
MultiSampledSemiring,
KMaxSemiring,
StdSemiring,
GumbelCRFSemiring,
)


Expand Down Expand Up @@ -183,6 +184,13 @@ def count(self):
ones[self.log_potentials.eq(-float("inf"))] = 0
return self._struct(StdSemiring).sum(ones, self.lengths)

def gumbel_crf(self, temperature=1.0):
with torch.enable_grad():
st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals(
self.log_potentials, self.lengths
)
return st_gumbel

# @constraints.dependent_property
# def support(self):
# pass
Expand Down
9 changes: 8 additions & 1 deletion torch_struct/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def backward(ctx, grad_v):

return DPManual.apply(edge)

def marginals(self, edge, lengths=None, _autograd=True, _raw=False):
def marginals(self, edge, lengths=None, _autograd=True, _raw=False, _combine=False):
"""
Compute the marginals of a structured model.
Expand Down Expand Up @@ -135,6 +135,13 @@ def marginals(self, edge, lengths=None, _autograd=True, _raw=False):
)
all_m.append(self.semiring.unconvert(self._arrange_marginals(marg)))
return torch.stack(all_m, dim=0)
elif _combine:
obj = v.sum(dim=0).sum(dim=0)
marg = torch.autograd.grad(
obj, edges, create_graph=True, only_inputs=True, allow_unused=False
)
a_m = self._arrange_marginals(marg)
return a_m
else:
obj = self.semiring.unconvert(v).sum(dim=0)
marg = torch.autograd.grad(
Expand Down
100 changes: 100 additions & 0 deletions torch_struct/semirings/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,106 @@ def sum(xs, dim=-1):
return _SampledLogSumExp.apply(xs, dim)


def GumbelMaxSemiring(temp):
class _GumbelMaxLogSumExp(torch.autograd.Function):
@staticmethod
def forward(ctx, input, dim):
ctx.save_for_backward(input, torch.tensor(dim))
return torch.logsumexp(input, dim=dim)

@staticmethod
def backward(ctx, grad_output):
logits, dim = ctx.saved_tensors
grad_input = None
if ctx.needs_input_grad[0]:

def sample(ls):
pre_shape = ls.shape
update = (
ls + torch.distributions.Gumbel(0, 1).sample((ls.shape[-1],))
) / temp
out = torch.nn.functional.one_hot(update.max(-1)[1], pre_shape[-1])
return out

if dim == -1:
s = sample(logits)
else:
dim = dim if dim >= 0 else logits.dim() + dim
perm = [i for i in range(logits.dim()) if i != dim] + [dim]
rev_perm = [
a for a, b in sorted(enumerate(perm), key=lambda a: a[1])
]
s = sample(logits.permute(perm)).permute(rev_perm)

grad_input = grad_output.unsqueeze(dim).mul(s)
return grad_input, None

class _GumbelMaxSemiring(_BaseLog):
@staticmethod
def sum(xs, dim=-1):
return _GumbelMaxLogSumExp.apply(xs, dim)

return _GumbelMaxSemiring


def GumbelCRFSemiring(temp):
class ST(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, dim):
out = torch.nn.functional.one_hot(logits.max(-1)[1], dim)
out = out.type_as(logits)
ctx.save_for_backward(logits, out)
return out

@staticmethod
def backward(ctx, grad_output):
logits, out = ctx.saved_tensors
with torch.enable_grad():
ret = torch.autograd.grad(
logits.softmax(-1), logits, out * grad_output
)[0]
return ret, None

class _GumbelCRFLogSumExp(torch.autograd.Function):
@staticmethod
def forward(ctx, input, dim):
ctx.save_for_backward(input, torch.tensor(dim))
return torch.logsumexp(input, dim=dim)

@staticmethod
def backward(ctx, grad_output):
logits, dim = ctx.saved_tensors
grad_input = None
if ctx.needs_input_grad[0]:

def sample(ls):
update = (
ls + torch.distributions.Gumbel(0, 1).sample((ls.shape[-1],))
) / temp
out = ST.apply(update, ls.shape[-1])
return out

if dim == -1:
s = sample(logits)
else:
dim = dim if dim >= 0 else logits.dim() + dim
perm = [i for i in range(logits.dim()) if i != dim] + [dim]
rev_perm = [
a for a, b in sorted(enumerate(perm), key=lambda a: a[1])
]
s = sample(logits.permute(perm)).permute(rev_perm)

grad_input = grad_output.unsqueeze(dim).mul(s)
return grad_input, None

class _GumbelCRFSemiring(_BaseLog):
@staticmethod
def sum(xs, dim=-1):
return _GumbelCRFLogSumExp.apply(xs, dim)

return _GumbelCRFSemiring


bits = torch.tensor([pow(2, i) for i in range(1, 18)])


Expand Down
7 changes: 5 additions & 2 deletions torch_struct/semirings/semirings.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,12 @@ class KLDivergenceSemiring(Semiring):
Based on descriptions in:
* Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter`
* First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first`
* Parameter estimation for probabilistic finite-state
transducers :cite:`eisner2002parameter`
* First-and second-order expectation semirings with applications to
minimumrisk training on translation forests :cite:`li2009first`
* Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf`
"""

zero = 0
Expand Down

0 comments on commit 0339886

Please sign in to comment.