Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Fix BoundaryPooling tracing (#713)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #713

Array slicing using the results of a .size() will store the slice index as constant in the trace. In this case we don't even need to tracing to work nicely because we can just record a constant -1 for the slicing.

Reviewed By: liaimi

Differential Revision: D15932989

fbshipit-source-id: 2fc05cf953b091f15ba0bcd97c90ffcc01740aec
  • Loading branch information
bethebunny authored and facebook-github-bot committed Jun 21, 2019
1 parent d994b4a commit 115ba44
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions pytext/models/representations/pooling.py
Expand Up @@ -91,18 +91,17 @@ def __init__(self, config: Config, n_input: int) -> None:
def forward(
self, inputs: torch.Tensor, seq_lengths: torch.Tensor = None
) -> torch.Tensor:
max_len = inputs.size()[1]
if self.boundary_type == "first":
return inputs[:, 0, :]
elif self.boundary_type == "last":
# could only have the bos values if add_bos or add_eos as False
# should not reach here if the eos is not added.
assert max_len > 1
return inputs[:, max_len - 1, :]
assert inputs.size()[1] > 1
return inputs[:, -1, :]
elif self.boundary_type == "firstlast":
assert max_len > 1
assert inputs.size()[1] > 1
# merge from embed_dim into 2*emded_dim
return torch.cat((inputs[:, 0, :], inputs[:, max_len - 1, :]), dim=1)
return torch.cat((inputs[:, 0, :], inputs[:, -1, :]), dim=1)
else:
raise Exception("Unknown configuration type {}".format(self.boundary_type))

Expand Down

0 comments on commit 115ba44

Please sign in to comment.