Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing torch.load #403

Closed
GeorgeS2019 opened this issue Oct 8, 2021 · 17 comments
Closed

Missing torch.load #403

GeorgeS2019 opened this issue Oct 8, 2021 · 17 comments
Labels
question Further information is requested

Comments

@GeorgeS2019
Copy link

GeorgeS2019 commented Oct 8, 2021

Similar with the discussion of the missing torch.save

pytorch torch.load refers to torch\serialization.py

which provide parameter instructions for loading

Example

        >>> torch.load('tensors.pt')
        # Load all tensors onto the CPU
        >>> torch.load('tensors.pt', map_location=torch.device('cpu'))
        # Load all tensors onto the CPU, using a function
        >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
        # Load all tensors onto GPU 1
        >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
        # Map tensors from GPU 1 to GPU 0
        >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
        # Load tensor from io.BytesIO object
        >>> with open('tensor.pt', 'rb') as f:
        ...     buffer = io.BytesIO(f.read())
        >>> torch.load(buffer)
        # Load a module with 'ascii' encoding for unpickling
        >>> torch.load('module.pt', encoding='ascii')

Currently LibTorchSharp implements one of the possible loading options listed above

Since pickling is an overkill as discussed for .NET

As I am still learning ... is there a need to provide more loading options provided through torch.load instead of Module.load in TorchSharp?

I am raising this issue, as I fail to load a saved State_Dict created through exportsd.py back to TorchSharp using Module.Load

I did not get any error message, as the process crashes.

suggestions: is there a need for error messages when loading fail to assist in a more reliable loading state_dict.

@NiklasGustafsson
Copy link
Contributor

I did not get any error message, as the process crashes.

Python pickling is not going to be a solution here, it's too intimately tied to Python's object model. If there is a bug with saving state_dict() in Python, and loading it in .NET, I want to make sure to fix it. At a minimum, provide a more actionable error message. Like the article on saving and loading demonstrates, you have to use the special format for exporting model weights.

Could you please provide more details on what you are saving and how your are loading it? If you can provide some source files for Python and .NET (narrowed down, preferably), that would be great.

@GeorgeS2019
Copy link
Author

I am attempting to hijack the populated model decribed in main.py and save the State_Dict as decribied using exportsd.py

At this empty line, these codes were inserted to export State_Dict

    f = open("gpt2-pytorch_model.ts", "wb")
    exportsd.save_state_dict(model.to("cpu").state_dict(), f)
    f.close()

When I attempted to load gpt2-pytorch_model.ts using TorchSharp Module.Load, it crashes without error message.

Since in PyTorch, it is possible to simply load the state_dict without first defining the model, I attempting to check if that is possible in TorchSharp

Code
state_dict = torch.load('gpt2-pytorch_model.bin', map_location='cpu' if not torch.cuda.is_available() else None)

As the PyTorch load function involves more parameters, I wonder if TorchSharp need that too.

I hope I understand correctly this is how to use the exportsd.py function

@NiklasGustafsson
Copy link
Contributor

NiklasGustafsson commented Oct 8, 2021

Thanks for that information. In terms of the how TorchSharp model serialization works, it loads and saves model parameters (weights), not models. That means that in order to load weights, you have to have an exact copy of the original model defined in .NET, and an instance of the model created (presumably with random or empty weights).

Parameters should be represented as fields in the model, as should buffers (tensors that are used by the model, but not affected by training), and they must have exactly the same name as in the original model.

I'll construct some negative unit tests for this and see if I can improve the error messages to be more informative.

@GeorgeS2019
Copy link
Author

Create a unit test to test the feasibility

[Fact]
public void LoadModelTest()
{
    string fileName = "gpt2-pytorch_model.ts";
    if (File.Exists(fileName)) {
        var stateDict = Module.Load(fileName);
        Assert.NotNull(stateDict);
    }
}

@GeorgeS2019
Copy link
Author

GeorgeS2019 commented Oct 8, 2021

@NiklasGustafsson

Do you think this is a valid use case?

Enhance the exportsd.py script so that TorchSharp compatible parameters and weights can be exported and then loaded back into TorchSharp and populate a TorchSharp custom model as described in PyTorch main.py

This will provide the .NET community additional ways to evaluate how close compatibility is TorchSharp with the PyTorch counterpart.

@NiklasGustafsson
Copy link
Contributor

Yes, it is exactly what the intent for Module.load is -- to deserialize parameters. However, the model definition needs to exactly match the origin, since (unlike Python) .NET cannot create a class definition and then instantiate it. Like I mentioned, Module.load() loads weights, not modules.

@GeorgeS2019
Copy link
Author

GeorgeS2019 commented Oct 8, 2021

This is the method to load weights, not modules as described in Utils.py

Is that conceivable possible with TorchSharp?

Export TorchSharp compatible weights and load them back to a compatible model using TorchSharp compatible codes as described below?

def load_weight(model, state_dict):
    old_keys = []
    new_keys = []
    for key in state_dict.keys():
        new_key = None
        if key.endswith(".g"):
            new_key = key[:-2] + ".weight"
        elif key.endswith(".b"):
            new_key = key[:-2] + ".bias"
        elif key.endswith(".w"):
            new_key = key[:-2] + ".weight"
        if new_key:
            old_keys.append(key)
            new_keys.append(new_key)
    for old_key, new_key in zip(old_keys, new_keys):
        state_dict[new_key] = state_dict.pop(old_key)

    missing_keys = []
    unexpected_keys = []
    error_msgs = []
    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, "_metadata", None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

@NiklasGustafsson
Copy link
Contributor

Anything that looks like a <String,Tensor> dictionary and was saved using the format that exportsd.py also uses should be possible to load, but when loading, the keys come from the model instance (either a custom module or Sequential) that the weights are being loaded into. On the saving side, the keys likewise come from the original model.

Thus, the two have to exactly match -- that's the key here. Without seeing the model definition on both sides, it's hard to help debug it. The best I can do, and I will try to get that into the next release, is to improve the error messages so that they are more informative.

@NiklasGustafsson
Copy link
Contributor

@GeorgeS2019 -- I suggest adding a print statement (on your machine) to the exportsd.py, something like:

    for entry in sd:
        print(entry)
        stream.write(leb128.u.encode(len(entry)))
        stream.write(bytes(entry, 'utf-8'))
        _write_tensor(sd[entry], stream)

and see what the names of all the state_dict entries are, then compare that to your .NET module that you are loading the weights into.

@GeorgeS2019
Copy link
Author

Ref

PyTorch serialization formats are a moving target and saving to ONNX is more reliable.

That discussion was back in 2019. I wonder how moving target is PyTorch format 2021 and will that affect future TorchSharp import PyTorch model?

@NiklasGustafsson
Copy link
Contributor

@GeorgeS2019 -

In order to train with TorchSharp, you will always need a representation of the model in code.

A longer-term solution than the one we have in place now (loading weights into a model instance) will be to generate the TorchSharp code from an ONNX graph, so that you have full fidelity, then load the weights into that.

Implementing ONNX import will take some time, so if you need to transfer from PyTorch to TorchSharp before then, the existing mechanism is your only option.

If the article at: https://github.com/dotnet/TorchSharp/blob/main/docfx/articles/saveload.md doesn't describe the mechanics of the current approach in sufficient detail, please file a documentation issue with feedback on where the article is lacking in detail or clarity.

@GeorgeS2019
Copy link
Author

GeorgeS2019 commented Oct 11, 2021

@NiklasGustafsson I see this eventually related to ML.NET TorchSharp integration
image
It is not a simple decision. Let us hope there will be active participation to such discussion from TorchSharp/ML.NET communities

@NiklasGustafsson
Copy link
Contributor

@GeorgeS2019 I'm curious -- where did the diagram above come from?

@GeorgeS2019
Copy link
Author

@NiklasGustafsson I took an existing diagram from a Microsoft documentation on tensorflow integration to ML.NET and simply clone it and add in the PyTorch part.

@NiklasGustafsson
Copy link
Contributor

NiklasGustafsson commented Oct 11, 2021

One of the differences between TF and PyTroch is that TF saves the graph using protobuf, which is language-independent.

In PyTorch, there really isn't a graph, and the weights are saved using Python pickling, which is a serialization format that is specific to Python. That's why you need to use something like ONNX to get not only the weights, but also the graph. This is why you need the model code in TorchSharp before you can load the weights.

See, for example, the PyTorch documentation for 'load_state_dict()' (which is analogous to what we're doing in TorchSharp):

load_state_dict(state_dict, strict=True)
Copies parameters and buffers from state_dict into this module and its descendants. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Once the model has been exported to ONNX, we can (theoretically) use that to recreate the model code in C# (or F#), which can then be used to load the model weights. That will take a lot of work, so it's not something that is coming soon. In the meantime, you will have to recreate the model code in .NET if you want to load parameters from Python.

@GeorgeS2019
Copy link
Author

GeorgeS2019 commented Nov 4, 2021

@GeorgeS2019
Copy link
Author

I close this issue now after an unit test provided to show how to save dict_states in pytorch and load them in torchsharp.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants