Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Oct 24, 2019
1 parent 299c4c6 commit c95f024
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 5 deletions.
4 changes: 0 additions & 4 deletions torch_struct/autoregressive.py
@@ -1,7 +1,6 @@
import torch
from .semirings import MaxSemiring, KMaxSemiring
from torch.distributions.distribution import Distribution
from torch.distributions.utils import lazy_property


class AutoregressiveModel:
Expand Down Expand Up @@ -91,8 +90,6 @@ def log_prob(self, value, normalize=True):
log_probs = logits

# batch_shape x event_shape (N x C)
positions = torch.arange(self.n_length)
batch = torch.arange(batch_shape)
return log_probs.masked_fill_(value == 0, 0).sum(-1).sum(-1)

def _beam_search(self, semiring, gumbel=True):
Expand Down Expand Up @@ -156,7 +153,6 @@ def sample(self, sample_shape=torch.Size()):
samples (*sample_shape x batch_shape x event_shape*)
"""
sample_shape = sample_shape[0]
beam = torch.zeros((sample_shape,) + self.batch_shape)
state = self.init.unsqueeze(0).expand((sample_shape,) + self.init.shape)
all_tokens = []
for t in range(0, self.n_length):
Expand Down
2 changes: 2 additions & 0 deletions torch_struct/semirings.py
Expand Up @@ -130,6 +130,7 @@ class MaxSemiring(_BaseLog):
def sum(xs, dim=-1):
return torch.max(xs, dim=dim)[0]

@staticmethod
def sparse_sum(xs, dim=-1):
return torch.max(xs, dim=dim)

Expand Down Expand Up @@ -172,6 +173,7 @@ def sum(xs, dim=-1):
return xs
assert False

@staticmethod
def sparse_sum(xs, dim=-1):
if dim == -1:
xs = xs.permute(tuple(range(1, xs.dim())) + (0,))
Expand Down
1 change: 0 additions & 1 deletion torch_struct/test_distributions.py
Expand Up @@ -49,7 +49,6 @@ def test_simple(data, seed):
@given(data(), integers(min_value=1, max_value=20))
@settings(max_examples=50, deadline=None)
def test_autoregressive(data, seed):
model = Autoregressive
n_classes = 2
n_length = 5
batch = 3
Expand Down

0 comments on commit c95f024

Please sign in to comment.