Skip to content
This repository has been archived by the owner. It is now read-only.
Permalink
Browse files
fix(encoder): fix unused variable
  • Loading branch information
raccoonliukai committed Aug 13, 2019
1 parent 732f2e6 commit 5fedf6dffccf881345986142fc3afc586ae38fec
Showing with 3 additions and 4 deletions.
  1. +3 −4 gnes/encoder/text/torch_transformers.py
@@ -67,7 +67,7 @@ def encode(self, text: List[str], *args, **kwargs) -> np.ndarray:
batch_data = np.zeros([batch_size, max_len], dtype=np.int64)
# batch_mask = np.zeros([batch_size, max_len], dtype=np.float32)
for i, ids in enumerate(tokens_ids):
batch_data[i, :tokens_lens[i]] = tokens_ids[i]
batch_data[i, :tokens_lens[i]] = ids
# batch_mask[i, :tokens_lens[i]] = 1

# Convert inputs to PyTorch tensors
@@ -85,8 +85,7 @@ def encode(self, text: List[str], *args, **kwargs) -> np.ndarray:
with torch.no_grad():
out_tensor = self.model(tokens_tensor)[0]
out_tensor = torch.mul(out_tensor, mask_tensor.unsqueeze(2))

if self.use_cuda:
output_tensor = output_tensor.cpu()
if self.use_cuda:
out_tensor = out_tensor.cpu()

return out_tensor.numpy()

0 comments on commit 5fedf6d

Please sign in to comment.