Skip to content

Commit

Permalink
support KL and cross-entropy semiring
Browse files Browse the repository at this point in the history
  • Loading branch information
sonta committed Jul 30, 2020
1 parent d9157fc commit 5adffa7
Show file tree
Hide file tree
Showing 10 changed files with 206 additions and 13 deletions.
1 change: 1 addition & 0 deletions torch_struct/__init__.py
Expand Up @@ -72,3 +72,4 @@
CheckpointShardSemiring,
TempMax,
]

3 changes: 1 addition & 2 deletions torch_struct/cky_crf.py
Expand Up @@ -6,8 +6,7 @@

class CKY_CRF(_Struct):
def _check_potentials(self, edge, lengths=None):
batch, N, _, NT = edge.shape
edge.requires_grad_(True)
batch, N, _, NT = self._get_dimension(edge)
edge = self.semiring.convert(edge)
if lengths is None:
lengths = torch.LongTensor([N] * batch).to(edge.device)
Expand Down
6 changes: 4 additions & 2 deletions torch_struct/deptree.py
Expand Up @@ -43,6 +43,8 @@ class DepTree(_Struct):
Parameters:
arc_scores_in: Arc scores of shape (B, N, N) or (B, N, N, L) with root scores on
diagonal.
Note: For single-root case, do not set cache=True for now.
"""

def _dp(self, arc_scores_in, lengths=None, force_grad=False, cache=True):
Expand All @@ -61,7 +63,7 @@ def _dp(self, arc_scores_in, lengths=None, force_grad=False, cache=True):
alpha = [
[
[
Chart((batch, N, N), arc_scores, semiring, cache=cache)
Chart((batch, N, N), arc_scores, semiring, cache=multiroot)
for _ in range(2)
]
for _ in range(2)
Expand Down Expand Up @@ -113,7 +115,7 @@ def _dp(self, arc_scores_in, lengths=None, force_grad=False, cache=True):

def _check_potentials(self, arc_scores, lengths=None):
semiring = self.semiring
batch, N, N2 = arc_scores.shape[:3]
batch, N, N2, *_ = self._get_dimension(arc_scores)
assert N == N2, "Non-square potentials"
if lengths is None:
lengths = torch.LongTensor([N - 1] * batch).to(arc_scores.device)
Expand Down
27 changes: 26 additions & 1 deletion torch_struct/distributions.py
Expand Up @@ -11,12 +11,15 @@
LogSemiring,
MaxSemiring,
EntropySemiring,
CrossEntropySemiring,
KLDivergenceSemiring,
MultiSampledSemiring,
KMaxSemiring,
StdSemiring,
)



class StructDistribution(Distribution):
r"""
Base structured distribution class.
Expand Down Expand Up @@ -65,6 +68,8 @@ def log_prob(self, value):
value.type_as(self.log_potentials),
batch_dims=batch_dims,
)


return v - self.partition

@lazy_property
Expand All @@ -75,13 +80,32 @@ def entropy(self):
Returns:
entropy (*batch_shape*)
"""

return self._struct(EntropySemiring).sum(self.log_potentials, self.lengths)

def cross_entropy(self, other):
"""
Compute cross-entropy for distribution p(self) and q(other) :math:`H[p, q]`.
Returns:
cross entropy (*batch_shape*)
"""

return self._struct(CrossEntropySemiring).sum([self.log_potentials, other.log_potentials], self.lengths)

def kl(self, other):
"""
Compute KL-divergence for distribution p(self) and q(other) :math:`KL[p || q] = H[p, q] - H[p]`.
Returns:
cross entropy (*batch_shape*)
"""
return self._struct(KLDivergenceSemiring).sum([self.log_potentials, other.log_potentials], self.lengths)

@lazy_property
def max(self):
r"""
Compute an max for distribution :math:`\max p(z)`.
Returns:
max (*batch_shape*)
"""
Expand Down Expand Up @@ -355,6 +379,7 @@ def __init__(self, log_potentials, lengths=None, args={}, multiroot=True):
setattr(self.struct, "multiroot", multiroot)



class TreeCRF(StructDistribution):
r"""
Represents a 0th-order span parser with NT nonterminals. Implemented using a
Expand Down
9 changes: 9 additions & 0 deletions torch_struct/helpers.py
Expand Up @@ -79,6 +79,15 @@ def _bin_length(self, length):
bin_N = int(math.pow(2, log_N))
return log_N, bin_N

def _get_dimension(self, edge):
if isinstance(edge, list):
for t in edge:
t.requires_grad_(True)
return edge[0].shape
else:
edge.requires_grad_(True)
return edge.shape

def _chart(self, size, potentials, force_grad):
return self._make_chart(1, size, potentials, force_grad)[0]

Expand Down
4 changes: 1 addition & 3 deletions torch_struct/linearchain.py
Expand Up @@ -28,10 +28,8 @@ class LinearChain(_Struct):
"""

def _check_potentials(self, edge, lengths=None):
batch, N_1, C, C2 = edge.shape
edge.requires_grad_(True)
batch, N_1, C, C2 = self._get_dimension(edge)
edge = self.semiring.convert(edge)

N = N_1 + 1

if lengths is None:
Expand Down
2 changes: 1 addition & 1 deletion torch_struct/semimarkov.py
Expand Up @@ -8,7 +8,7 @@ class SemiMarkov(_Struct):
"""

def _check_potentials(self, edge, lengths=None):
batch, N_1, K, C, C2 = edge.shape
batch, N_1, K, C, C2 = self._get_dimension(edge)
edge = self.semiring.convert(edge)
N = N_1 + 1
if lengths is None:
Expand Down
4 changes: 4 additions & 0 deletions torch_struct/semirings/__init__.py
Expand Up @@ -4,6 +4,8 @@
KMaxSemiring,
MaxSemiring,
EntropySemiring,
CrossEntropySemiring,
KLDivergenceSemiring,
TempMax,
)

Expand All @@ -29,6 +31,8 @@
SparseMaxSemiring,
KMaxSemiring,
EntropySemiring,
CrossEntropySemiring,
KLDivergenceSemiring,
MultiSampledSemiring,
CheckpointSemiring,
CheckpointShardSemiring,
Expand Down
145 changes: 145 additions & 0 deletions torch_struct/semirings/semirings.py
Expand Up @@ -269,6 +269,150 @@ def mul(a, b):
return KMaxSemiring


class KLDivergenceSemiring(Semiring):
"""
Implements an KL-divergence semiring.
Computes both the log-values of two distributions and the running KL divergence between two distributions.
Based on descriptions in:
* Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter`
* First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first`
* Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf`
"""
zero = 0
@staticmethod
def size():
return 3

@staticmethod
def convert(xs):
values = torch.zeros((3,) + xs[0].shape).type_as(xs[0])
values[0] = xs[0]
values[1] = xs[1]
values[2] = 0
return values

@staticmethod
def unconvert(xs):
return xs[-1]

@staticmethod
def sum(xs, dim=-1):
assert dim != 0
d = dim - 1 if dim > 0 else dim
part_p = torch.logsumexp(xs[0], dim=d)
part_q = torch.logsumexp(xs[1], dim=d)
log_sm_p = xs[0] - part_p.unsqueeze(d)
log_sm_q = xs[1] - part_q.unsqueeze(d)
sm_p = log_sm_p.exp()
return torch.stack((part_p, part_q, torch.sum(xs[2].mul(sm_p) - log_sm_q.mul(sm_p) + log_sm_p.mul(sm_p), dim=d)))

@staticmethod
def mul(a, b):
return torch.stack((a[0] + b[0], a[1] + b[1], a[2] + b[2]))

@classmethod
def prod(cls, xs, dim=-1):
return xs.sum(dim)

@classmethod
def zero_mask_(cls, xs, mask):
"Fill *ssize x ...* tensor with additive identity."
xs[0].masked_fill_(mask, -1e5)
xs[1].masked_fill_(mask, -1e5)
xs[2].masked_fill_(mask, 0)

@staticmethod
def zero_(xs):
xs[0].fill_(-1e5)
xs[1].fill_(-1e5)
xs[2].fill_(0)
return xs

@staticmethod
def one_(xs):
xs[0].fill_(0)
xs[1].fill_(0)
xs[2].fill_(0)
return xs

class CrossEntropySemiring(Semiring):
"""
Implements an cross-entropy expectation semiring.
Computes both the log-values of two distributions and the running cross entropy between two distributions.
Based on descriptions in:
* Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter`
* First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first`
* Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf`
"""

zero = 0

@staticmethod
def size():
return 3

@staticmethod
def convert(xs):
values = torch.zeros((3,) + xs[0].shape).type_as(xs[0])
values[0] = xs[0]
values[1] = xs[1]
values[2] = 0
return values

@staticmethod
def unconvert(xs):
return xs[-1]

@staticmethod
def sum(xs, dim=-1):
assert dim != 0
d = dim - 1 if dim > 0 else dim
part_p = torch.logsumexp(xs[0], dim=d)
part_q = torch.logsumexp(xs[1], dim=d)
log_sm_p = xs[0] - part_p.unsqueeze(d)
log_sm_q = xs[1] - part_q.unsqueeze(d)
sm_p = log_sm_p.exp()
return torch.stack((part_p, part_q, torch.sum(xs[2].mul(sm_p) - log_sm_q.mul(sm_p), dim=d)))

@staticmethod
def mul(a, b):
return torch.stack((a[0] + b[0], a[1] + b[1], a[2] + b[2]))

@classmethod
def prod(cls, xs, dim=-1):
return xs.sum(dim)

@classmethod
def zero_mask_(cls, xs, mask):
"Fill *ssize x ...* tensor with additive identity."
xs[0].masked_fill_(mask, -1e5)
xs[1].masked_fill_(mask, -1e5)
xs[2].masked_fill_(mask, 0)

@staticmethod
def zero_(xs):
xs[0].fill_(-1e5)
xs[1].fill_(-1e5)
xs[2].fill_(0)
return xs

@staticmethod
def one_(xs):
xs[0].fill_(0)
xs[1].fill_(0)
xs[2].fill_(0)
return xs





class EntropySemiring(Semiring):
"""
Implements an entropy expectation semiring.
Expand All @@ -279,6 +423,7 @@ class EntropySemiring(Semiring):
* Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter`
* First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first`
* Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf`
"""

zero = 0
Expand Down
18 changes: 14 additions & 4 deletions torch_struct/test_distributions.py
Expand Up @@ -21,19 +21,29 @@ def test_simple(data, seed):
lengths = torch.tensor(
[data.draw(integers(min_value=2, max_value=N)) for b in range(batch - 1)] + [N]
)

dist = model(vals, lengths)
edges, enum_lengths = dist.enumerate_support()
print(edges.shape)
log_probs = dist.log_prob(edges)
for b in range(lengths.shape[0]):
log_probs[enum_lengths[b] :, b] = -1e9

assert torch.isclose(log_probs.exp().sum(0), torch.tensor(1.0)).all()

entropy = dist.entropy
assert torch.isclose(entropy, -log_probs.exp().mul(log_probs).sum(0)).all()

vals2 = torch.rand(*vals.shape)
dist2 = model(vals2, lengths)

cross_entropy = dist.cross_entropy(other=dist2)
kl = dist.kl(other=dist2)

edges2, enum_lengths2 = dist2.enumerate_support()
log_probs2 = dist2.log_prob(edges2)
for b in range(lengths.shape[0]):
log_probs2[enum_lengths2[b] :, b] = -1e9

assert torch.isclose(cross_entropy, -log_probs.exp().mul(log_probs2).sum(0)).all()
assert torch.isclose(kl, -log_probs.exp().mul(log_probs2-log_probs).sum(0)).all()

argmax = dist.argmax
_, max_indices = log_probs.max(0)

Expand Down

0 comments on commit 5adffa7

Please sign in to comment.