Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved parsing of pytorch models using torch.FX - Clean #799

Merged
merged 11 commits into from Jun 22, 2023

Conversation

JanFSchulte
Copy link
Contributor

Refreshed version of #723 to leave behind messy git history

Current parsing of pytorch models uses a loop of the named_modules of the model (https://github.com/fastmachinelearning/hls4ml/blob/main/hls4ml/converters/pytorch_to_hls.py#L163). This has several disadvantages:

  • Captures only layers that are defined as members of the model class
  • Can't infer the correct order of models
  • Ignores other operations that are part of the forward() method of the model

In this PR, we propose to fix this by first created a graph representation of the model's forward() function using the symbolic tracing functionality of https://pytorch.org/docs/stable/fx.html. Each operation in the forward() is represented by a node in the graph. Nodes can be of these types:
image

For example, for this model

class MyModuleConvRelu(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3,3,3)
        
    def forward(self, x):
        y1 = self.conv(x)
        y = torch.relu(y1)
        y = y + y1
        y = torch.relu(y)
        return y

the resulting graph representation is

graph():
    %x : [#users=1] = placeholder[target=x]
    %conv : [#users=2] = call_module[target=conv](args = (%x,), kwargs = {})
    %relu : [#users=1] = call_function[target=torch.relu](args = (%conv,), kwargs = {})
    %add : [#users=1] = call_function[target=operator.add](args = (%relu, %conv), kwargs = {})
    %relu_1 : [#users=1] = call_function[target=torch.relu](args = (%add,), kwargs = {})
    return relu_1

As the nodes in the graph follow the order of operations of the forward() function, we can then simply loop over them and parse each node into one node in the hls4ml model representation. For the parsing of the individual layers, existing code is used where available without significant changes. Functionality for more types of layers is also added by this PR.

The types of layers currently understood by the parser are

  • Linear
  • Softmax
  • Relu
  • LeakyReLU
  • ThresholdedReLU
  • ELU
  • PReLU
  • Sigmoid
  • BatchNorm2d
  • BatchNorm1d'
  • Batch_norm
  • MaxPool1d
  • MaxPool2d
  • AvgPool1d
  • AvgPool2d
  • Add
  • Subtract
  • Multiply
  • Average
  • Maximum
  • Minimum
  • Concatenate
  • Dot
  • Conv1d
  • Conv2d
  • View
  • Dropout
  • Flatten
  • Sequential

This PR also fixes #409

Changes are mostly confined to the frontend, but small changes are made to the backend to the templates for pooling layers to add the option that zero-padded entries are included in average pooling operations.

One big difference between pytorch and keras is the data format of the input tensors, which is channels_first by default, instead of the channels_last used by keras. The built-in tools in pytorch to convert a model to channels_last don't work for all dimensions of the input. Therefore the functionality has been added to transpose the inputs within hls4ml so the existing channels_last implementations of layers can be used. By default the inputs are transposed for io_parrallel but not io_stream since we don't have transpose layers for all dimensions in io_stream. The outputs are not transposed by default, but this can be switched on, again only for io_parallel.

Limitations:

  • Many types of layers not supported yet
  • The same functionality is available in pytorch either as torch.nn classes or torch.functional functions in many cases. These have to be parsed differently, which I have implemented only sporadically for the functionals so far.

Type of change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change which adds functionality)

Tests

The new parsing was tested using 5-6 different pytorch model examples from around the web. In addition, I verified that the two example models for pytroch included with hls4ml get parsed successfully. A test for the API was added in the test/pytest folder, in analogy to the test for the keras parser. All tests pass successfully.

Checklist

  • I have read the guidelines for contributing.
  • I have commented my code, particularly in hard-to-understand areas.
  • I have made corresponding changes to the documentation.
  • My changes generate no new warnings.
  • I have installed and run pre-commit on the files I edited or added.
  • I have added tests that prove my fix is effective or that my feature works.

@vloncar vloncar added the please test Trigger testing by creating local PR branch label Jun 11, 2023
@vloncar
Copy link
Contributor

vloncar commented Jun 11, 2023

This has reached a reached a very high level of stability and feature-set, it is far better than what is currently in the main branch. Support for some ops is not complete and there are some guards to be added to ensure proper parsing, but these are mostly corner cases that we can address later. So I would propose we merge it in this state and continue with bugfixes as we go. There's significant developments built on top already that I wouldn't like to push as part of this PR.

@jmduarte jmduarte added this to the v0.8.0 milestone Jun 15, 2023
@vloncar vloncar added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Jun 15, 2023
@jmitrevs jmitrevs merged commit fdbdb99 into fastmachinelearning:main Jun 22, 2023
9 checks passed
calad0i pushed a commit to calad0i/hls4ml that referenced this pull request Jul 1, 2023
 Improved parsing of pytorch models using torch.FX - Clean
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
please test Trigger testing by creating local PR branch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

PyTorch conversion issues
4 participants