Skip to content

Commit

Permalink
Added explicit encode, combine, decode functions to ECD (#2073)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed May 29, 2022
1 parent 59b2d1b commit c92d891
Showing 1 changed file with 44 additions and 30 deletions.
74 changes: 44 additions & 30 deletions ludwig/models/ecd.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,39 +103,12 @@ def input_shape(self):
# TODO(justin): Remove dummy implementation. Make input_shape and output_shape functions.
return torch.Size([1, 1])

def forward(
def encode(
self,
inputs: Union[
Dict[str, torch.Tensor], Dict[str, np.ndarray], Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]
],
mask=None,
) -> Dict[str, torch.Tensor]:
"""Forward pass of the model.
Args:
inputs: Inputs to the model. Can be a dictionary of input names to
input tensors or a tuple of (inputs, targets) where inputs is
a dictionary of input names to input tensors and targets is a
dictionary of target names to target tensors.
mask: A mask for the inputs.
Returns:
A dictionary of output {feature name}::{tensor_name} -> output tensor.
"""

if isinstance(inputs, tuple):
inputs, targets = inputs
# Convert targets to tensors.
for target_feature_name, target_value in targets.items():
if not isinstance(target_value, torch.Tensor):
targets[target_feature_name] = torch.from_numpy(target_value)
else:
targets[target_feature_name] = target_value
else:
targets = None

assert list(inputs.keys()) == self.input_features.keys()

):
# Convert inputs to tensors.
for input_feature_name, input_values in inputs.items():
if not isinstance(input_values, torch.Tensor):
Expand All @@ -149,8 +122,12 @@ def forward(
encoder_output = encoder(input_values)
encoder_outputs[input_feature_name] = encoder_output

combiner_outputs = self.combiner(encoder_outputs)
return encoder_outputs

def combine(self, encoder_outputs):
return self.combiner(encoder_outputs)

def decode(self, combiner_outputs, targets, mask):
# Invoke output features.
output_logits = {}
output_last_hidden = {}
Expand All @@ -169,6 +146,43 @@ def forward(
output_last_hidden[output_feature_name] = decoder_outputs["last_hidden"]
return output_logits

def forward(
self,
inputs: Union[
Dict[str, torch.Tensor], Dict[str, np.ndarray], Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]
],
mask=None,
) -> Dict[str, torch.Tensor]:
"""Forward pass of the model.
Args:
inputs: Inputs to the model. Can be a dictionary of input names to
input tensors or a tuple of (inputs, targets) where inputs is
a dictionary of input names to input tensors and targets is a
dictionary of target names to target tensors.
mask: A mask for the inputs.
Returns:
A dictionary of output {feature name}::{tensor_name} -> output tensor.
"""

if isinstance(inputs, tuple):
inputs, targets = inputs
# Convert targets to tensors.
for target_feature_name, target_value in targets.items():
if not isinstance(target_value, torch.Tensor):
targets[target_feature_name] = torch.from_numpy(target_value)
else:
targets[target_feature_name] = target_value
else:
targets = None

assert list(inputs.keys()) == self.input_features.keys()

encoder_outputs = self.encode(inputs)
combiner_outputs = self.combine(encoder_outputs)
return self.decode(combiner_outputs, targets, mask)

def predictions(self, inputs):
outputs = self(inputs)
predictions = {}
Expand Down

0 comments on commit c92d891

Please sign in to comment.