Restructure graph definition#788
Conversation
| sensor_mask=sensor_mask, | ||
| string_mask=string_mask, | ||
| repeat_labels=repeat_labels, | ||
| node_definition=node_definition, # -> kwargs |
There was a problem hiding this comment.
This just calls your pre_init function, which is a very complicated way of setting a member variable :-D
There was a problem hiding this comment.
I agree with you. I am not a fan of the pre_init function. However, with the structure the classes have now, it is necessary since the call of output_feature_names here
can require member variable assignment specific to e.g. a Graph_Definition. But since the inheritance of the pytorch.nn.Module forces us to do all the member variable assignments after the super().__init__() call here
this assignment has to happen somewhere in between these two calls. If it is preferred we can also pull out the output_feature_names call from the DataRepresentation.__ini__() and put it in the subclasses to get rid of this lazy fix with the pre_init function.
There was a problem hiding this comment.
AttributeError: cannot assign Module before Module.__init__() call This is the error given when member variables are assigned before the super().__init__() call. So if I am not mistaken the fix proposed by you below will not work.
There was a problem hiding this comment.
Hi @sevmag - thanks for the details. I think this little chicken-and-the-egg inheritance challenge can be solved quite conveniently through the @property decorator. I made a little MWE that omits the pre-init method by replacing the explicit member variables with properties. This modification of the DataRepresentation appears to work.
class DataRepresentation(Model):
"""An Abstract class to create data representations from."""
def __init__(
self,
detector: Detector,
input_feature_names: Optional[List[str]] = None,
dtype: Optional[torch.dtype] = torch.float,
perturbation_dict: Optional[Dict[str, float]] = None,
seed: Optional[Union[int, Generator]] = None,
add_inactive_sensors: bool = False,
sensor_mask: Optional[List[int]] = None,
string_mask: Optional[List[int]] = None,
repeat_labels: bool = False,
):
"""Construct´DataRepresentation´. The ´detector´ holds.
´Detector´-specific code. E.g. scaling/standardization and geometry
tables.
Args:
detector: The corresponding ´Detector´ representing the data.
input_feature_names: Names of each column in expected input data
that will be built into a the data. If not provided,
it is automatically assumed that all features in `Detector` is
used.
dtype: data type used for features. e.g. ´torch.float´
perturbation_dict: Dictionary mapping a feature name to a standard
deviation according to which the values for this
feature should be randomly perturbed. Defaults
to None.
seed: seed or Generator used to randomly sample perturbations.
Defaults to None.
add_inactive_sensors: If True, inactive sensors will be appended
to the data with padded pulse information. Defaults to False.
sensor_mask: A list of sensor id's to be masked from the data. Any
sensor listed here will be removed from the data.
Defaults to None.
string_mask: A list of string id's to be masked from the data.
Defaults to None.
repeat_labels: If True, labels will be repeated to match the
the number of rows in the output of the GraphDefinition.
Defaults to False.
"""
# Base class constructor
super().__init__(name=__name__, class_name=self.__class__.__name__)
# Member Variables
self._detector = detector
self._perturbation_dict = perturbation_dict
self._sensor_mask = sensor_mask
self._string_mask = string_mask
self._add_inactive_sensors = add_inactive_sensors
self._repeat_labels = repeat_labels
self._resolve_masks()
if input_feature_names is None:
# Assume all features in Detector is used.
input_feature_names = list(
self._detector.feature_map().keys()
) # noqa: E501 # type: ignore
self._input_feature_names = input_feature_names
# Set data type
self.to(dtype)
# Set perturbation_cols if needed
if isinstance(self._perturbation_dict, dict):
self._perturbation_cols = [
self._input_feature_names.index(key)
for key in self._perturbation_dict.keys()
]
if seed is not None:
if isinstance(seed, int):
self.rng = default_rng(seed)
elif isinstance(seed, Generator):
self.rng = seed
else:
raise ValueError(
"Invalid seed. Must be an int or a numpy Generator."
)
else:
self.rng = default_rng()
@property
def nb_inputs(self) -> int:
return len(self._input_feature_names)
@property
def nb_outputs(self) -> int:
if not hasattr(self, 'output_feature_names'):
feats = self._set_output_feature_names(self._input_feature_names)
self.output_feature_names = feats
return len(self.output_feature_names)I'll attach the MWE for you on slack. Could you check if that works for you?
There was a problem hiding this comment.
Hi @sevmag - thanks for this very clean pull request!
Considering the magnitude of the change, I think we should have more eyes on it. Maybe @Aske-Rosted and/or @giogiopg has time to give this a look?
I think the restructuring looks great. I have just one comment:
The readability of the subclasses of DataRepresentation would be much improved if we could omit the _pre_init, _forward_end methods and move the forward function into the subclass. I.e if the structure was
class MyRep(DataRepresentation):
def __init__(data_rep_args, my_args):
# My special init things
self._my_member_variable = my_args
# Base inheritance
super().__init__(data_rep_args)
# My special post-init thins
self._other_member_variable = ..
def forward( .. ): # This describes exactly what my rep does
# My special beginning of forward pass
mystery = ..
# Base processing
data = super().forward( ... ) # Calls DataRepresentation.forward
# My special post-processing
data['odd_thing'] = 1e-8
return dataBy having the forward pass directly in the subclass, it becomes much easier to see exactly how a specific data representation processes data. What do you think @sevmag?
|
@RasmusOrsoe Thanks for the feedback. I agree that not having the |
|
I adjusted bits and pieces following the recommendations from @RasmusOrsoe. Next to changing to properties, I also adjusted the This way, we got rid of the Let me know what you think! |
| self._node_definition.set_output_feature_names(input_feature_names) | ||
| return self._node_definition._output_feature_names | ||
|
|
||
| def _create_data( |
There was a problem hiding this comment.
Let's get rid of _create_data in the DataRepresentation and move the functionality into the forward pass. I.e.
def forward( # type: ignore
self,
input_features: np.ndarray,
input_feature_names: List[str],
truth_dicts: Optional[List[Dict[str, Any]]] = None,
custom_label_functions: Optional[Dict[str, Callable[..., Any]]] = None,
loss_weight_column: Optional[str] = None,
loss_weight: Optional[float] = None,
loss_weight_default_value: Optional[float] = None,
data_path: Optional[str] = None,
) -> Data:
"""Construct graph as ´Data´ object.
Args:
input_features: Input features for graph construction.
Shape ´[num_rows, d]´
input_feature_names: name of each column. Shape ´[,d]´.
truth_dicts: Dictionary containing truth labels.
custom_label_functions: Custom label functions.
loss_weight_column: Name of column that holds loss weight.
Defaults to None.
loss_weight: Loss weight associated with event. Defaults to None.
loss_weight_default_value: default value for loss weight.
Used in instances where some events have
no pre-defined loss weight. Defaults to None.
data_path: Path to dataset data files. Defaults to None.
Returns:
graph
"""
# Base class call - Get bare Data Object
data = super().forward(
input_features=input_features,
input_feature_names=input_feature_names,
truth_dicts=truth_dicts,
custom_label_functions=custom_label_functions,
loss_weight_column=loss_weight_column,
loss_weight=loss_weight,
loss_weight_default_value=loss_weight_default_value,
data_path=data_path,
)
# Create graph & get new node feature names
data = self._node_definition(data.x)
if self._sort_by is not None:
data.x = data.x[data.x[:, self._sort_by].sort()[1]]
# Assign edges
if self._edge_definition is not None:
data = self._edge_definition(data)
if self._add_static_features:
data = self._add_features_individually(
data,
self.output_feature_names,
)
return data
RasmusOrsoe
left a comment
There was a problem hiding this comment.
Hey @sevmag!
I think this looks great. Upon looking at the changes again, I noticed the abstract method of DataRepresentation called _create_data. I think we should consider removing this abstract method to decrease the complexity of the data flow. Instead, we can replace this method in DataRepresentation with
# Create data & get new final data feature names
data = Data(x = input_features) # instead of data = self._create_data(..)When subsequent representations inherit from it, they then have the same starting point. I.e. in the forward pass of GraphDefinition, we would then have:
def forward( # type: ignore
self,
input_features: np.ndarray,
input_feature_names: List[str],
truth_dicts: Optional[List[Dict[str, Any]]] = None,
custom_label_functions: Optional[Dict[str, Callable[..., Any]]] = None,
loss_weight_column: Optional[str] = None,
loss_weight: Optional[float] = None,
loss_weight_default_value: Optional[float] = None,
data_path: Optional[str] = None,
) -> Data:
"""Construct graph as ´Data´ object.
Args:
input_features: Input features for graph construction.
Shape ´[num_rows, d]´
input_feature_names: name of each column. Shape ´[,d]´.
truth_dicts: Dictionary containing truth labels.
custom_label_functions: Custom label functions.
loss_weight_column: Name of column that holds loss weight.
Defaults to None.
loss_weight: Loss weight associated with event. Defaults to None.
loss_weight_default_value: default value for loss weight.
Used in instances where some events have
no pre-defined loss weight. Defaults to None.
data_path: Path to dataset data files. Defaults to None.
Returns:
graph
"""
# Base class call - Get bare Data Object
data = super().forward(
input_features=input_features,
input_feature_names=input_feature_names,
truth_dicts=truth_dicts,
custom_label_functions=custom_label_functions,
loss_weight_column=loss_weight_column,
loss_weight=loss_weight,
loss_weight_default_value=loss_weight_default_value,
data_path=data_path,
)
# Create graph & get new node feature names
data = self._node_definition(data.x)
if self._sort_by is not None:
data.x = data.x[data.x[:, self._sort_by].sort()[1]]
# Assign edges
if self._edge_definition is not None:
data = self._edge_definition(data)
if self._add_static_features:
data = self._add_features_individually(
data,
self.output_feature_names,
)
return data|
|
||
| return data | ||
|
|
||
| def _forward_end( |
There was a problem hiding this comment.
I don't think this function is needed. Am I missing something?
|
@sevmag thanks for the changes! At this stage I think we're good to go! @Aske-Rosted / @giogiopg do you have thoughts on this? |
Aske-Rosted
left a comment
There was a problem hiding this comment.
Looks good to me, found two very small things.
| from graphnet.utilities.decorators import final | ||
| from graphnet.models import Model | ||
| from graphnet.models.graphs.utils import ( | ||
| from ..utils import ( |
There was a problem hiding this comment.
Is it not more readable to have the absolute path?
| raise NotImplementedError | ||
|
|
||
| return graph | ||
| def _label_repeater(self, label: torch.Tensor, data: Data) -> torch.Tensor: |
There was a problem hiding this comment.
Should this have the abstractmethod decorator?
There was a problem hiding this comment.
I do not think having this function in all of the subclasses of the DataRepresentation class is necessary. For example, I would not know which logic to follow in repeating the label for a 3D image. However, I also do not fully understand the reason for this functionality in the first place, so I may be missing something.
But suppose we would not use it for an image definition or something similar. In that case, it does not make sense to put an abstractmethod decorator since it would force us to define an unnecessary function.
It would be nicer to push this function into the GraphDefinition but since its called by _add_custom_labels and _add_truth I think that would be too messy.
An abstract base class called
DataRepresentationwill be introduced for data representations. This step is necessary for data preprocessing for models that do not use graph data representations for training, like a CNN, which is interesting to Graphnet (see #771).Changes:
DataRepresentationGraphDefinitionis restructured and built on the new base class, keeping its functionality as beforesrc/modelskeeping the skeleton for thesrc/models/graphs/*directories with adjusted__init__.pyfiles, which allow for legacy imports (Plus added warnings for these imports)GraphDefinitionreworked to use the more generalDataRepresentationkeeping all necessaryGraphDefinitonarguments and class member attributes with deprecation warnings"data_representation"but keeping old field"graph_definition"Files affected by the Deprecation:
src/graphnet/data/curated_datamodule.pysrc/graphnet/data/dataset/dataset.pysrc/graphnet/data/dataset/parquet/parquet_dataset.pysrc/graphnet/models/standard_averaged_model.pysrc/graphnet/models/standard_model.pysrc/graphnet/utilities/config/dataset_config.pysrc/models/graphsLet me know if I missed something!