Skip to content

Commit

Permalink
Merge 8eec68a into 034e8a2
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Oct 28, 2019
2 parents 034e8a2 + 8eec68a commit 5ee7175
Show file tree
Hide file tree
Showing 10 changed files with 1,942 additions and 226 deletions.
2 changes: 1 addition & 1 deletion docs/source/index.rst
Expand Up @@ -9,7 +9,7 @@ PyTorch-Struct
README
model
networks
advanced
semiring
refs


Expand Down
807 changes: 745 additions & 62 deletions docs/source/model.ipynb

Large diffs are not rendered by default.

733 changes: 733 additions & 0 deletions docs/source/semiring.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -2,7 +2,7 @@

setup(
name="torch_struct",
version="0.2",
version="0.3",
author="Alexander Rush",
author_email="arush@cornell.edu",
packages=["torch_struct", "torch_struct.data", "torch_struct.networks"],
Expand Down
5 changes: 4 additions & 1 deletion torch_struct/__init__.py
Expand Up @@ -8,8 +8,9 @@
TreeCRF,
SentCFG,
AlignmentCRF,
HMM,
)
from .autoregressive import Autoregressive
from .autoregressive import Autoregressive, AutoregressiveModel
from .cky_crf import CKY_CRF
from .deptree import DepTree
from .linearchain import LinearChain
Expand Down Expand Up @@ -48,12 +49,14 @@
SelfCritical,
StructDistribution,
Autoregressive,
AutoregressiveModel,
LinearChainCRF,
SemiMarkovCRF,
DependencyCRF,
NonProjectiveDependencyCRF,
TreeCRF,
SentCFG,
HMM,
AlignmentCRF,
Alignment,
]
202 changes: 145 additions & 57 deletions torch_struct/autoregressive.py
Expand Up @@ -3,48 +3,33 @@
from torch.distributions.distribution import Distribution


class AutoregressiveModel:
class AutoregressiveModel(torch.nn.Module):
"""
User should implement as their favorite RNN / Transformer / etc.
"""

def sequence_logits(self, init, seq_inputs):
"""
Compute the logits for all tokens in a batched sequence :math:`p(y_1, ... y_{T})`
def forward(self, inputs, state=None):
r"""
Compute the logits for all tokens in a batched sequence :math:`p(y_{t+1}, ... y_{T}| y_1 \ldots t)`
Parameters:
init (batch_size x hidden_shape): everything needed for conditioning.
inputs (batch_size x N x C): next tokens to update representation
state (tuple of batch_size x ...): everything needed for conditioning.
Retuns:
logits (batch_size x C): next set of logits.
"""
pass

def local_logits(self, state):
"""
Compute the local logits of :math:`p(y_t | y_{1:t-1})`
Parameters:
state (batch_size x hidden_shape): everything needed for conditioning.
logits (*batch_size x C*): next set of logits.
Retuns:
logits (batch_size x C): next set of logits.
state (*tuple of batch_size x ...*): next set of logits.
"""
pass

def update_state(self, prev_state, inputs):
"""
Update the model state based on previous state and inputs

Parameters:
prev_state (batch_size x hidden_shape): everything needed for conditioning.
inputs (batch_size x C): next tokens to update representation
def wrap(state, ssize):
return state.contiguous().view(ssize, -1, *state.shape[1:])

Retuns:
state (batch_size x hidden_shape): everything needed for next conditioning.
"""
pass

def unwrap(state):
return state.contiguous().view(-1, *state.shape[2:])


class Autoregressive(Distribution):
Expand All @@ -56,60 +41,127 @@ class Autoregressive(Distribution):
Parameters:
model (AutoregressiveModel): A lazily computed autoregressive model.
init (tensor, batch_shape x hidden_shape): initial state of autoregressive model.
init (tuple of tensors, batch_shape x ...): initial state of autoregressive model.
n_classes (int): number of classes in each time step
n_length (int): max length of sequence
"""

def __init__(self, model, init, n_classes, n_length):
def __init__(
self,
model,
initial_state,
n_classes,
n_length,
normalize=True,
start_class=0,
end_class=None,
):
self.model = model
self.init = init
self.init = initial_state
self.n_length = n_length
self.n_classes = n_classes
self.start_class = start_class
self.normalize = normalize
event_shape = (n_length, n_classes)
batch_shape = init.shape[:1]
batch_shape = initial_state[0].shape[:1]
super().__init__(batch_shape=batch_shape, event_shape=event_shape)

def log_prob(self, value, normalize=True):
def log_prob(self, value, sparse=False):
"""
Compute log probability over values :math:`p(z)`.
Parameters:
value (tensor): One-hot events (*sample_shape x batch_shape x event_shape*)
value (tensor): One-hot events (*sample_shape x batch_shape x N*)
Returns:
log_probs (*sample_shape x batch_shape*)
"""
batch_shape, n_length, n_classes = value.shape
value = value.long()
logits = self.model.sequence_logits(self.init, value)
if normalize:
if not sparse:
sample, batch_shape, n_length, n_classes = value.shape
value = (
(value * torch.arange(n_classes).view(1, 1, n_classes)).sum(-1).long()
)
else:
sample, batch_shape, n_length = value.shape

value = torch.cat(
[torch.zeros(sample, batch_shape, 1).fill_(self.start_class).long(), value],
dim=2,
)
value = unwrap(value)
state = tuple(
(unwrap(i.unsqueeze(0).expand((sample,) + i.shape)) for i in self.init)
)

logits, _ = self.model(value, state)
b2, n2, c2 = logits.shape
assert (
(b2 == sample * batch_shape)
and (n2 == n_length + 1)
and (c2 == self.n_classes)
), "Model should return logits of shape `batch x N x C` "

if self.normalize:
log_probs = logits.log_softmax(-1)
else:
log_probs = logits

# batch_shape x event_shape (N x C)
return log_probs.masked_fill_(value == 0, 0).sum(-1).sum(-1)
scores = log_probs[:, :-1].gather(2, value[:, 1:].unsqueeze(-1)).sum(-1).sum(-1)
return wrap(scores, sample)

def _beam_search(self, semiring, gumbel=True):
def _beam_search(self, semiring, gumbel=False):
beam = semiring.one_(torch.zeros((semiring.size(),) + self.batch_shape))
state = self.init.unsqueeze(0).expand((semiring.size(),) + self.init.shape)
ssize = semiring.size()

def take(state, indices):
return tuple(
(
s.contiguous()[
(
indices * self.batch_shape[0]
+ torch.arange(self.batch_shape[0]).unsqueeze(0)
)
.contiguous()
.view(-1)
]
for s in state
)
)

tokens = (
torch.zeros((ssize * self.batch_shape[0])).long().fill_(self.start_class)
)
state = tuple(
(unwrap(i.unsqueeze(0).expand((ssize,) + i.shape)) for i in self.init)
)

# Beam Search
all_beams = []
for t in range(0, self.n_length):
logits = self.model.local_logits(state)
logits, state = self.model(unwrap(tokens).unsqueeze(1), state)
b2, n2, c2 = logits.shape
assert (
(b2 == ssize * self.batch_shape[0])
and (n2 == 1)
and (c2 == self.n_classes)
), "Model should return logits of shape `batch x N x C` "
for s in state:
assert (
s.shape[0] == ssize * self.batch_shape[0]
), "Model should return state tuple with shapes `batch x ...` "
logits = wrap(logits.squeeze(1), ssize)
if gumbel:
logits = logits + torch.distributions.Gumbel(0.0, 0.0).sample(
logits = logits + torch.distributions.Gumbel(0.0, 1.0).sample(
logits.shape
)

if self.normalize:
logits = logits.log_softmax(-1)
ex_beam = beam.unsqueeze(-1) + logits
ex_beam.requires_grad_(True)
all_beams.append(ex_beam)
beam, tokens = semiring.sparse_sum(ex_beam)
state = self.model.update_state(state, tokens)
beam, (positions, tokens) = semiring.sparse_sum(ex_beam)
state = take(state, positions)

# Back pointers
v = beam
Expand All @@ -121,43 +173,79 @@ def _beam_search(self, semiring, gumbel=True):
)
marg = torch.stack(marg, dim=2)
all_m.append(marg.sum(0))
return torch.stack(all_m, dim=0)
return torch.stack(all_m, dim=0), v

def greedy_argmax(self):
"""
Compute "argmax" using greedy search
Compute "argmax" using greedy search.
Returns:
greedy_path (*batch x N x C*)
"""
return self._beam_search(MaxSemiring).squeeze(0)
return self._beam_search(MaxSemiring)[0].squeeze(0)

def _greedy_max(self):
return self._beam_search(MaxSemiring)[1].squeeze(0)

def beam_topk(self, K):
"""
Compute "top-k" using beam search
Returns:
paths (*K x batch x N x C*)
"""
return self._beam_search(KMaxSemiring(K))
return self._beam_search(KMaxSemiring(K))[0]

def _beam_max(self, K):
return self._beam_search(KMaxSemiring(K))[1]

def sample_without_replacement(self, sample_shape=torch.Size()):
"""
Compute sampling without replacement using Gumbel trick.
Based on:
* Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for
Sampling Sequences Without Replacement :cite:`DBLP:journals/corr/abs-1903-06059`
Parameters:
sample_shape (torch.Size): batch_size
Returns:
paths (*K x batch x N x C*)
"""
K = sample_shape[0]
return self._beam_search(KMaxSemiring(K), gumbel=True)
return self._beam_search(KMaxSemiring(K), gumbel=True)[0]

def sample(self, sample_shape=torch.Size()):
r"""
Compute structured samples from the distribution :math:`z \sim p(z)`.
Parameters:
sample_shape (int): number of samples
sample_shape (torch.Size): number of samples
Returns:
samples (*sample_shape x batch_shape x event_shape*)
"""
sample_shape = sample_shape[0]
state = self.init.unsqueeze(0).expand((sample_shape,) + self.init.shape)
state = tuple(
(
unwrap(i.unsqueeze(0).expand((sample_shape,) + i.shape))
for i in self.init
)
)
all_tokens = []
tokens = (
torch.zeros((sample_shape * self.batch_shape[0]))
.long()
.fill_(self.start_class)
)
for t in range(0, self.n_length):
logits = self.model.local_logits(state)
tokens = torch.distributions.OneHotCategorical(logits).sample((1,))[0]
state = self.model.update_state(state, tokens)
logits, state = self.model(tokens.unsqueeze(-1), state)
logits = logits.squeeze(1)
tokens = torch.distributions.Categorical(logits=logits).sample((1,))[0]
all_tokens.append(tokens)
return torch.stack(all_tokens, dim=2)
v = wrap(torch.stack(all_tokens, dim=1), sample_shape)
return torch.nn.functional.one_hot(v, self.n_classes)

0 comments on commit 5ee7175

Please sign in to comment.