Skip to content

Commit

Permalink
Merge 4da7f41 into d828506
Browse files Browse the repository at this point in the history
  • Loading branch information
da03 committed Jan 29, 2020
2 parents d828506 + 4da7f41 commit df7e77a
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions torch_struct/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
self.normalize = normalize
event_shape = (n_length, n_classes)
batch_shape = initial_state[0].shape[:1]
self.device = initial_state[0].device
super().__init__(batch_shape=batch_shape, event_shape=event_shape)

def log_prob(self, value, sparse=False):
Expand Down Expand Up @@ -116,7 +117,7 @@ def log_prob(self, value, sparse=False):
return wrap(scores, sample)

def _beam_search(self, semiring, gumbel=False):
beam = semiring.one_(torch.zeros((semiring.size(),) + self.batch_shape))
beam = semiring.one_(torch.zeros((semiring.size(),) + self.batch_shape, device=self.device))
ssize = semiring.size()

def take(state, indices):
Expand All @@ -125,7 +126,7 @@ def take(state, indices):
s.contiguous()[
(
indices * self.batch_shape[0]
+ torch.arange(self.batch_shape[0]).unsqueeze(0)
+ torch.arange(self.batch_shape[0], device=self.device).unsqueeze(0)
)
.contiguous()
.view(-1)
Expand Down

0 comments on commit df7e77a

Please sign in to comment.