Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Oct 23, 2019
1 parent bf3c560 commit 3cdce94
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 18 deletions.
47 changes: 35 additions & 12 deletions torch_struct/autoregressive.py
@@ -1,5 +1,5 @@
import torch
from .semirings import MaxSemiring
from .semirings import MaxSemiring, KMaxSemiring
from torch.distributions.distribution import Distribution
from torch.distributions.utils import lazy_property

Expand All @@ -17,14 +17,13 @@ class Autoregressive(Distribution):
n_length (int): max length of sequence
"""

def __init__(self, model, init, n_classes, n_length):
self.model = model
self.init = init
self.n_length = n_length
self.n_classes = n_classes
event_shape = (n_length, n_classes)
batch_shape = start_state.shape[:1]
batch_shape = init.shape[:1]
super().__init__(batch_shape=batch_shape, event_shape=event_shape)


Expand All @@ -46,21 +45,45 @@ def log_prob(self, value):

def _beam_search(self, semiring):
# beam size
beam = semiring.ones_(
torch.zeros(semiring.size(), self.batch_shape))
beam = semiring.one_(
torch.zeros((semiring.size(),) + self.batch_shape))
beam.requires_grad_(True)
state = self.init.unsqueeze(0).expand((semiring.size(),) + self.init.shape)
all_beams = []
for t in range(0, self.n_length):
logits = self.model.logits(init)
logits = self.model.log_probs(state)
# ssize x batch_size x C
beam = semiring.times(beam.unsqueeze(-1), logits.log_softmax(-1))
ex_beam = beam.unsqueeze(-1) + logits
ex_beam.requires_grad_(True)
all_beams.append(ex_beam)
# ssize x batch_size x C
beam, backpointers = semiring.sparse_sum(beam)
beam, tokens = semiring.sparse_sum(ex_beam)
# ssize x batch_size
state = self.model.update_state(state, backpointers)
return beam
state = self.model.update_state(state, tokens)



v = beam
print(beam)
all_m = []
for k in range(v.shape[0]):
obj = v[k].sum(dim=0)
marg = torch.autograd.grad(
obj,
all_beams,
create_graph=True,
only_inputs=True,
allow_unused=False,
)
marg = torch.stack(marg, dim=2)
all_m.append(marg.sum(0))
return torch.stack(all_m, dim=0)

def greedy_argmax(self):
return self._beam_search(MaxSemiring).squeeze(0)

def greedy_max(self):
return _beam_search(self, MaxSemiring)
def beam_topk(self, K):
return self._beam_search(KMaxSemiring(K))

def sample(self, sample_shape=torch.Size()):
r"""
Expand Down
3 changes: 2 additions & 1 deletion torch_struct/distributions.py
Expand Up @@ -68,6 +68,7 @@ def entropy(self):
"""
return self.struct(EntropySemiring).sum(self.log_potentials, self.lengths)


@lazy_property
def argmax(self):
r"""
Expand All @@ -78,7 +79,7 @@ def argmax(self):
"""
return self.struct(MaxSemiring).marginals(self.log_potentials, self.lengths)

def kmax(self, k):
def topk(self, k):
r"""
Compute the k-max for distribution :math:`k\max p(z)`.
Expand Down
1 change: 0 additions & 1 deletion torch_struct/helpers.py
Expand Up @@ -93,7 +93,6 @@ def marginals(self, edge, lengths=None, _autograd=True, _raw=False):
v, edges, _ = self._dp(edge, lengths=lengths, force_grad=True)
if _raw:
all_m = []
print(v)
for k in range(v.shape[0]):
obj = v[k].sum(dim=0)

Expand Down
1 change: 0 additions & 1 deletion torch_struct/linearchain.py
Expand Up @@ -76,7 +76,6 @@ def merge(x, y, size):
chart[n][:, :, :size] = merge(
left(chart[n - 1], size), right(chart[n - 1], size), size
)
print(root(chart[-1][:]))
v = semiring.sum(semiring.sum(root(chart[-1][:])))

return v, [log_potentials], None
Expand Down
7 changes: 4 additions & 3 deletions torch_struct/semirings.py
Expand Up @@ -149,8 +149,9 @@ def convert(cls, orig_potentials):
potentials[0] = orig_potentials
return potentials

@staticmethod
def one_(xs):
@classmethod
def one_(cls, xs):
cls.zero_(xs)
xs[0].fill_(0)
return xs

Expand All @@ -175,7 +176,7 @@ def sparse_sum(xs, dim=-1):
xs = xs.contiguous().view(xs.shape[:-2] + (-1,))
xs, xs2 = torch.topk(xs, k, dim=-1)
xs = xs.permute((xs.dim() - 1,) + tuple(range(0, xs.dim() - 1)))
xs, xs2 = xs.permute((xs.dim() - 1,) + tuple(range(0, xs.dim() - 1)))
xs2 = xs2.permute((xs.dim() - 1,) + tuple(range(0, xs.dim() - 1)))
assert xs.shape[0] == k
return xs, xs2
assert False
Expand Down
59 changes: 59 additions & 0 deletions torch_struct/test_distributions.py
@@ -1,4 +1,5 @@
from .distributions import LinearChainCRF
from .autoregressive import Autoregressive
import torch
from hypothesis import given, settings
from hypothesis.strategies import integers, data, sampled_from
Expand Down Expand Up @@ -43,3 +44,61 @@ def test_simple(data, seed):
samples = dist.sample((100,))
marginals = dist.marginals
assert ((samples.mean(0) - marginals).abs() < 0.2).all()


@given(data(), integers(min_value=1, max_value=20))
@settings(max_examples=50, deadline=None)
def test_autoregressive(data, seed):
model = Autoregressive
n_classes = 2
n_length = 5
batch = 3

values = torch.rand(batch, n_length, n_classes)


values2 = values.unsqueeze(-1).expand(batch, n_length, n_classes, n_classes).clone()
values2[:, 0, :, :] = -1e9
values2[:, 0, torch.arange(n_classes), torch.arange(n_classes)] = values[:, 0]

init = torch.zeros(batch, 5).long()
class Model:
def update_state(self, prev_state, inputs):
K, batch, hidden = prev_state.shape
return prev_state + 1

def sequence_logits(self, init, seq_inputs):
pass

def log_probs(self, state):
K, batch, hidden = state.shape
t = state[0,0,0]
x = values[:, t, :].unsqueeze(0).expand(K, batch, n_classes)
return x

auto = Autoregressive(Model(), init, n_classes, n_length)
v = auto.greedy_argmax()
assert((v == LinearChainCRF(values2).argmax.sum(-1)).all())
crf = LinearChainCRF(values2)
v2 = auto.beam_topk(K=5)

print(crf.struct().score(crf.topk(5), values2, batch_dims=[0,1]))
print(crf.topk(5)[0].nonzero())
print(crf.topk(5)[1].nonzero())

# print(v2.shape)
print(v2[0].nonzero())


print(v2[1].nonzero())

# print(crf.topk(5).sum(-1).nonzero().shape)
assert((v2.nonzero() == crf.topk(5).sum(-1).nonzero()).all())


# print(v, LinearChainCRF(values2).max)
# print('0', v[0])
# print('1', v[1])
# print('2', v[2])
# print("q", LinearChainCRF(values2).max)
assert((v2[0] == LinearChainCRF(values2).argmax.sum(-1)).all())

0 comments on commit 3cdce94

Please sign in to comment.