Skip to content

Commit

Permalink
Merge 0f3b577 into 9f93432
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Nov 8, 2020
2 parents 9f93432 + 0f3b577 commit 01e58b0
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 5 deletions.
4 changes: 4 additions & 0 deletions torch_struct/__init__.py
Expand Up @@ -20,6 +20,8 @@
from .semirings import (
LogSemiring,
FastLogSemiring,
GumbelMaxSemiring,
GumbelCRFSemiring,
TempMax,
FastMaxSemiring,
FastSampleSemiring,
Expand Down Expand Up @@ -56,6 +58,8 @@
EntropySemiring,
MultiSampledSemiring,
SelfCritical,
GumbelMaxSemiring,
GumbelCRFSemiring,
StructDistribution,
Autoregressive,
AutoregressiveModel,
Expand Down
19 changes: 19 additions & 0 deletions torch_struct/distributions.py
Expand Up @@ -16,6 +16,7 @@
MultiSampledSemiring,
KMaxSemiring,
StdSemiring,
GumbelCRFSemiring
)


Expand Down Expand Up @@ -171,6 +172,14 @@ def count(self):
ones, self.lengths
)


def gumbel_crf(self, temperature=1.0):
with torch.enable_grad():
st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals(
self.log_potentials, self.lengths
)
return st_gumbel

# @constraints.dependent_property
# def support(self):
# pass
Expand Down Expand Up @@ -233,6 +242,16 @@ def _struct(self, sr=None):
return self.struct(sr if sr is not None else LogSemiring)


class StraightThrough(torch.autograd.Function):
@staticmethod
def forward(ctx, hard, soft):
ctx.save_for_backward(soft)
return hard

def backward(ctx, grad_output):
soft = ctx.saved_tensors
return soft.mul(grad_output)

class LinearChainCRF(StructDistribution):
r"""
Represents structured linear-chain CRFs with C classes.
Expand Down
9 changes: 8 additions & 1 deletion torch_struct/helpers.py
Expand Up @@ -145,7 +145,7 @@ def backward(ctx, grad_v):

return DPManual.apply(edge)

def marginals(self, edge, lengths=None, _autograd=True, _raw=False):
def marginals(self, edge, lengths=None, _autograd=True, _raw=False, _combine=False):
"""
Compute the marginals of a structured model.
Expand Down Expand Up @@ -178,6 +178,13 @@ def marginals(self, edge, lengths=None, _autograd=True, _raw=False):
)
all_m.append(self.semiring.unconvert(self._arrange_marginals(marg)))
return torch.stack(all_m, dim=0)
elif _combine:
obj = v.sum(dim=0).sum(dim=0)
marg = torch.autograd.grad(
obj, edges, create_graph=True, only_inputs=True, allow_unused=False
)
a_m = self._arrange_marginals(marg)
return a_m
else:
obj = self.semiring.unconvert(v).sum(dim=0)
marg = torch.autograd.grad(
Expand Down
6 changes: 4 additions & 2 deletions torch_struct/semirings/__init__.py
Expand Up @@ -4,9 +4,9 @@
KMaxSemiring,
MaxSemiring,
EntropySemiring,
TempMax,
CrossEntropySemiring,
KLDivergenceSemiring,
TempMax,
)

from .fast_semirings import FastLogSemiring, FastMaxSemiring, FastSampleSemiring
Expand All @@ -16,14 +16,16 @@

from .sparse_max import SparseMaxSemiring

from .sample import MultiSampledSemiring, SampledSemiring
from .sample import MultiSampledSemiring, SampledSemiring, GumbelMaxSemiring, GumbelCRFSemiring


# For flake8 compatibility.
__all__ = [
FastLogSemiring,
FastMaxSemiring,
FastSampleSemiring,
GumbelCRFSemiring,
GumbelMaxSemiring,
LogSemiring,
StdSemiring,
SampledSemiring,
Expand Down
156 changes: 156 additions & 0 deletions torch_struct/semirings/sample.py
Expand Up @@ -51,6 +51,92 @@ class SampledSemiring(_BaseLog):
@staticmethod
def sum(xs, dim=-1):
return _SampledLogSumExp.apply(xs, dim)


def GumbelMaxSemiring(temp):
class _GumbelMaxLogSumExp(torch.autograd.Function):
@staticmethod
def forward(ctx, input, dim):
ctx.save_for_backward(input, torch.tensor(dim))
return torch.logsumexp(input, dim=dim)

@staticmethod
def backward(ctx, grad_output):
pre_shape = ls.shape
logits, dim = ctx.saved_tensors
grad_input = None
if ctx.needs_input_grad[0]:
def sample(ls):
update = (ls + torch.distributions.Gumbel(0, 1).sample((ls.shape[-1],))) / temp
out = torch.nn.functional.one_hot(update.max(-1)[1], pre_shape[-1])
return out

if dim == -1:
s = sample(logits)
else:
dim = dim if dim >= 0 else logits.dim() + dim
perm = [i for i in range(logits.dim()) if i != dim] + [dim]
rev_perm = [a for a, b in sorted(enumerate(perm), key=lambda a: a[1])]
s = sample(logits.permute(perm)).permute(rev_perm)

grad_input = grad_output.unsqueeze(dim).mul(s)
return grad_input, None

class _GumbelMaxSemiring(_BaseLog):
@staticmethod
def sum(xs, dim=-1):
return _GumbelMaxLogSumExp.apply(xs, dim)

return _GumbelMaxSemiring


def GumbelCRFSemiring(temp):
class ST(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, dim):
ctx.save_for_backward(logits)
out = torch.nn.functional.one_hot(logits.max(-1)[1], dim)
out = out.type_as(logits)
return out

@staticmethod
def backward(ctx, grad_output):
logits, = ctx.saved_tensors
return logits.softmax(-1) * grad_output, None

class _GumbelCRFLogSumExp(torch.autograd.Function):
@staticmethod
def forward(ctx, input, dim):
ctx.save_for_backward(input, torch.tensor(dim))
return torch.logsumexp(input, dim=dim)

@staticmethod
def backward(ctx, grad_output):
logits, dim = ctx.saved_tensors
grad_input = None
if ctx.needs_input_grad[0]:
def sample(ls):
update = (ls + torch.distributions.Gumbel(0, 1).sample((ls.shape[-1],))) / temp
out = ST.apply(update, ls.shape[-1])
return out

if dim == -1:
s = sample(logits)
else:
dim = dim if dim >= 0 else logits.dim() + dim
perm = [i for i in range(logits.dim()) if i != dim] + [dim]
rev_perm = [a for a, b in sorted(enumerate(perm), key=lambda a: a[1])]
s = sample(logits.permute(perm)).permute(rev_perm)

grad_input = grad_output.unsqueeze(dim).mul(s)
return grad_input, None

class _GumbelCRFSemiring(_BaseLog):
@staticmethod
def sum(xs, dim=-1):
return _GumbelCRFLogSumExp.apply(xs, dim)

return _GumbelCRFSemiring


bits = torch.tensor([pow(2, i) for i in range(1, 18)])
Expand Down Expand Up @@ -125,3 +211,73 @@ def to_discrete(xs, j):
final = xs % 2
mbits = bits.type_as(xs)
return (((xs % mbits[i + 1]) - (xs % mbits[i]) + final) != 0).type_as(xs)


# def GumbelCRFSemiring(temp):
# class _GumbelCRF_LSE(torch.autograd.Function):
# @staticmethod
# def forward(ctx, input, dim):
# ctx.save_for_backward(input, torch.tensor(dim))
# return torch.logsumexp(input, dim=dim)

# @staticmethod
# def backward(ctx, grad_output):
# logits, dim = ctx.saved_tensors
# grad_input = None
# hard = grad_output[0]
# soft = grad_output[1]
# print(hard.shape, logits[0].shape)
# new_logits = logits[0]

# if ctx.needs_input_grad[0]:
# def sample(ls):
# pre_shape = ls.shape
# update = (ls + torch.distributions.Gumbel(0, 1).sample((pre_shape[-1],))) / temp
# hard = torch.nn.functional.one_hot(update.max(-1)[1], pre_shape[-1])
# soft = update.softmax(-1)
# return hard, soft

# sample_hard, sample_soft = sample(new_logits)
# grad_input = torch.stack(
# [hard.unsqueeze(dim).mul(sample_hard),
# soft.unsqueeze(dim).mul(sample_soft)], dim=0)
# return grad_input, None

# class GumbelCRFSemiring(_BaseLog):
# @staticmethod
# def size():
# return 2

# @classmethod
# def convert(cls, orig_potentials):
# potentials = torch.zeros(
# (2,) + orig_potentials.shape,
# dtype=orig_potentials.dtype,
# device=orig_potentials.device,
# )
# cls.zero_(potentials)
# potentials[0] = orig_potentials
# potentials[1] = orig_potentials
# return potentials

# @classmethod
# def one_(cls, xs):
# cls.zero_(xs)
# xs.fill_(0)
# return xs

# @staticmethod
# def unconvert(potentials):
# return potentials[0]

# @staticmethod
# def sum(xs, dim=-1):
# if dim == -1:
# return _GumbelCRF_LSE.apply(xs, dim)
# assert False

# @staticmethod
# def mul(a, b):
# return a + b

# return GumbelCRFSemiring
8 changes: 6 additions & 2 deletions torch_struct/semirings/semirings.py
Expand Up @@ -277,10 +277,14 @@ class KLDivergenceSemiring(Semiring):
Based on descriptions in:
* Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter`
* First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first`
* Parameter estimation for probabilistic finite-state
transducers :cite:`eisner2002parameter`
* First-and second-order expectation semirings with applications to
minimumrisk training on translation forests :cite:`li2009first`
* Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf`
"""

zero = 0
@staticmethod
def size():
Expand Down
30 changes: 30 additions & 0 deletions torch_struct/test_algorithms.py
Expand Up @@ -8,6 +8,7 @@
LogSemiring,
CheckpointSemiring,
CheckpointShardSemiring,
GumbelCRFSemiring,
KMaxSemiring,
SparseMaxSemiring,
MaxSemiring,
Expand Down Expand Up @@ -509,3 +510,32 @@ def test_lc_custom():
# s2 = struct.sum(vals)
# assert torch.isclose(s, s2).all()
# assert torch.isclose(marginals, marginals2).all()

@given(data())
def test_gumbel(data):
model = data.draw(sampled_from([LinearChain, SemiMarkov, DepTree]))
K = 2
semiring = GumbelCRFSemiring(1.0)
struct = model(semiring)
vals, (batch, N) = model._rand()
vals.requires_grad_(True)
alpha = struct.marginals(vals)
print(alpha[0])
print(torch.autograd.grad(alpha, vals, alpha.detach())[0][0])

assert(False)
# assert (alpha[0] == max1).all()
# assert (alpha[1] <= max1).all()

# topk = struct.marginals(vals, _raw=True)
# argmax = model(MaxSemiring).marginals(vals)
# assert (topk[0] == argmax).all()
# print(topk[0].nonzero(), topk[1].nonzero())
# assert (topk[1] != topk[0]).any()

# if model != DepTree:
# log_probs = model(MaxSemiring).enumerate(vals)[1]
# tops = torch.topk(torch.cat(log_probs, dim=0), 5, 0)[0]
# assert torch.isclose(struct.score(topk[1], vals), alpha[1]).all()
# for k in range(K):
# assert (torch.isclose(alpha[k], tops[k])).all()

0 comments on commit 01e58b0

Please sign in to comment.