Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Enable dense features in ByteTokensDocumentModel #763

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 40 additions & 4 deletions pytext/models/doc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,9 @@ def create_embedding(cls, config, tensorizers: Dict[str, Tensorizer]):
assert word_tensorizer.column == byte_tensorizer.column

word_embedding = create_module(
config.embedding, tensorizer=tensorizers["tokens"]
config.embedding,
tensorizer=tensorizers["tokens"],
init_from_saved_state=config.init_from_saved_state,
)
byte_embedding = CharacterEmbedding(
ByteTokenTensorizer.NUM_BYTES,
Expand All @@ -241,10 +243,16 @@ def arrange_model_inputs(self, tensor_dict):
tokens, seq_lens, _ = tensor_dict["tokens"]
token_bytes, byte_seq_lens, _ = tensor_dict["token_bytes"]
assert (seq_lens == byte_seq_lens).all().item()
return tokens, token_bytes, seq_lens
model_inputs = tokens, token_bytes, seq_lens
if "dense" in tensor_dict:
model_inputs += (tensor_dict["dense"],)
return model_inputs

def get_export_input_names(self, tensorizers):
return ["tokens", "token_bytes", "tokens_lens"]
names = ["tokens", "token_bytes", "tokens_lens"]
if "dense" in tensorizers:
names.append("float_vec_vals")
return names

def torchscriptify(self, tensorizers, traced_model):
output_layer = self.output_layer.torchscript_predictions()
Expand Down Expand Up @@ -277,7 +285,35 @@ def forward(self, tokens: List[List[str]]):
)
return self.output_layer(logits)

return Model()
class ModelWithDenseFeat(jit.ScriptModule):
def __init__(self):
super().__init__()
self.vocab = Vocabulary(input_vocab, unk_idx=input_vocab.idx[UNK])
self.max_byte_len = jit.Attribute(max_byte_len, int)
self.byte_offset_for_non_padding = jit.Attribute(
byte_offset_for_non_padding, int
)
self.pad_idx = jit.Attribute(input_vocab.idx[PAD], int)
self.model = traced_model
self.output_layer = output_layer

@jit.script_method
def forward(self, tokens: List[List[str]], dense_feat: List[List[float]]):
seq_lens = make_sequence_lengths(tokens)
word_ids = self.vocab.lookup_indices_2d(tokens)
word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
token_bytes, _ = make_byte_inputs(
tokens, self.max_byte_len, self.byte_offset_for_non_padding
)
logits = self.model(
torch.tensor(word_ids),
token_bytes,
torch.tensor(seq_lens),
torch.tensor(dense_feat),
)
return self.output_layer(logits)

return ModelWithDenseFeat() if "dense" in tensorizers else Model()


class DocRegressionModel(DocModel):
Expand Down