Skip to content

Comments

Restructure graph definition#788

Merged
sevmag merged 15 commits intographnet-team:mainfrom
sevmag:restructure-Graph_Definition
Apr 14, 2025
Merged

Restructure graph definition#788
sevmag merged 15 commits intographnet-team:mainfrom
sevmag:restructure-Graph_Definition

Conversation

@sevmag
Copy link
Collaborator

@sevmag sevmag commented Feb 27, 2025

An abstract base class called DataRepresentation will be introduced for data representations. This step is necessary for data preprocessing for models that do not use graph data representations for training, like a CNN, which is interesting to Graphnet (see #771).

Changes:

  • abstract base class called DataRepresentation
  • current GraphDefinition is restructured and built on the new base class, keeping its functionality as before
  • restructuring folders in src/models keeping the skeleton for the src/models/graphs/* directories with adjusted __init__.py files, which allow for legacy imports (Plus added warnings for these imports)
  • classes using GraphDefinition reworked to use the more general DataRepresentation keeping all necessary GraphDefiniton arguments and class member attributes with deprecation warnings
  • new field in dataset config called "data_representation" but keeping old field "graph_definition"

Files affected by the Deprecation:

  • src/graphnet/data/curated_datamodule.py
  • src/graphnet/data/dataset/dataset.py
  • src/graphnet/data/dataset/parquet/parquet_dataset.py
  • src/graphnet/models/standard_averaged_model.py
  • src/graphnet/models/standard_model.py
  • src/graphnet/utilities/config/dataset_config.py
  • directory: src/models/graphs

Let me know if I missed something!

@sevmag sevmag requested a review from RasmusOrsoe February 27, 2025 12:53
sensor_mask=sensor_mask,
string_mask=string_mask,
repeat_labels=repeat_labels,
node_definition=node_definition, # -> kwargs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just calls your pre_init function, which is a very complicated way of setting a member variable :-D

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you. I am not a fan of the pre_init function. However, with the structure the classes have now, it is necessary since the call of output_feature_names here

# Set final data column names
self.output_feature_names = self._set_output_feature_names(
self._input_feature_names
)

can require member variable assignment specific to e.g. a Graph_Definition. But since the inheritance of the pytorch.nn.Module forces us to do all the member variable assignments after the super().__init__() call here

super().__init__(name=__name__, class_name=self.__class__.__name__)

this assignment has to happen somewhere in between these two calls. If it is preferred we can also pull out the output_feature_names call from the DataRepresentation.__ini__() and put it in the subclasses to get rid of this lazy fix with the pre_init function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AttributeError: cannot assign Module before Module.__init__() call This is the error given when member variables are assigned before the super().__init__() call. So if I am not mistaken the fix proposed by you below will not work.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @sevmag - thanks for the details. I think this little chicken-and-the-egg inheritance challenge can be solved quite conveniently through the @property decorator. I made a little MWE that omits the pre-init method by replacing the explicit member variables with properties. This modification of the DataRepresentation appears to work.

class DataRepresentation(Model):
    """An Abstract class to create data representations from."""

    def __init__(
        self,
        detector: Detector,
        input_feature_names: Optional[List[str]] = None,
        dtype: Optional[torch.dtype] = torch.float,
        perturbation_dict: Optional[Dict[str, float]] = None,
        seed: Optional[Union[int, Generator]] = None,
        add_inactive_sensors: bool = False,
        sensor_mask: Optional[List[int]] = None,
        string_mask: Optional[List[int]] = None,
        repeat_labels: bool = False,
    ):
        """Construct´DataRepresentation´. The ´detector´ holds.

        ´Detector´-specific code. E.g. scaling/standardization and geometry
        tables.

        Args:
            detector: The corresponding ´Detector´ representing the data.
            input_feature_names: Names of each column in expected input data
                that will be built into a the data. If not provided,
                it is automatically assumed that all features in `Detector` is
                used.
            dtype: data type used for features. e.g. ´torch.float´
            perturbation_dict: Dictionary mapping a feature name to a standard
                               deviation according to which the values for this
                               feature should be randomly perturbed. Defaults
                               to None.
            seed: seed or Generator used to randomly sample perturbations.
                  Defaults to None.
            add_inactive_sensors: If True, inactive sensors will be appended
                to the data with padded pulse information. Defaults to False.
            sensor_mask: A list of sensor id's to be masked from the data. Any
                sensor listed here will be removed from the data.
                    Defaults to None.
            string_mask: A list of string id's to be masked from the data.
                Defaults to None.
            repeat_labels: If True, labels will be repeated to match the
                the number of rows in the output of the GraphDefinition.
                Defaults to False.
        """
        # Base class constructor
        super().__init__(name=__name__, class_name=self.__class__.__name__)

        # Member Variables
        self._detector = detector
        self._perturbation_dict = perturbation_dict
        self._sensor_mask = sensor_mask
        self._string_mask = string_mask
        self._add_inactive_sensors = add_inactive_sensors
        self._repeat_labels = repeat_labels

        self._resolve_masks()

        if input_feature_names is None:
            # Assume all features in Detector is used.
            input_feature_names = list(
                self._detector.feature_map().keys()
            )  # noqa: E501 # type: ignore
        self._input_feature_names = input_feature_names

        # Set data type
        self.to(dtype)

        # Set perturbation_cols if needed
        if isinstance(self._perturbation_dict, dict):
            self._perturbation_cols = [
                self._input_feature_names.index(key)
                for key in self._perturbation_dict.keys()
            ]
        if seed is not None:
            if isinstance(seed, int):
                self.rng = default_rng(seed)
            elif isinstance(seed, Generator):
                self.rng = seed
            else:
                raise ValueError(
                    "Invalid seed. Must be an int or a numpy Generator."
                )
        else:
            self.rng = default_rng()

    @property
    def nb_inputs(self) -> int:
        return len(self._input_feature_names)
    
    @property
    def nb_outputs(self) -> int:
        if not hasattr(self, 'output_feature_names'):
            feats = self._set_output_feature_names(self._input_feature_names)
            self.output_feature_names = feats
        return len(self.output_feature_names)

I'll attach the MWE for you on slack. Could you check if that works for you?

@RasmusOrsoe RasmusOrsoe self-requested a review March 3, 2025 08:42
Copy link
Collaborator

@RasmusOrsoe RasmusOrsoe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @sevmag - thanks for this very clean pull request!

Considering the magnitude of the change, I think we should have more eyes on it. Maybe @Aske-Rosted and/or @giogiopg has time to give this a look?

I think the restructuring looks great. I have just one comment:

The readability of the subclasses of DataRepresentation would be much improved if we could omit the _pre_init, _forward_end methods and move the forward function into the subclass. I.e if the structure was

class MyRep(DataRepresentation):

    def __init__(data_rep_args, my_args):

        # My special init things
        self._my_member_variable = my_args

        # Base inheritance
        super().__init__(data_rep_args)

        # My special post-init thins
        self._other_member_variable = ..

    
    def forward( .. ): # This describes exactly what my rep does
        # My special beginning of forward pass
        mystery = ..
        # Base processing
        data = super().forward( ... ) # Calls DataRepresentation.forward
        # My special post-processing
        data['odd_thing'] = 1e-8
        return data

By having the forward pass directly in the subclass, it becomes much easier to see exactly how a specific data representation processes data. What do you think @sevmag?

@sevmag
Copy link
Collaborator Author

sevmag commented Mar 4, 2025

@RasmusOrsoe Thanks for the feedback. I agree that not having the forward method buried in the inheritance chain makes things easier to read. I will rework that. For the pre_init function, see my comment above.

@sevmag
Copy link
Collaborator Author

sevmag commented Mar 18, 2025

I adjusted bits and pieces following the recommendations from @RasmusOrsoe. Next to changing to properties, I also adjusted the forward function of the NodeDefinition since the return of the output_feature_names was not necessary anymore. I adjusted the unit test to match this new structure.

This way, we got rid of the pre_init function.

Let me know what you think!

@RasmusOrsoe RasmusOrsoe self-requested a review March 19, 2025 10:55
self._node_definition.set_output_feature_names(input_feature_names)
return self._node_definition._output_feature_names

def _create_data(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's get rid of _create_data in the DataRepresentation and move the functionality into the forward pass. I.e.

def forward(  # type: ignore
        self,
        input_features: np.ndarray,
        input_feature_names: List[str],
        truth_dicts: Optional[List[Dict[str, Any]]] = None,
        custom_label_functions: Optional[Dict[str, Callable[..., Any]]] = None,
        loss_weight_column: Optional[str] = None,
        loss_weight: Optional[float] = None,
        loss_weight_default_value: Optional[float] = None,
        data_path: Optional[str] = None,
    ) -> Data:
        """Construct graph as ´Data´ object.

        Args:
            input_features: Input features for graph construction.
                Shape ´[num_rows, d]´
            input_feature_names: name of each column. Shape ´[,d]´.
            truth_dicts: Dictionary containing truth labels.
            custom_label_functions: Custom label functions.
            loss_weight_column: Name of column that holds loss weight.
                                Defaults to None.
            loss_weight: Loss weight associated with event. Defaults to None.
            loss_weight_default_value: default value for loss weight.
                    Used in instances where some events have
                    no pre-defined loss weight. Defaults to None.
            data_path: Path to dataset data files. Defaults to None.

        Returns:
            graph
        """

        # Base class call - Get bare Data Object
        data = super().forward(
            input_features=input_features,
            input_feature_names=input_feature_names,
            truth_dicts=truth_dicts,
            custom_label_functions=custom_label_functions,
            loss_weight_column=loss_weight_column,
            loss_weight=loss_weight,
            loss_weight_default_value=loss_weight_default_value,
            data_path=data_path,
        )


        # Create graph & get new node feature names
        data = self._node_definition(data.x)
        if self._sort_by is not None:
            data.x = data.x[data.x[:, self._sort_by].sort()[1]]

        # Assign edges
        if self._edge_definition is not None:
            data = self._edge_definition(data)


        if self._add_static_features:
            data = self._add_features_individually(
                data,
                self.output_feature_names,
            )

        return data

Copy link
Collaborator

@RasmusOrsoe RasmusOrsoe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @sevmag!

I think this looks great. Upon looking at the changes again, I noticed the abstract method of DataRepresentation called _create_data. I think we should consider removing this abstract method to decrease the complexity of the data flow. Instead, we can replace this method in DataRepresentation with

# Create data & get new final data feature names
data = Data(x = input_features) # instead of data = self._create_data(..)

When subsequent representations inherit from it, they then have the same starting point. I.e. in the forward pass of GraphDefinition, we would then have:

def forward(  # type: ignore
        self,
        input_features: np.ndarray,
        input_feature_names: List[str],
        truth_dicts: Optional[List[Dict[str, Any]]] = None,
        custom_label_functions: Optional[Dict[str, Callable[..., Any]]] = None,
        loss_weight_column: Optional[str] = None,
        loss_weight: Optional[float] = None,
        loss_weight_default_value: Optional[float] = None,
        data_path: Optional[str] = None,
    ) -> Data:
        """Construct graph as ´Data´ object.

        Args:
            input_features: Input features for graph construction.
                Shape ´[num_rows, d]´
            input_feature_names: name of each column. Shape ´[,d]´.
            truth_dicts: Dictionary containing truth labels.
            custom_label_functions: Custom label functions.
            loss_weight_column: Name of column that holds loss weight.
                                Defaults to None.
            loss_weight: Loss weight associated with event. Defaults to None.
            loss_weight_default_value: default value for loss weight.
                    Used in instances where some events have
                    no pre-defined loss weight. Defaults to None.
            data_path: Path to dataset data files. Defaults to None.

        Returns:
            graph
        """

        # Base class call - Get bare Data Object
        data = super().forward(
            input_features=input_features,
            input_feature_names=input_feature_names,
            truth_dicts=truth_dicts,
            custom_label_functions=custom_label_functions,
            loss_weight_column=loss_weight_column,
            loss_weight=loss_weight,
            loss_weight_default_value=loss_weight_default_value,
            data_path=data_path,
        )


        # Create graph & get new node feature names
        data = self._node_definition(data.x)
        if self._sort_by is not None:
            data.x = data.x[data.x[:, self._sort_by].sort()[1]]

        # Assign edges
        if self._edge_definition is not None:
            data = self._edge_definition(data)


        if self._add_static_features:
            data = self._add_features_individually(
                data,
                self.output_feature_names,
            )

        return data


return data

def _forward_end(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this function is needed. Am I missing something?

@RasmusOrsoe
Copy link
Collaborator

@sevmag thanks for the changes! At this stage I think we're good to go! @Aske-Rosted / @giogiopg do you have thoughts on this?

Copy link
Collaborator

@Aske-Rosted Aske-Rosted left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, found two very small things.

from graphnet.utilities.decorators import final
from graphnet.models import Model
from graphnet.models.graphs.utils import (
from ..utils import (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it not more readable to have the absolute path?

raise NotImplementedError

return graph
def _label_repeater(self, label: torch.Tensor, data: Data) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this have the abstractmethod decorator?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think having this function in all of the subclasses of the DataRepresentation class is necessary. For example, I would not know which logic to follow in repeating the label for a 3D image. However, I also do not fully understand the reason for this functionality in the first place, so I may be missing something.

But suppose we would not use it for an image definition or something similar. In that case, it does not make sense to put an abstractmethod decorator since it would force us to define an unnecessary function.

It would be nicer to push this function into the GraphDefinition but since its called by _add_custom_labels and _add_truth I think that would be too messy.

@sevmag sevmag merged commit fb26ec2 into graphnet-team:main Apr 14, 2025
13 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants