From c92d8919bd66ad364721b9f808cd8dac8644064d Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Sat, 28 May 2022 21:34:01 -0700 Subject: [PATCH] Added explicit encode, combine, decode functions to ECD (#2073) --- ludwig/models/ecd.py | 74 ++++++++++++++++++++++++++------------------ 1 file changed, 44 insertions(+), 30 deletions(-) diff --git a/ludwig/models/ecd.py b/ludwig/models/ecd.py index f29ef7ddd8f..a1a09230bee 100644 --- a/ludwig/models/ecd.py +++ b/ludwig/models/ecd.py @@ -103,39 +103,12 @@ def input_shape(self): # TODO(justin): Remove dummy implementation. Make input_shape and output_shape functions. return torch.Size([1, 1]) - def forward( + def encode( self, inputs: Union[ Dict[str, torch.Tensor], Dict[str, np.ndarray], Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]] ], - mask=None, - ) -> Dict[str, torch.Tensor]: - """Forward pass of the model. - - Args: - inputs: Inputs to the model. Can be a dictionary of input names to - input tensors or a tuple of (inputs, targets) where inputs is - a dictionary of input names to input tensors and targets is a - dictionary of target names to target tensors. - mask: A mask for the inputs. - - Returns: - A dictionary of output {feature name}::{tensor_name} -> output tensor. - """ - - if isinstance(inputs, tuple): - inputs, targets = inputs - # Convert targets to tensors. - for target_feature_name, target_value in targets.items(): - if not isinstance(target_value, torch.Tensor): - targets[target_feature_name] = torch.from_numpy(target_value) - else: - targets[target_feature_name] = target_value - else: - targets = None - - assert list(inputs.keys()) == self.input_features.keys() - + ): # Convert inputs to tensors. for input_feature_name, input_values in inputs.items(): if not isinstance(input_values, torch.Tensor): @@ -149,8 +122,12 @@ def forward( encoder_output = encoder(input_values) encoder_outputs[input_feature_name] = encoder_output - combiner_outputs = self.combiner(encoder_outputs) + return encoder_outputs + + def combine(self, encoder_outputs): + return self.combiner(encoder_outputs) + def decode(self, combiner_outputs, targets, mask): # Invoke output features. output_logits = {} output_last_hidden = {} @@ -169,6 +146,43 @@ def forward( output_last_hidden[output_feature_name] = decoder_outputs["last_hidden"] return output_logits + def forward( + self, + inputs: Union[ + Dict[str, torch.Tensor], Dict[str, np.ndarray], Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]] + ], + mask=None, + ) -> Dict[str, torch.Tensor]: + """Forward pass of the model. + + Args: + inputs: Inputs to the model. Can be a dictionary of input names to + input tensors or a tuple of (inputs, targets) where inputs is + a dictionary of input names to input tensors and targets is a + dictionary of target names to target tensors. + mask: A mask for the inputs. + + Returns: + A dictionary of output {feature name}::{tensor_name} -> output tensor. + """ + + if isinstance(inputs, tuple): + inputs, targets = inputs + # Convert targets to tensors. + for target_feature_name, target_value in targets.items(): + if not isinstance(target_value, torch.Tensor): + targets[target_feature_name] = torch.from_numpy(target_value) + else: + targets[target_feature_name] = target_value + else: + targets = None + + assert list(inputs.keys()) == self.input_features.keys() + + encoder_outputs = self.encode(inputs) + combiner_outputs = self.combine(encoder_outputs) + return self.decode(combiner_outputs, targets, mask) + def predictions(self, inputs): outputs = self(inputs) predictions = {}