Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeTransformers for PyTorch Tensor, Module, and Checkpoint #1032

Merged
merged 19 commits into from
Jul 7, 2022

Conversation

samhita-alla
Copy link
Contributor

@samhita-alla samhita-alla commented Jun 1, 2022

Signed-off-by: Samhita Alla aallasamhita@gmail.com

TL;DR

This PR adds support for:

  • torch.Tensor as a native type.
  • saving and loading PyTorch modules (models & layers) using state_dict. PyTorchStateDict is the custom type defined to handle serialization and deserialization of state_dict. torch.nn.Module (base class for all neural network modules) isn't considered a native type here; this is per the docs: "Instead of saving a module directly, for compatibility reasons it is recommended to instead save only its state dict." Subclassing torch.nn.Module in a TypeTransfomer might mislead the users into thinking that a PyTorch model is being serialized and deserialized when instead the model's state_dict is being considered.

Type

  • Bug Fix
  • Feature
  • Plugin

Are all requirements met?

  • Code completed
  • Smoke tested
  • Unit tests added
  • Code documentation added
  • Any pending items have an associated Issue

Complete description

Module/Model serialization example:

@task
def generate_module() -> torch.nn.Module:
    bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
    return bn

@task
def t4(model: torch.nn.Module) -> torch.nn.Module:
    return model.l1

Tensor serialization example:

@task
def generate_tensor_1d() -> torch.Tensor:
    return torch.zeros(5, dtype=torch.int32)

@task
def t1(tensor: torch.Tensor) -> torch.Tensor:
    assert tensor.dtype == torch.int32
    tensor[0] = 1
    return tensor

Checkpoint example:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        ...

    def forward(self, x):
        ...
        return x


@task
def generate_model(hyperparameters: Hyperparameters) -> PyTorchCheckpoint:
    bn = Net()
    optimizer = optim.SGD(bn.parameters(), lr=0.001, momentum=0.9)
    return PyTorchCheckpoint(module=bn, hyperparameters=hyperparameters, optimizer=optimizer)

@task
def t1(checkpoint: PyTorchCheckpoint):
    new_bn = Net()
    new_bn.load_state_dict(checkpoint["module_state_dict"])
    optimizer = optim.SGD(new_bn.parameters(), lr=0.001, momentum=0.9)
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

Tracking Issue

flyteorg/flyte#2544

Follow-up issue

NA
OR
https://github.com/flyteorg/flyte/issues/

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
@codecov
Copy link

codecov bot commented Jun 1, 2022

Codecov Report

Attention: Patch coverage is 88.48684% with 35 lines in your changes missing coverage. Please review.

Project coverage is 86.92%. Comparing base (03a9487) to head (a48472a).
Report is 1038 commits behind head on master.

Files Patch % Lines
flytekit/extras/pytorch/checkpoint.py 84.50% 6 Missing and 5 partials ⚠️
...ts/flytekit/unit/extras/pytorch/test_checkpoint.py 91.13% 7 Missing ⚠️
flytekit/extras/pytorch/native.py 86.66% 4 Missing and 2 partials ⚠️
...ytekit/unit/extras/pytorch/test_transformations.py 90.90% 4 Missing and 1 partial ⚠️
tests/flytekit/unit/extras/pytorch/test_native.py 91.66% 4 Missing ⚠️
flytekit/extras/pytorch/__init__.py 66.66% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1032      +/-   ##
==========================================
+ Coverage   86.90%   86.92%   +0.01%     
==========================================
  Files         269      275       +6     
  Lines       25144    25448     +304     
  Branches     2834     2862      +28     
==========================================
+ Hits        21851    22120     +269     
- Misses       2823     2850      +27     
- Partials      470      478       +8     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@kumare3
Copy link
Contributor

kumare3 commented Jun 1, 2022

@samhita-alla i think the type should still be torch.no.module and we should do the right thing? I read your point, but is there a problem on reloading non.module? Cc @cosmicBboy

Copy link
Contributor

@cosmicBboy cosmicBboy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this should be a flytekit plugin and not part of core flytekit package. It sets a precedence of adding more and more additional deps to the core package that I'm not sure we want to shoulder (pytorch, tensorflow, sklearn, etc.)

thoughts @eapolinario @kumare3 ?

edit: I see how we're handling the case if torch isn't installed. I suppose this is okay, as long as we're okay establishing this pattern for other types

@cosmicBboy
Copy link
Contributor

i think the type should still be torch.no.module and we should do the right thing? I read your point, but is there a problem on reloading non.module? Cc @cosmicBboy

Agreed, I think any subclass of nn.Module should be automatically serialized/deserialized using the state dict.

Instead of saving a module directly, for compatibility reasons it is recommended to instead save only its state dict

So we should do the "right thing" automatically using state dict instead of pickling the module. We're already abstracting away how these types are stored in Flyte, so I think the risk of confusion here is minimal. The extra layer of indirection with PyTorchStateDict adds an additional thing for the user to learn and more friction porting code from vanilla python to Flyte.

In the to_python_value function, I think the nn.Module.load_state_dict method should be automatically called so that the user doesn't have to worry about loading it in the task function body.

Of course all of this assumes that the user-code has access to the nn.Module subclass when it's specified in the input type signature

class MyModel(nn.Module): ...

@task
def my_task(model: MyModel):
    # model is automatically converted to a MyModel type by the type engine.
    ...

@samhita-alla
Copy link
Contributor Author

samhita-alla commented Jun 2, 2022

@cosmicBboy, I don't think this can be a standalone plugin because having so will require users to install torch which might not correspond to the version that the user has already installed or may want to install. If a user wants to use torch.Tensor or any other native type as such, the library might have already been installed to construct the models.

In the to_python_value function, I think the nn.Module.load_state_dict method should be automatically called so that the user doesn't have to worry about loading it in the task function body.

How should we serialize and deserialize a PyTorch model?

@cosmicBboy
Copy link
Contributor

cosmicBboy commented Jun 2, 2022

I don't think this can be a standalone plugin because having so will require users to install torch which might not correspond to the version that the user has already installed or may want to install.

The plugin could specify an unpinned torch as a dependency, which should address that issue. Anyway, I think it's okay having it in the core library as long as we're aware of and consistently handle loose dependencies like this.

How should we serialize and deserialize a PyTorch model?

Playing around with this PR locally, it does seem like there are a bunch of issues associated with trying to handle serialization/deserialization in the type transformer:

  • we don't know ahead of time what hyperparameters to initialize the nn.Module subclass, which is an issue because often times one can't simply instantiate a model with MyModel().
  • we could do something like Annotated[MyModel, dict(...)] to capture hyperparameters, but it's not ideal since in most use cases the hyperparameters are part of the task arguments and so we don't know that ahead of time.

I have two thoughts here:

  1. Does it make sense to support nn.Module via torch.save and torch.load (i.e. using pickle)? Back-pedaling on my previous comment, I wonder if it's okay to support this and referencing the pytorch docs to highlight potential issues with doing this (I just have a feeling that this would be a welcome convenience for the lazy pytorch dev 🙃).
  2. I think the "right way" using a PytorchStateDict dataclass is the way to go, although I'd propose extending this to PytorchCheckpoint such that it includes additional metadata, going off of this guide.:
T = TypeVar("T", bound=typing.Union[typing.Dict, typing.NamedTuple])

@dataclass_json
@dataclass
class PyTorchCheckpoint(object, Generic[T]):
    module: typing.Optional[torch.nn.Module] = None
    hyperparameters: typing.Optional[T] = None # not required for models that have hard-coded architecture
    optimizer: typing.Optional[torch.optim. Optimizer] = None
    epoch: typing.Optional[int] = None
    loss: typing.Optional[float] = None

Basically this supports the special case of just wanted to store the module state dict, while also supporting a fully generalized checkpoint that we can probably use in concert with the intra-task checkpointing.

Since hyperparameters must be user-provided, we can't know its type ahead of time, hence the use of Generic.:

class Hyperparameters(typing.NamedTuple):
    ...

ModelCheckpoint = PyTorchCheckpoint[Hyperparameters]

@task 
def produce_model(hyperparameters: Hyperparameters) -> ModelCheckpoint:
    model = MyModel(**hyperparameters._asdict())
    optim = torch.optim.SGD()
    ...  # train
    return ModelCheckpoint(
        module=model,
        hyperparameters=hyperparameters,
        optimizer=optim,
        epoch=...,
        loss=...,
    )

edit: since we'd include the hyperparameters in the state dict that's serialized, we may not need the Generic stuff

@samhita-alla
Copy link
Contributor Author

  • we don't know ahead of time what hyperparameters to initialize the nn.Module subclass, which is an issue because often times one can't simply instantiate a model with MyModel().
  • we could do something like Annotated[MyModel, dict(...)] to capture hyperparameters, but it's not ideal since in most use cases the hyperparameters are part of the task arguments and so we don't know that ahead of time.

Can we not store hyperparameters when the user returns a module with PyTorchStateDict as the attached type? The module must have already been initialized with the hyperparameters in case the architecture isn't hard-coded. WDYT?

@task
def generate_model() -> PyTorchStateDict:
    bn = MyModel(...)
    return PyTorchStateDict(module=bn)

Does it make sense to support nn.Module via torch.save and torch.load (i.e. using pickle)? Back-pedaling on my previous comment, I wonder if it's okay to support this and referencing the pytorch docs to highlight potential issues with doing this (I just have a feeling that this would be a welcome convenience for the lazy pytorch dev 🙃).

Yeah! torch.save on nn.Module stores the path to the file containing the class. The disadvantage with this approach as per their doc is that the code can break in various ways when the module is loaded in other projects or after refactors. Reference.

But isn't this the approach we'll have to follow in case we want to support applying load_state_dict on the model within the plugin because that's how we could store the subclass? If it's a module or a single network layer, we'd have to store what it is, say, BatchNorm1D. Please let me know if you disagree.

I think the "right way" using a PytorchStateDict dataclass is the way to go, although I'd propose extending this to PytorchCheckpoint such that it includes additional metadata, going off of this guide.:

I love this! We can have PyTorchCheckpoint rather than PyTorchStateDict to handle additional metadata.

@cosmicBboy
Copy link
Contributor

cosmicBboy commented Jun 3, 2022

Can we not store hyperparameters when the user returns a module with PyTorchStateDict as the attached type?

Yep! This is pretty much the PyTorchCheckpoint proposal with additional metadata

But isn't this the approach we'll have to follow in case we want to support applying load_state_dict on the model within the plugin because that's how we could store the subclass? If it's a module or a single network layer, we'd have to store what it is, say, BatchNorm1D. Please let me know if you disagree.

So if we use torch.save and torch.load on the entire module we offload storage of the complete state and class definition to pytorch, subject to the risk you point out:

The disadvantage with this approach as per their doc is that the code can break in various ways when the module is loaded in other projects or after refactors

So if we agree that PyTorchCheckpoint is the way to go to do things correctly to fully capture model and training state, the question is:

do we want to support nn.Module as a type that uses torch.save/load, with documentation of the risks and encouraging use of PyTorchCheckpoint? I'm inclined to say yes because (i) pytorch supports it and (ii) in certain circumstances Flyte actually reduces the risk of incompatibilities since tasks are hardened via containerization.

@samhita-alla
Copy link
Contributor Author

samhita-alla commented Jun 6, 2022

do we want to support nn.Module as a type that uses torch.save/load, with documentation of the risks and encouraging use of PyTorchCheckpoint? I'm inclined to say yes because (i) pytorch supports it and (ii) in certain circumstances Flyte actually reduces the risk of incompatibilities since tasks are hardened via containerization.

Do you mean to say that we support:

  • nn.Module as a native type where we use torch.save and torch.load on the entire module under the hood, and
  • PyTorchCheckpoint to serialize checkpoint and let users run load_state_dict on the deserialized checkpoint?

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
@samhita-alla
Copy link
Contributor Author

samhita-alla commented Jun 13, 2022

I've dynamically created transformers for torch.nn.Module and torch.Tensor because they use the same logic under the hood. To eliminate code duplication, I opted for dynamic class creation, and the two transformers are present in the native.py file.

On the whole, the following are the types enclosed in this PR:

  • PyTorch Tensor (torch.Tensor) native type
  • Pytorch Module (torch.nn.Module) native type
  • PyTorchCheckpoint

All the examples are available in the PR description.

Note: PyTorchCheckpoint supports serializing hyperparameters of type dict, NamedTuple, and dataclass.

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
@@ -1,13 +1 @@
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is removing this from docs intentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I don't think we'd want to have Transformer in the API reference cause the methods within the TypeTransformer class remain the same.

@@ -0,0 +1,110 @@
from __future__ import annotations
Copy link
Contributor

@cosmicBboy cosmicBboy Jun 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall notes on this module:

I can appreciate the DRY principle here, but the dynamically generated classes in this module makes it a little hard to reason about and read imo... I'm also not sure whether these dynamically generated classes will have useful code/auto-completion on the user side.

Suggestion: DRYing this logic could be achieved with a parent class like BaseTensor and two subclasses Tensor and Module could be clearer to read.

One suggestion for an approach would be:

T = typing.TypeVar("T")

# use generics to abstract out the types in the method definitions
class PytorchTypeTransformer(TypeTransformer, typing.Generic[T]):

    def get_literal_type(self, t: Type[T]) -> LiteralType:
        return LiteralType(
            blob=_core_types.BlobType(
                format=self.PYTORCH_FORMAT,
                dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,
            )
        )

    def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
        ...

    # implement the other common methods here

class PyTorchTensorTransformer(PytorchTypeTransformer[nn.Tensor]):
    PYTORCH_FORMAT = "PytorchTensor"

    def __init__(self):
        super().__init__(name="Pytorch Tensor", t=nn.Tensor)

    ...

class PyTorchModuleTransformer(PytorchTypeTransformer[nn.Module]): ...
    PYTORCH_FORMAT = "PytorchModule"

    def __init__(self):
        super().__init__(name="Pytorch Module", t=nn.Module)

    

There's probably a better way of doing this, but just wanted to propose a feasible alternative.

Copy link
Contributor Author

@samhita-alla samhita-alla Jun 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed an imperative style to tackle this, but yes, the OOP approach is more readable. There shouldn't be a problem with auto-completion cause users wouldn't be importing any module in their code as such. However, I modified the code to use inheritance now. :) Thanks for looking into this!

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
@samhita-alla samhita-alla changed the title TypeTransformers for PyTorch Tensor and Module TypeTransformers for PyTorch Tensor, Module, and Checkpoint Jun 16, 2022
Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
cosmicBboy
cosmicBboy previously approved these changes Jun 16, 2022
Copy link
Contributor

@cosmicBboy cosmicBboy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great work!

Copy link
Contributor

@cosmicBboy cosmicBboy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, one thing I forgot to mention here is: how do we want to handle cpu/cuda tensors?

For example, if I serialize a cuda tensor/module, the type engine will fail if we try to deserialize it in a cpu machine.

Should we leave that up to the user? Or should we do some automagic to handle that?

(depending on the answer, we can work on a follow-up PR)

@cosmicBboy
Copy link
Contributor

Amazing! Let's merge this @samhita-alla @eapolinario

@samhita-alla
Copy link
Contributor Author

+1, please. @cosmicBboy

Copy link
Member

@pingsutw pingsutw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple of minor comments. LGTM otherwise

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
@samhita-alla
Copy link
Contributor Author

Couple of minor comments. LGTM otherwise

@pingsutw, I missed this! Resolved it now.

cosmicBboy
cosmicBboy previously approved these changes Jun 24, 2022
@kumare3
Copy link
Contributor

kumare3 commented Jun 29, 2022

should we make this part of pip install flytekitplugins-ml plugin -> cc @cosmicBboy / @wild-endeavor

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
@cosmicBboy
Copy link
Contributor

just for the paper trail:

@Yee and I chatted and we’re okay with the current approach on the PR. I think we figured it makes sense to have pytorch (and other “core” ML types) in flytekit.extras Our general heuristic is that: if it’s part of the “core” ML stack (i.e. it’s super popular, e.g. pytorch, tensorflow, keras, etc) then it can be in flytekit.extras but if not it goes into flytekitplugins … acknowledging that what is considered “core” will change over time.

Copy link
Contributor

@cosmicBboy cosmicBboy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Module organization makes sense to me, see comment about flytekit/__init__.py import

@@ -183,6 +183,7 @@
from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow
from flytekit.deck import Deck
from flytekit.extras.persistence import GCSPersistence, HttpPersistence, S3Persistence
from flytekit.extras.pytorch import PyTorchModuleTransformer, PyTorchTensorTransformer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will break if torch isn't installed (?)

Maybe we can do something like:

from flytekit.extras import pytorch

Which would automatically register the tensor and module type transformers (if torch is installed)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad. Fixed the import now.

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
cosmicBboy
cosmicBboy previously approved these changes Jul 6, 2022
pingsutw
pingsutw previously approved these changes Jul 6, 2022
wild-endeavor
wild-endeavor previously approved these changes Jul 6, 2022
from flytekit.loggers import logger

try:
from .checkpoint import PyTorchCheckpoint, PyTorchCheckpointTransformer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should stick with the full import in the future, just for consistency. merge as is, i'll update it in the future.

try:
from typing import Protocol
except ImportError:
from typing_extensions import Protocol
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can always use typing_extensions right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! I'll merge this now but will make sure to modify the import to use typing_extensions in a different PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified the import — I had to resolve a merge conflict.

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
@samhita-alla samhita-alla merged commit 650eb58 into master Jul 7, 2022
wild-endeavor pushed a commit that referenced this pull request Aug 2, 2022
* TypeTransformers for PyTorch Tensor and Module

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* add torch to requirements

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* add module as a native type and PyTorchCheckpoint

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* update requirements

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* procedural to OOP approach

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* nit

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* verify device conversion

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* verify device conversion

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* hyperparameters can be None

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* device conversion

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* device conversion

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* checkpoint code cleanup

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* resolve merge conflict

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* fix pytorch api reference; resolve merge conflict

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* fix pytorch import

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
wild-endeavor pushed a commit that referenced this pull request Aug 2, 2022
* TypeTransformers for PyTorch Tensor and Module

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* add torch to requirements

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* add module as a native type and PyTorchCheckpoint

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* update requirements

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* procedural to OOP approach

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* nit

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* verify device conversion

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* verify device conversion

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* hyperparameters can be None

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* device conversion

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* device conversion

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* checkpoint code cleanup

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* resolve merge conflict

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* fix pytorch api reference; resolve merge conflict

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* fix pytorch import

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants