Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
add a new PyText get_num_examples_from_batch function in model
Browse files Browse the repository at this point in the history
Summary:
Usually in PyText dataloader, each row (sentence) is read as an example, i.e. create List[sentences], and the number of examples in the batch is simply len(List[data]).

This is not correct for language model: in LM, each word is considered as an example.

Reviewed By: kmalik22

Differential Revision: D21032297

fbshipit-source-id: 619e3509e2dadb64431700516e23c4911cc4699a
  • Loading branch information
psuzhanhy authored and facebook-github-bot committed Apr 15, 2020
1 parent f5278f6 commit e491b61
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pytext/models/language_models/lmlstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ def arrange_targets(self, tensor_dict):
tokens, seq_lens, _ = tensor_dict["tokens"]
return (tokens[:, 1:].contiguous(), seq_lens - 1)

def get_num_examples_from_batch(self, batch):
targets = self.arrange_targets(batch)
num_words_in_batch = targets[1].sum().item()
return num_words_in_batch

def get_export_input_names(self, tensorizers):
return ["tokens_vals", "tokens_lens"]

Expand Down
11 changes: 11 additions & 0 deletions pytext/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,17 @@ def arrange_caffe2_model_inputs(self, tensor_dict):
flat_model_inputs.append(model_input)
return flat_model_inputs

def get_num_examples_from_batch(self, batch):
"""
usually, the number of examples in the batch is just number
of rows (len) in the batch, but this may not be true. for
example, in language model, each row has multiple words, and
the total number of examples is the total number of words across
all rows in the batch. Thus LM model needs to override this function.
"""
targets = self.arrange_targets(batch)
return len(targets)


class Model(BaseModel):
"""
Expand Down

0 comments on commit e491b61

Please sign in to comment.