Skip to content

Commit

Permalink
Fix pre-terminal rule prob. in NeuralCFG (#54)
Browse files Browse the repository at this point in the history
* minimize the CKY for debugging

* add tests for the CKY

* fix formatting issues

* fix pre-terminal rule prob. in NeuralCFG
  • Loading branch information
zhaoyanpeng committed Mar 8, 2020
1 parent 67aa60d commit 0e47fe9
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions torch_struct/networks/NeuralCFG.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,17 @@ def forward(self, input):
T, NT = self.T, self.NT

def terms(words):
return torch.einsum(
"bnh,th->bnt", self.word_emb[words], self.mlp1(self.term_emb)
).log_softmax(-2)
b, n = input.shape[:2]
term_prob = (
torch.einsum("vh,th->tv", self.word_emb, self.mlp1(self.term_emb))
.log_softmax(-1)
.unsqueeze(0)
.unsqueeze(0)
.expand(b, n, self.T, self.V)
)
indices = input.unsqueeze(2).expand(b, n, self.T).unsqueeze(3)
term_prob = torch.gather(term_prob, 3, indices).squeeze(3)
return term_prob

def rules(b):
return (
Expand Down

0 comments on commit 0e47fe9

Please sign in to comment.