Skip to content

Commit

Permalink
🐛 swap syntactic evaluation decoder order
Browse files Browse the repository at this point in the history
  • Loading branch information
jumelet committed Jun 22, 2022
1 parent a4dbb77 commit 5e9a067
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions diagnnose/syntax/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,17 @@ def _dual_context_accuracy(
def _decode(
self, activations: Tensor, token_ids: Optional[Tensor] = None
) -> Tensor:
if hasattr(self.model, "decoder_w"):
if hasattr(self.model, "decoder"):
# Transformers
logits = getattr(self.model, "decoder")(activations)

if token_ids is not None:
batch_size = logits.size(0)
logits = logits[range(batch_size), token_ids]

return logits
elif hasattr(self.model, "decoder_w"):
# LSTMs
decoder_w = self.model.decoder_w
decoder_b = self.model.decoder_b
if token_ids is None:
Expand All @@ -285,13 +295,5 @@ def _decode(
logits += decoder_b[token_ids]

return logits
elif hasattr(self.model, "decoder"):
logits = getattr(self.model, "decoder")(activations)

if token_ids is not None:
batch_size = logits.size(0)
logits = logits[range(batch_size), token_ids]

return logits
else:
raise AttributeError("Model decoder not found")

0 comments on commit 5e9a067

Please sign in to comment.