# File I/O

## Loading and Saving Tensors

In [1]:
import torch
from torch import nn
from torch.nn import functional as F

In [3]:
x = torch.arange(4)
torch.save(x, "arange-4")

In [5]:
x2 = torch.load("arange-4")
x2

tensor([0, 1, 2, 3])

In [6]:
y= torch.zeros(4)
torch.save([x,y],"xy")

In [7]:
torch.load("xy")

[tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.])]

In [8]:
mydict = {"x": x, "y": y}
torch.save(mydict, "xydict")

In [10]:
torch.load("xydict")

{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}

## Loading and Saving Model Parameters

In [11]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20,256)
        self.out = nn.Linear(256, 10)
    
    def forward(self, X):
        return self.out(F.relu(self.hidden(X)))

net = MLP()
x = torch.rand(2,20)
net(x)

tensor([[ 0.0776,  0.1052,  0.0336, -0.0342,  0.4236,  0.0860, -0.0179, -0.0939,
         -0.0709, -0.0778],
        [ 0.2364,  0.0829,  0.0299,  0.0339,  0.1920,  0.0753, -0.0074, -0.1779,
         -0.0821, -0.0153]], grad_fn=<AddmmBackward>)

In [15]:
for name,block in net.named_parameters():
    print(name, block.shape)

hidden.weight torch.Size([256, 20])
hidden.bias torch.Size([256])
out.weight torch.Size([10, 256])
out.bias torch.Size([10])


In [16]:
torch.save(net.state_dict(), "mlp.params")

In [22]:
clone = MLP()
clone.load_state_dict(torch.load("mlp.params"))
clone(x)

tensor([[ 0.0776,  0.1052,  0.0336, -0.0342,  0.4236,  0.0860, -0.0179, -0.0939,
         -0.0709, -0.0778],
        [ 0.2364,  0.0829,  0.0299,  0.0339,  0.1920,  0.0753, -0.0074, -0.1779,
         -0.0821, -0.0153]], grad_fn=<AddmmBackward>)

In [23]:
torch.save(net.hidden.state_dict(),"mlp-hidden.params")
torch.save(net.out.state_dict(),"mlp-out.params")

clone2 = MLP()
clone2.hidden.load_state_dict(torch.load("mlp-hidden.params"))
clone2.out.load_state_dict(torch.load("mlp-out.params"))
clone2(x)

tensor([[ 0.0776,  0.1052,  0.0336, -0.0342,  0.4236,  0.0860, -0.0179, -0.0939,
         -0.0709, -0.0778],
        [ 0.2364,  0.0829,  0.0299,  0.0339,  0.1920,  0.0753, -0.0074, -0.1779,
         -0.0821, -0.0153]], grad_fn=<AddmmBackward>)