Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Oct 23, 2019
1 parent b770f21 commit bf3c560
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 0 deletions.
4 changes: 4 additions & 0 deletions torch_struct/__init__.py
Expand Up @@ -8,6 +8,9 @@
TreeCRF,
SentCFG,
)
from .autoregressive import (
Autoregressive
)
from .cky_crf import CKY_CRF
from .deptree import DepTree
from .linearchain import LinearChain
Expand Down Expand Up @@ -42,6 +45,7 @@
MultiSampledSemiring,
SelfCritical,
StructDistribution,
Autoregressive,
LinearChainCRF,
SemiMarkovCRF,
DependencyCRF,
Expand Down
75 changes: 75 additions & 0 deletions torch_struct/autoregressive.py
@@ -0,0 +1,75 @@
import torch
from .semirings import MaxSemiring
from torch.distributions.distribution import Distribution
from torch.distributions.utils import lazy_property

class Autoregressive(Distribution):
"""
Autoregressive sequence model utilizing beam search.
* batch_shape -> Given by initializer
* event_shape -> N x T sequence of choices
Parameters:
model:
init (tensor, batch_shape x hidden_shape):
n_classes (int): number of classes in each time step
n_length (int): max length of sequence
"""

def __init__(self, model, init, n_classes, n_length):
self.model = model
self.init = init
self.n_length = n_length
self.n_classes = n_classes
event_shape = (n_length, n_classes)
batch_shape = start_state.shape[:1]
super().__init__(batch_shape=batch_shape, event_shape=event_shape)


def log_prob(self, value):
"""
Compute log probability over values :math:`p(z)`.
Parameters:
value (tensor): One-hot events (*sample_shape x batch_shape x event_shape*)
Returns:
log_probs (*sample_shape x batch_shape*)
"""
logits = self.model.sequence_logits(self.init, value)
# batch_shape x event_shape (N x C)
log_probs = logits.log_softmax(-1)
positions = torch.arange(self.n_length)
return log_probs[:, positions, value[positions]].sum(-1)

def _beam_search(self, semiring):
# beam size
beam = semiring.ones_(
torch.zeros(semiring.size(), self.batch_shape))
state = self.init.unsqueeze(0).expand((semiring.size(),) + self.init.shape)
for t in range(0, self.n_length):
logits = self.model.logits(init)
# ssize x batch_size x C
beam = semiring.times(beam.unsqueeze(-1), logits.log_softmax(-1))
# ssize x batch_size x C
beam, backpointers = semiring.sparse_sum(beam)
# ssize x batch_size
state = self.model.update_state(state, backpointers)
return beam

def greedy_max(self):
return _beam_search(self, MaxSemiring)

def sample(self, sample_shape=torch.Size()):
r"""
Compute structured samples from the distribution :math:`z \sim p(z)`.
Parameters:
sample_shape (int): number of samples
Returns:
samples (*sample_shape x batch_shape x event_shape*)
"""
pass
14 changes: 14 additions & 0 deletions torch_struct/semirings.py
Expand Up @@ -129,6 +129,8 @@ class MaxSemiring(_BaseLog):
def sum(xs, dim=-1):
return torch.max(xs, dim=dim)[0]

def sparse_sum(xs, dim=-1):
return torch.max(xs, dim=dim)

def KMaxSemiring(k):
class KMaxSemiring(_BaseLog):
Expand Down Expand Up @@ -167,6 +169,18 @@ def sum(xs, dim=-1):
return xs
assert False

def sparse_sum(xs, dim=-1):
if dim == -1:
xs = xs.permute(tuple(range(1, xs.dim())) + (0,))
xs = xs.contiguous().view(xs.shape[:-2] + (-1,))
xs, xs2 = torch.topk(xs, k, dim=-1)
xs = xs.permute((xs.dim() - 1,) + tuple(range(0, xs.dim() - 1)))
xs, xs2 = xs.permute((xs.dim() - 1,) + tuple(range(0, xs.dim() - 1)))
assert xs.shape[0] == k
return xs, xs2
assert False


@staticmethod
def mul(a, b):
a = a.view((k, 1) + a.shape[1:])
Expand Down

0 comments on commit bf3c560

Please sign in to comment.