fix(examples): Move non-persistent buffers to a valid device #91
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Non-persistent buffer devices
During deserialization using
load_into_module
, only the buffers loaded from the deserializer are overwritten on the target module, and any extra buffers on the target module remain unmodified. This can lead to an error if there are existing buffers on the target module that are on a different device from the deserializer, since the model will end up scattered across multiple devices.This primarily arises as an issue when serializing with
include_non_persistent_buffers=False
, since non-persistent buffers may be created on any device, and will not be overwritten by the deserializer. In particular, this led to an error when running the code indeserialize.py
(which does not expect buffers from anything but the deserializer) on a model serialized usingexamples/hf_serialization.py
(which usesinclude_non_persistent_buffers=False
).This change adds an extra section to
deserialize.py
that moves remaining non-persistent buffers to a device matching the deserializer after regular deserialization finishes. It is not included in the speed measurement because non-persistent buffers are explicitly not any business of the deserializer—only of external code—and they can be included and measured by settinginclude_non_persistent_buffers=True
if it is desired to force the deserializer to handle them instead.