Skip to content

Commit

Permalink
Merge 3cdce94 into b770f21
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Oct 23, 2019
2 parents b770f21 + 3cdce94 commit 83fe8e1
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 5 deletions.
4 changes: 4 additions & 0 deletions torch_struct/__init__.py
Expand Up @@ -8,6 +8,9 @@
TreeCRF,
SentCFG,
)
from .autoregressive import (
Autoregressive
)
from .cky_crf import CKY_CRF
from .deptree import DepTree
from .linearchain import LinearChain
Expand Down Expand Up @@ -42,6 +45,7 @@
MultiSampledSemiring,
SelfCritical,
StructDistribution,
Autoregressive,
LinearChainCRF,
SemiMarkovCRF,
DependencyCRF,
Expand Down
98 changes: 98 additions & 0 deletions torch_struct/autoregressive.py
@@ -0,0 +1,98 @@
import torch
from .semirings import MaxSemiring, KMaxSemiring
from torch.distributions.distribution import Distribution
from torch.distributions.utils import lazy_property

class Autoregressive(Distribution):
"""
Autoregressive sequence model utilizing beam search.
* batch_shape -> Given by initializer
* event_shape -> N x T sequence of choices
Parameters:
model:
init (tensor, batch_shape x hidden_shape):
n_classes (int): number of classes in each time step
n_length (int): max length of sequence
"""
def __init__(self, model, init, n_classes, n_length):
self.model = model
self.init = init
self.n_length = n_length
self.n_classes = n_classes
event_shape = (n_length, n_classes)
batch_shape = init.shape[:1]
super().__init__(batch_shape=batch_shape, event_shape=event_shape)


def log_prob(self, value):
"""
Compute log probability over values :math:`p(z)`.
Parameters:
value (tensor): One-hot events (*sample_shape x batch_shape x event_shape*)
Returns:
log_probs (*sample_shape x batch_shape*)
"""
logits = self.model.sequence_logits(self.init, value)
# batch_shape x event_shape (N x C)
log_probs = logits.log_softmax(-1)
positions = torch.arange(self.n_length)
return log_probs[:, positions, value[positions]].sum(-1)

def _beam_search(self, semiring):
# beam size
beam = semiring.one_(
torch.zeros((semiring.size(),) + self.batch_shape))
beam.requires_grad_(True)
state = self.init.unsqueeze(0).expand((semiring.size(),) + self.init.shape)
all_beams = []
for t in range(0, self.n_length):
logits = self.model.log_probs(state)
# ssize x batch_size x C
ex_beam = beam.unsqueeze(-1) + logits
ex_beam.requires_grad_(True)
all_beams.append(ex_beam)
# ssize x batch_size x C
beam, tokens = semiring.sparse_sum(ex_beam)
# ssize x batch_size
state = self.model.update_state(state, tokens)



v = beam
print(beam)
all_m = []
for k in range(v.shape[0]):
obj = v[k].sum(dim=0)
marg = torch.autograd.grad(
obj,
all_beams,
create_graph=True,
only_inputs=True,
allow_unused=False,
)
marg = torch.stack(marg, dim=2)
all_m.append(marg.sum(0))
return torch.stack(all_m, dim=0)

def greedy_argmax(self):
return self._beam_search(MaxSemiring).squeeze(0)

def beam_topk(self, K):
return self._beam_search(KMaxSemiring(K))

def sample(self, sample_shape=torch.Size()):
r"""
Compute structured samples from the distribution :math:`z \sim p(z)`.
Parameters:
sample_shape (int): number of samples
Returns:
samples (*sample_shape x batch_shape x event_shape*)
"""
pass
3 changes: 2 additions & 1 deletion torch_struct/distributions.py
Expand Up @@ -68,6 +68,7 @@ def entropy(self):
"""
return self.struct(EntropySemiring).sum(self.log_potentials, self.lengths)


@lazy_property
def argmax(self):
r"""
Expand All @@ -78,7 +79,7 @@ def argmax(self):
"""
return self.struct(MaxSemiring).marginals(self.log_potentials, self.lengths)

def kmax(self, k):
def topk(self, k):
r"""
Compute the k-max for distribution :math:`k\max p(z)`.
Expand Down
1 change: 0 additions & 1 deletion torch_struct/helpers.py
Expand Up @@ -93,7 +93,6 @@ def marginals(self, edge, lengths=None, _autograd=True, _raw=False):
v, edges, _ = self._dp(edge, lengths=lengths, force_grad=True)
if _raw:
all_m = []
print(v)
for k in range(v.shape[0]):
obj = v[k].sum(dim=0)

Expand Down
1 change: 0 additions & 1 deletion torch_struct/linearchain.py
Expand Up @@ -76,7 +76,6 @@ def merge(x, y, size):
chart[n][:, :, :size] = merge(
left(chart[n - 1], size), right(chart[n - 1], size), size
)
print(root(chart[-1][:]))
v = semiring.sum(semiring.sum(root(chart[-1][:])))

return v, [log_potentials], None
Expand Down
19 changes: 17 additions & 2 deletions torch_struct/semirings.py
Expand Up @@ -129,6 +129,8 @@ class MaxSemiring(_BaseLog):
def sum(xs, dim=-1):
return torch.max(xs, dim=dim)[0]

def sparse_sum(xs, dim=-1):
return torch.max(xs, dim=dim)

def KMaxSemiring(k):
class KMaxSemiring(_BaseLog):
Expand All @@ -147,8 +149,9 @@ def convert(cls, orig_potentials):
potentials[0] = orig_potentials
return potentials

@staticmethod
def one_(xs):
@classmethod
def one_(cls, xs):
cls.zero_(xs)
xs[0].fill_(0)
return xs

Expand All @@ -167,6 +170,18 @@ def sum(xs, dim=-1):
return xs
assert False

def sparse_sum(xs, dim=-1):
if dim == -1:
xs = xs.permute(tuple(range(1, xs.dim())) + (0,))
xs = xs.contiguous().view(xs.shape[:-2] + (-1,))
xs, xs2 = torch.topk(xs, k, dim=-1)
xs = xs.permute((xs.dim() - 1,) + tuple(range(0, xs.dim() - 1)))
xs2 = xs2.permute((xs.dim() - 1,) + tuple(range(0, xs.dim() - 1)))
assert xs.shape[0] == k
return xs, xs2
assert False


@staticmethod
def mul(a, b):
a = a.view((k, 1) + a.shape[1:])
Expand Down
59 changes: 59 additions & 0 deletions torch_struct/test_distributions.py
@@ -1,4 +1,5 @@
from .distributions import LinearChainCRF
from .autoregressive import Autoregressive
import torch
from hypothesis import given, settings
from hypothesis.strategies import integers, data, sampled_from
Expand Down Expand Up @@ -43,3 +44,61 @@ def test_simple(data, seed):
samples = dist.sample((100,))
marginals = dist.marginals
assert ((samples.mean(0) - marginals).abs() < 0.2).all()


@given(data(), integers(min_value=1, max_value=20))
@settings(max_examples=50, deadline=None)
def test_autoregressive(data, seed):
model = Autoregressive
n_classes = 2
n_length = 5
batch = 3

values = torch.rand(batch, n_length, n_classes)


values2 = values.unsqueeze(-1).expand(batch, n_length, n_classes, n_classes).clone()
values2[:, 0, :, :] = -1e9
values2[:, 0, torch.arange(n_classes), torch.arange(n_classes)] = values[:, 0]

init = torch.zeros(batch, 5).long()
class Model:
def update_state(self, prev_state, inputs):
K, batch, hidden = prev_state.shape
return prev_state + 1

def sequence_logits(self, init, seq_inputs):
pass

def log_probs(self, state):
K, batch, hidden = state.shape
t = state[0,0,0]
x = values[:, t, :].unsqueeze(0).expand(K, batch, n_classes)
return x

auto = Autoregressive(Model(), init, n_classes, n_length)
v = auto.greedy_argmax()
assert((v == LinearChainCRF(values2).argmax.sum(-1)).all())
crf = LinearChainCRF(values2)
v2 = auto.beam_topk(K=5)

print(crf.struct().score(crf.topk(5), values2, batch_dims=[0,1]))
print(crf.topk(5)[0].nonzero())
print(crf.topk(5)[1].nonzero())

# print(v2.shape)
print(v2[0].nonzero())


print(v2[1].nonzero())

# print(crf.topk(5).sum(-1).nonzero().shape)
assert((v2.nonzero() == crf.topk(5).sum(-1).nonzero()).all())


# print(v, LinearChainCRF(values2).max)
# print('0', v[0])
# print('1', v[1])
# print('2', v[2])
# print("q", LinearChainCRF(values2).max)
assert((v2[0] == LinearChainCRF(values2).argmax.sum(-1)).all())

0 comments on commit 83fe8e1

Please sign in to comment.