Skip to content

Loading tensors in lists/dict that have not yet been instantiated #101

@vmoens

Description

@vmoens

🚀 The feature

We'd like to be able to load tensors that are saved on disk but do not yet populate the destination module.

Motivation, pitch

Say we have a module that stores a list of tensors. During training, we increment that list.

If I'm using regular torch.save(state_dict). We will end up with a dictionary with a list of tensors, and we can just load it back where it belongs (as loading is not done in place).

With torchsnapshot, what I understand is that snapshot will look for my current state_dict, and repopulate it in-place. Hence, if my list of tensors is empty (which I expect to be when I load a checkpoint) all the tensors in the list will be discarded.

Example:

from torchsnapshot import StateDict, Snapshot
import torch
import os

def list_files(startpath):
    for root, dirs, files in os.walk(startpath):
        level = root.replace(startpath, '').count(os.sep)
        indent = ' ' * 4 * (level)
        print('{}{}/'.format(indent, os.path.basename(root)))
        subindent = ' ' * 4 * (level + 1)
        for f in files:
            print('{}{}'.format(subindent, f))

class ClassWithSD:
    def __init__(self):
        self.obj = []
    def state_dict(self):
        return {"obj": self.obj}
    def load_state_dict(self, sd):
        self.obj = sd["obj"]


x = ClassWithSD()

# let's put 2 tensors in out list. We'd like to get them back when loading
x.obj.append(torch.tensor([1.0]))
x.obj.append(torch.tensor([2.0]))

app_state = {"x": x}
Snapshot.take(app_state=app_state, path="./")


snapshot = Snapshot(path="./")
y = ClassWithSD()
app_state = {"x": y}
snapshot.restore(app_state=app_state)

print(list_files("./0"))
print("content before take:", x.obj)
print("content after restore:", y.obj)

# with torch.save

torch.save(x.state_dict(), "torch_saved.pt")
y = ClassWithSD()
y.load_state_dict(torch.load("torch_saved.pt"))
print("torch.save:", y.obj)

Alternatives

No response

Additional context

Looking at this: https://github.com/pytorch/torchsnapshot/blob/4596fc6baf0fc9662cbfbc8d363cf115dc46d517/torchsnapshot/snapshot.py#L681-L736

I guess that what I would like is that if not all available_entries are loaded, the remaining logical_paths are still loaded in the state_dict that will be given to the stateful.load_state_dict(...) at line 736.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions