Skip to content

Commit

Permalink
Merge 864314b into ad2ca2e
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Nov 1, 2019
2 parents ad2ca2e + 864314b commit 730c88a
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 4 deletions.
8 changes: 8 additions & 0 deletions docs/source/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,11 @@ @article{DBLP:journals/corr/abs-1903-06059
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1903-06059},
bibsource = {dblp computer science bibliography, https://dblp.org}
}


@article{goyal2017differentiable,
title={Differentiable scheduled sampling for credit assignment},
author={Goyal, Kartik and Dyer, Chris and Berg-Kirkpatrick, Taylor},
journal={arXiv preprint arXiv:1704.06970},
year={2017}
}
17 changes: 15 additions & 2 deletions torch_struct/autoregressive.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from .semirings import MaxSemiring, KMaxSemiring
from .semirings import MaxSemiring, KMaxSemiring, TempMax
from torch.distributions.distribution import Distribution


Expand All @@ -13,7 +13,7 @@ def forward(self, inputs, state=None):
Compute the logits for all tokens in a batched sequence :math:`p(y_{t+1}, ... y_{T}| y_1 \ldots t)`
Parameters:
inputs (batch_size x N x C): next tokens to update representation
inputs (batch_size x N x C ): next tokens to update representation
state (tuple of batch_size x ...): everything needed for conditioning.
Retuns:
Expand Down Expand Up @@ -187,6 +187,19 @@ def greedy_argmax(self):
def _greedy_max(self):
return self._beam_search(MaxSemiring)[1].squeeze(0)

def greedy_tempmax(self, alpha):
"""
Compute differentiable scheduled sampling using greedy search.
Based on:
* Differentiable Scheduled Sampling for Credit Assignment :cite:`goyal2017differentiable`
Returns:
greedy_path (*batch x N x C*)
"""
return self._beam_search(TempMax(alpha), alpha)[0].squeeze(0)

def beam_topk(self, K):
"""
Compute "top-k" using beam search
Expand Down
19 changes: 19 additions & 0 deletions torch_struct/semirings.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,25 @@ def sparse_sum(xs, dim=-1):
return m, (torch.zeros(a.shape).long(), a)


def TempMax(alpha):
class _TempMax(_BaseLog):
"""
Implements a max forward, hot softmax backward.
"""

@staticmethod
def sum(xs, dim=-1):
pass

@staticmethod
def sparse_sum(xs, dim=-1):
m, _ = torch.max(xs, dim=dim)
a = torch.softmax(alpha * xs, dim)
return m, (torch.zeros(a.shape[:-1]).long(), a)

return _TempMax


def KMaxSemiring(k):
"""
Implements the k-max semiring (kmax, +, [-inf, -inf..], [0, -inf, ...]).
Expand Down
14 changes: 12 additions & 2 deletions torch_struct/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,31 @@ def t(a):
init = (torch.zeros(batch, layer, H),)

class AR(torch.nn.Module):
def __init__(self):
def __init__(self, sparse=True):
super().__init__()
self.sparse = sparse
self.rnn = torch.nn.RNN(H, H, batch_first=True)
self.proj = torch.nn.Linear(H, C)
self.embed = torch.nn.Embedding(C, H)
if sparse:
self.embed = torch.nn.Embedding(C, H)
else:
self.embed = torch.nn.Linear(C, H)

def forward(self, inputs, state):
if not self.sparse and inputs.dim() == 2:
inputs = torch.nn.functional.one_hot(inputs, C).float()
inputs = self.embed(inputs)
out, state = self.rnn(inputs, t(state)[0])
out = self.proj(out)
return out, t((state,))

dist2 = Autoregressive(AR(sparse=False), init, C, N, normalize=False)
path = dist2.greedy_tempmax(1)

dist = Autoregressive(AR(), init, C, N, normalize=False)
scores = dist._greedy_max()
path = dist.greedy_argmax()

assert torch.isclose(scores, dist.log_prob(path.unsqueeze(0))).all()
scores = dist._beam_max(7)
path = dist.beam_topk(7)
Expand Down

0 comments on commit 730c88a

Please sign in to comment.