Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Support token inputs from both sides to two-tower document classifier (
Browse files Browse the repository at this point in the history
…#1454)

Summary:
Pull Request resolved: #1454

Added left-side encoder to two-tower document classifier. Now it supports passing text token features from both sides

Reviewed By: sinonwang

Differential Revision: D23487446

fbshipit-source-id: a178238e42311e86d1ad949621ef27eee15f34b7
  • Loading branch information
HannaMao authored and facebook-github-bot committed Sep 9, 2020
1 parent 5f2fbfb commit 9902f48
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 63 deletions.
14 changes: 9 additions & 5 deletions pytext/models/decoders/mlp_decoder_two_tower.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,21 @@ def get_mlp(
return nn.Sequential(*layers)

def forward(self, *x: List[torch.Tensor]) -> torch.Tensor:
# x[0]: tokens, x[1]: right_dense, x[2]: left_dense
assert len(x) == 3
# x[0]: right_text_emb, x[1]: left_text_emb, x[2]: right_dense, x[3]: left_dense
assert len(x) == 4

right_tensor = (
torch.cat((x[0], x[1]), 1).half()
torch.cat((x[0], x[2]), 1).half()
if precision.FP16_ENABLED
else torch.cat((x[0], x[1]), 1).float()
else torch.cat((x[0], x[2]), 1).float()
)
right_output = self.mlp_for_right(right_tensor)

left_tensor = x[2].half() if precision.FP16_ENABLED else x[2].float()
left_tensor = (
torch.cat((x[1], x[3]), 1).half()
if precision.FP16_ENABLED
else torch.cat((x[1], x[3]), 1).float()
)
left_output = self.mlp_for_left(left_tensor)

return self.mlp(torch.cat((right_output, left_output), 1))
Expand Down
93 changes: 51 additions & 42 deletions pytext/models/two_tower_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,54 +39,45 @@ class TwoTowerClassificationModel(BaseModel):

class Config(BaseModel.Config):
class InputConfig(ConfigBase):
tokens: RoBERTaTensorizer.Config = RoBERTaTensorizer.Config()
right_tokens: RoBERTaTensorizer.Config = RoBERTaTensorizer.Config()
left_tokens: RoBERTaTensorizer.Config = RoBERTaTensorizer.Config()
right_dense: FloatListTensorizer.Config = None
left_dense: FloatListTensorizer.Config = None

labels: LabelTensorizer.Config = LabelTensorizer.Config()

inputs: InputConfig = InputConfig()
encoder: RoBERTaEncoderBase.Config = RoBERTaEncoder.Config()
right_encoder: RoBERTaEncoderBase.Config = RoBERTaEncoder.Config()
left_encoder: RoBERTaEncoderBase.Config = RoBERTaEncoder.Config()
decoder: MLPDecoderTwoTower.Config = MLPDecoderTwoTower.Config()
output_layer: ClassificationOutputLayer.Config = (
ClassificationOutputLayer.Config()
)

def trace(self, inputs):
if self.encoder.export_encoder:
return torch.jit.trace(self.encoder, inputs)
else:
return torch.jit.trace(self, inputs)
return torch.jit.trace(self, inputs)

def torchscriptify(self, tensorizers, traced_model):
"""Using the traced model, create a ScriptModule which has a nicer API that
includes generating tensors from simple data types, and returns classified
values according to the output layer (eg. as a dict mapping class name to score)
"""
script_tensorizer = tensorizers["tokens"].torchscriptify()
if self.encoder.export_encoder:
return ScriptPyTextEmbeddingModuleIndex(
traced_model, script_tensorizer, index=0
)
else:
if "right_dense" in tensorizers and "left_dense" in tensorizers:
return ScriptPyTextTwoTowerModuleWithDense(
model=traced_model,
output_layer=self.output_layer.torchscript_predictions(),
tensorizer=script_tensorizer,
right_normalizer=tensorizers["right_dense"].normalizer,
left_normalizer=tensorizers["left_dense"].normalizer,
)
else:
return ScriptPyTextModule(
model=traced_model,
output_layer=self.output_layer.torchscript_predictions(),
tensorizer=script_tensorizer,
)
right_script_tensorizer = tensorizers["right_tokens"].torchscriptify()
left_script_tensorizer = tensorizers["left_tokens"].torchscriptify()

return ScriptPyTextTwoTowerModuleWithDense(
model=traced_model,
output_layer=self.output_layer.torchscript_predictions(),
right_tensorizer=right_script_tensorizer,
left_tensorizer=left_script_tensorizer,
right_normalizer=tensorizers["right_dense"].normalizer,
left_normalizer=tensorizers["left_dense"].normalizer,
)

def arrange_model_inputs(self, tensor_dict):
model_inputs = (
tensor_dict["tokens"],
tensor_dict["right_tokens"],
tensor_dict["left_tokens"],
tensor_dict["right_dense"],
tensor_dict["left_dense"],
)
Expand All @@ -97,14 +88,22 @@ def arrange_targets(self, tensor_dict):
return tensor_dict["labels"]

def forward(
self, encoder_inputs: Tuple[torch.Tensor, ...], *args
self,
right_encoder_inputs: Tuple[torch.Tensor, ...],
left_encoder_inputs: Tuple[torch.Tensor, ...],
*args
) -> List[torch.Tensor]:
if self.encoder.output_encoded_layers:
if self.right_encoder.output_encoded_layers:
# if encoded layers are returned, discard them
representation = self.encoder(encoder_inputs)[1]
right_representation = self.right_encoder(right_encoder_inputs)[1]
else:
representation = self.encoder(encoder_inputs)[0]
return self.decoder(representation, *args)
right_representation = self.right_encoder(right_encoder_inputs)[0]
if self.left_encoder.output_encoded_layers:
# if encoded layers are returned, discard them
left_representation = self.left_encoder(left_encoder_inputs)[1]
else:
left_representation = self.left_encoder(left_encoder_inputs)[0]
return self.decoder(right_representation, left_representation, *args)

def caffe2_export(self, tensorizers, tensor_dict, path, export_onnx_path=None):
raise NotImplementedError
Expand All @@ -115,18 +114,26 @@ def from_config(cls, config: Config, tensorizers: Dict[str, Tensorizer]):
if not labels:
raise ValueError("Labels were not created, see preceding errors")

vocab = tensorizers["tokens"].vocab
encoder = create_module(
config.encoder, padding_idx=vocab.get_pad_index(), vocab_size=len(vocab)
right_vocab = tensorizers["right_tokens"].vocab
right_encoder = create_module(
config.right_encoder,
padding_idx=right_vocab.get_pad_index(),
vocab_size=len(right_vocab),
)
left_vocab = tensorizers["left_tokens"].vocab
left_encoder = create_module(
config.left_encoder,
padding_idx=left_vocab.get_pad_index(),
vocab_size=len(left_vocab),
)

right_dense_dim = tensorizers["right_dense"].dim
left_dense_dim = tensorizers["left_dense"].dim

decoder = create_module(
config.decoder,
right_dim=encoder.representation_dim + right_dense_dim,
left_dim=left_dense_dim,
right_dim=right_encoder.representation_dim + right_dense_dim,
left_dim=left_encoder.representation_dim + left_dense_dim,
to_dim=len(labels),
)

Expand All @@ -146,14 +153,16 @@ def from_config(cls, config: Config, tensorizers: Dict[str, Tensorizer]):
output_layer_cls = MulticlassOutputLayer

output_layer = output_layer_cls(list(labels), loss)
return cls(encoder, decoder, output_layer)
return cls(right_encoder, left_encoder, decoder, output_layer)

def __init__(self, encoder, decoder, output_layer, stage=Stage.TRAIN) -> None:
def __init__(
self, right_encoder, left_encoder, decoder, output_layer, stage=Stage.TRAIN
) -> None:
super().__init__(stage=stage)
self.encoder = encoder
self.right_encoder = right_encoder
self.left_encoder = left_encoder
self.decoder = decoder
self.module_list = [encoder, decoder]
self.module_list = [right_encoder, left_encoder, decoder]
self.output_layer = output_layer
self.stage = stage
self.module_list = [encoder, decoder]
log_class_usage(__class__)
92 changes: 76 additions & 16 deletions pytext/torchscript/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,63 @@ def forward(
return self.output_layer(logits)


class ScriptPyTextTwoTowerModuleWithDense(ScriptPyTextModule):
class ScriptTwoTowerModule(torch.jit.ScriptModule):
@torch.jit.script_method
def set_device(self, device: str):
self.right_tensorizer.set_device(device)
self.left_tensorizer.set_device(device)


class ScriptPyTextTwoTowerModule(ScriptTwoTowerModule):
def __init__(
self,
model: torch.jit.ScriptModule,
output_layer: torch.jit.ScriptModule,
tensorizer: ScriptTensorizer,
right_tensorizer: ScriptTensorizer,
left_tensorizer: ScriptTensorizer,
):
super().__init__()
self.model = model
self.output_layer = output_layer
self.right_tensorizer = right_tensorizer
self.left_tensorizer = left_tensorizer

@torch.jit.script_method
def forward(
self,
right_texts: Optional[List[str]] = None,
left_texts: Optional[List[str]] = None,
right_tokens: Optional[List[List[str]]] = None,
left_tokens: Optional[List[List[str]]] = None,
languages: Optional[List[str]] = None,
):
right_inputs: ScriptBatchInput = ScriptBatchInput(
texts=resolve_texts(right_texts),
tokens=squeeze_2d(right_tokens),
languages=squeeze_1d(languages),
)
right_input_tensors = self.right_tensorizer(right_inputs)
left_inputs: ScriptBatchInput = ScriptBatchInput(
texts=resolve_texts(left_texts),
tokens=squeeze_2d(left_tokens),
languages=squeeze_1d(languages),
)
left_input_tensors = self.left_tensorizer(left_inputs)
logits = self.model(right_input_tensors, left_input_tensors)
return self.output_layer(logits)


class ScriptPyTextTwoTowerModuleWithDense(ScriptPyTextTwoTowerModule):
def __init__(
self,
model: torch.jit.ScriptModule,
output_layer: torch.jit.ScriptModule,
right_tensorizer: ScriptTensorizer,
left_tensorizer: ScriptTensorizer,
right_normalizer: VectorNormalizer,
left_normalizer: VectorNormalizer,
):
super().__init__(model, output_layer, tensorizer)
super().__init__(model, output_layer, right_tensorizer, left_tensorizer)
self.right_normalizer = right_normalizer
self.left_normalizer = left_normalizer

Expand All @@ -110,27 +157,40 @@ def forward(
self,
right_dense_feat: List[List[float]],
left_dense_feat: List[List[float]],
texts: Optional[List[str]] = None,
# multi_texts is of shape [batch_size, num_columns]
multi_texts: Optional[List[List[str]]] = None,
tokens: Optional[List[List[str]]] = None,
right_texts: Optional[List[str]] = None,
left_texts: Optional[List[str]] = None,
right_tokens: Optional[List[List[str]]] = None,
left_tokens: Optional[List[List[str]]] = None,
languages: Optional[List[str]] = None,
):
inputs: ScriptBatchInput = ScriptBatchInput(
texts=resolve_texts(texts, multi_texts),
tokens=squeeze_2d(tokens),
right_inputs: ScriptBatchInput = ScriptBatchInput(
texts=resolve_texts(right_texts),
tokens=squeeze_2d(right_tokens),
languages=squeeze_1d(languages),
)
input_tensors = self.tensorizer(inputs)
right_input_tensors = self.right_tensorizer(right_inputs)
left_inputs: ScriptBatchInput = ScriptBatchInput(
texts=resolve_texts(left_texts),
tokens=squeeze_2d(left_tokens),
languages=squeeze_1d(languages),
)
left_input_tensors = self.left_tensorizer(left_inputs)

right_dense_feat = self.right_normalizer.normalize(right_dense_feat)
left_dense_feat = self.left_normalizer.normalize(left_dense_feat)

right_dense_tensor = torch.tensor(right_dense_feat, dtype=torch.float)
left_dense_tensor = torch.tensor(left_dense_feat, dtype=torch.float)
if self.tensorizer.device != "":
right_dense_tensor = right_dense_tensor.to(self.tensorizer.device)
left_dense_tensor = left_dense_tensor.to(self.tensorizer.device)
logits = self.model(input_tensors, right_dense_tensor, left_dense_tensor)
if self.right_tensorizer.device != "":
right_dense_tensor = right_dense_tensor.to(self.right_tensorizer.device)
if self.left_tensorizer.device != "":
left_dense_tensor = left_dense_tensor.to(self.left_tensorizer.device)

logits = self.model(
right_input_tensors,
left_input_tensors,
right_dense_tensor,
left_dense_tensor,
)
return self.output_layer(logits)


Expand Down

0 comments on commit 9902f48

Please sign in to comment.