- Reference: [Serialization Semantics](https://pytorch.org/docs/stable/notes/serialization.html)

In [1]:
import torch

# Saving and loading tensors

In [2]:
# 'torch.save()' and 'torch.load()' let you easily save and load tensors:

t = torch.tensor([1., 2.])
torch.save(t, 'tensor.pt')

In [3]:
torch.load('tensor.pt')

tensor([1., 2.])

In [4]:
# 'torch.save()' and 'torch.load()' use Python's pickle by default,
# so you can also save multiple tensors as part of Python objects like tuples, lists, and dicts:
d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
torch.save(d, 'tensor_dict.pt')

In [5]:
torch.load('tensor_dict.pt')

{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}

# Saving and loading tensors preserves views

In [6]:
numbers = torch.arange(1, 10)
evens = numbers[1::2]

torch.save([numbers, evens], 'tensors.pt')

In [7]:
# Saving tensors preserve their view relationship
# Behind the scenes, these tensors share the same "storage"
loaded_numbers, loaded_evens = torch.load('tensors.pt')
loaded_evens *= 2
print(loaded_numbers)

tensor([ 1,  4,  3,  8,  5, 12,  7, 16,  9])


In [8]:
# In some caees, however, saving the current storage objects may be unnecessary and
# create prohibitively large files.
# In the folloing snippet a storage must larger than saved tensor is written to a file:
large = torch.arange(1, 1000)
small = large[0:5]

torch.save(small, 'small.pt')

In [9]:
loaded_small = torch.load('small.pt')
print(loaded_small.storage().size())

999


  print(loaded_small.storage().size())


In [10]:
# When saving tensors with fewer elements than their storage objects,
# the size of the svaed file can be reduced by first cloning the tensors.
torch.save(small.clone(), 'small.pt') # saves a clone of small

In [11]:
loaded_small = torch.load('small.pt')
print(loaded_small.storage().size())

5


# Saving and loading torch.nn.Modules

- see also: [Saving and Loading Models](https://pytorch.org/tutorials/beginner/saving_loading_models.html)

In PyTorch, a module's state is frequently serialized using a 'state dict'. A module's state dict contains all of its parameters and persistent buffers:

In [12]:
bn = torch.nn.BatchNorm1d(3, track_running_stats=True)

In [13]:
list(bn.named_parameters())

[('weight',
  Parameter containing:
  tensor([1., 1., 1.], requires_grad=True)),
 ('bias',
  Parameter containing:
  tensor([0., 0., 0.], requires_grad=True))]

In [14]:
list(bn.named_buffers())

[('running_mean', tensor([0., 0., 0.])),
 ('running_var', tensor([1., 1., 1.])),
 ('num_batches_tracked', tensor(0))]

In [15]:
bn.state_dict()

OrderedDict([('weight', tensor([1., 1., 1.])),
             ('bias', tensor([0., 0., 0.])),
             ('running_mean', tensor([0., 0., 0.])),
             ('running_var', tensor([1., 1., 1.])),
             ('num_batches_tracked', tensor(0))])

Instead of saving a module directly, for compatibility reasons it is recommended to instead save only its state dict. Python modules even have a function, `load_state_dict()`, to restore their states from a state dict:

In [16]:
torch.save(bn.state_dict(), 'bn.pt')
bn_state_dict = torch.load('bn.pt')
new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
new_bn.load_state_dict(bn_state_dict)

<All keys matched successfully>

Even custom modules and modules containing other modules have state dicts and can use this pattern:

In [17]:
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

    def forward(self, input):
        out = self.l0(input)
        out_relu = torch.nn.functional.relu(out)
        return self.l1(out_relu)

m = MyModule()
print(m.state_dict())

OrderedDict([('l0.weight', tensor([[ 0.4830,  0.2238, -0.3981,  0.0723],
        [ 0.2290, -0.3365, -0.2250,  0.0290]])), ('l0.bias', tensor([-0.0201,  0.2434])), ('l1.weight', tensor([[-0.1918,  0.5304]])), ('l1.bias', tensor([-0.2418]))])


In [18]:
torch.save(m.state_dict(), 'mymodule.pt')
m_state_dict = torch.load('mymodule.pt')
new_m = MyModule()
new_m.load_state_dict(m_state_dict)

<All keys matched successfully>