Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Oct 23, 2019
1 parent ebcf96f commit 299c4c6
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 27 deletions.
4 changes: 1 addition & 3 deletions torch_struct/__init__.py
Expand Up @@ -8,9 +8,7 @@
TreeCRF,
SentCFG,
)
from .autoregressive import (
Autoregressive
)
from .autoregressive import Autoregressive
from .cky_crf import CKY_CRF
from .deptree import DepTree
from .linearchain import LinearChain
Expand Down
88 changes: 73 additions & 15 deletions torch_struct/autoregressive.py
Expand Up @@ -3,6 +3,51 @@
from torch.distributions.distribution import Distribution
from torch.distributions.utils import lazy_property


class AutoregressiveModel:
"""
User should implement as their favorite RNN / Transformer / etc.
"""

def sequence_logits(self, init, seq_inputs):
"""
Compute the logits for all tokens in a batched sequence :math:`p(y_1, ... y_{T})`
Parameters:
init (batch_size x hidden_shape): everything needed for conditioning.
inputs (batch_size x N x C): next tokens to update representation
Retuns:
logits (batch_size x C): next set of logits.
"""
pass

def local_logits(self, state):
"""
Compute the local logits of :math:`p(y_t | y_{1:t-1})`
Parameters:
state (batch_size x hidden_shape): everything needed for conditioning.
Retuns:
logits (batch_size x C): next set of logits.
"""
pass

def update_state(self, prev_state, inputs):
"""
Update the model state based on previous state and inputs
Parameters:
prev_state (batch_size x hidden_shape): everything needed for conditioning.
inputs (batch_size x C): next tokens to update representation
Retuns:
state (batch_size x hidden_shape): everything needed for next conditioning.
"""
pass


class Autoregressive(Distribution):
"""
Autoregressive sequence model utilizing beam search.
Expand All @@ -11,12 +56,13 @@ class Autoregressive(Distribution):
* event_shape -> N x T sequence of choices
Parameters:
model:
init (tensor, batch_shape x hidden_shape):
model (AutoregressiveModel): A lazily computed autoregressive model.
init (tensor, batch_shape x hidden_shape): initial state of autoregressive model.
n_classes (int): number of classes in each time step
n_length (int): max length of sequence
"""

def __init__(self, model, init, n_classes, n_length):
self.model = model
self.init = init
Expand All @@ -26,7 +72,6 @@ def __init__(self, model, init, n_classes, n_length):
batch_shape = init.shape[:1]
super().__init__(batch_shape=batch_shape, event_shape=event_shape)


def log_prob(self, value, normalize=True):
"""
Compute log probability over values :math:`p(z)`.
Expand All @@ -48,45 +93,58 @@ def log_prob(self, value, normalize=True):
# batch_shape x event_shape (N x C)
positions = torch.arange(self.n_length)
batch = torch.arange(batch_shape)
return log_probs.masked_fill_(value==0, 0).sum(-1).sum(-1)
return log_probs.masked_fill_(value == 0, 0).sum(-1).sum(-1)

def _beam_search(self, semiring):
# beam size
beam = semiring.one_(
torch.zeros((semiring.size(),) + self.batch_shape))
beam.requires_grad_(True)
def _beam_search(self, semiring, gumbel=True):
beam = semiring.one_(torch.zeros((semiring.size(),) + self.batch_shape))
state = self.init.unsqueeze(0).expand((semiring.size(),) + self.init.shape)

# Beam Search
all_beams = []
for t in range(0, self.n_length):
logits = self.model.local_logits(state)
# ssize x batch_size x C
if gumbel:
logits = logits + torch.distributions.Gumbel(0.0, 0.0).sample(
logits.shape
)

ex_beam = beam.unsqueeze(-1) + logits
ex_beam.requires_grad_(True)
all_beams.append(ex_beam)
beam, tokens = semiring.sparse_sum(ex_beam)
state = self.model.update_state(state, tokens)

# Back pointers
v = beam
all_m = []
for k in range(v.shape[0]):
obj = v[k].sum(dim=0)
marg = torch.autograd.grad(
obj,
all_beams,
create_graph=True,
only_inputs=True,
allow_unused=False,
obj, all_beams, create_graph=True, only_inputs=True, allow_unused=False
)
marg = torch.stack(marg, dim=2)
all_m.append(marg.sum(0))
return torch.stack(all_m, dim=0)

def greedy_argmax(self):
"""
Compute "argmax" using greedy search
"""
return self._beam_search(MaxSemiring).squeeze(0)

def beam_topk(self, K):
"""
Compute "top-k" using beam search
"""
return self._beam_search(KMaxSemiring(K))

def sample_without_replacement(self, sample_shape=torch.Size()):
"""
Compute sampling without replacement using Gumbel trick.
"""
K = sample_shape[0]
return self._beam_search(KMaxSemiring(K), gumbel=True)

def sample(self, sample_shape=torch.Size()):
r"""
Compute structured samples from the distribution :math:`z \sim p(z)`.
Expand Down
9 changes: 7 additions & 2 deletions torch_struct/distributions.py
Expand Up @@ -6,7 +6,13 @@
from .semimarkov import SemiMarkov
from .deptree import DepTree, deptree_nonproj, deptree_part
from .cky_crf import CKY_CRF
from .semirings import LogSemiring, MaxSemiring, EntropySemiring, MultiSampledSemiring, KMaxSemiring
from .semirings import (
LogSemiring,
MaxSemiring,
EntropySemiring,
MultiSampledSemiring,
KMaxSemiring,
)


class StructDistribution(Distribution):
Expand Down Expand Up @@ -68,7 +74,6 @@ def entropy(self):
"""
return self.struct(EntropySemiring).sum(self.log_potentials, self.lengths)


@lazy_property
def argmax(self):
r"""
Expand Down
2 changes: 1 addition & 1 deletion torch_struct/semirings.py
@@ -1,6 +1,7 @@
import torch
import torch.distributions


class Semiring:
@classmethod
def size(cls):
Expand Down Expand Up @@ -182,7 +183,6 @@ def sparse_sum(xs, dim=-1):
return xs, xs2
assert False


@staticmethod
def mul(a, b):
a = a.view((k, 1) + a.shape[1:])
Expand Down
22 changes: 16 additions & 6 deletions torch_struct/test_distributions.py
Expand Up @@ -61,6 +61,7 @@ def test_autoregressive(data, seed):
values2[:, 0, torch.arange(n_classes), torch.arange(n_classes)] = values[:, 0]

init = torch.zeros(batch, 5).long()

class Model:
def update_state(self, prev_state, inputs):
K, batch, hidden = prev_state.shape
Expand All @@ -71,25 +72,34 @@ def sequence_logits(self, init, seq_inputs):

def local_logits(self, state):
K, batch, hidden = state.shape
t = state[0,0,0]
t = state[0, 0, 0]
x = values[:, t, :].unsqueeze(0).expand(K, batch, n_classes)
return x

auto = Autoregressive(Model(), init, n_classes, n_length)
v = auto.greedy_argmax()
assert((v == LinearChainCRF(values2).argmax.sum(-1)).all())
assert (v == LinearChainCRF(values2).argmax.sum(-1)).all()
crf = LinearChainCRF(values2)
v2 = auto.beam_topk(K=5)

# print(crf.struct().score(crf.topk(5), values2, batch_dims=[0,1]))
# print(crf.topk(5)[0].nonzero())
# print(crf.topk(5)[1].nonzero())

assert((v2.nonzero() == crf.topk(5).sum(-1).nonzero()).all())
assert((v2[0] == LinearChainCRF(values2).argmax.sum(-1)).all())
assert (v2.nonzero() == crf.topk(5).sum(-1).nonzero()).all()
assert (v2[0] == LinearChainCRF(values2).argmax.sum(-1)).all()

print(auto.log_prob(v, normalize=False))
print(crf.struct().score(crf.argmax, values2))
assert ((auto.log_prob(v, normalize=False) == crf.struct().score(crf.argmax, values2)).all())
assert (
auto.log_prob(v, normalize=False) == crf.struct().score(crf.argmax, values2)
).all()

assert auto.sample((7,)).shape == (7, batch, n_length, n_classes)

assert(auto.sample((7,)).shape == (7, batch, n_length, n_classes))
assert auto.sample_without_replacement((7,)).shape == (
7,
batch,
n_length,
n_classes,
)

0 comments on commit 299c4c6

Please sign in to comment.