<a href="https://colab.research.google.com/github/junxnone/samples/blob/main/pytorch/pytorch_models_save.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [30]:
from torchinfo import summary
import torch
from torch import nn

class LSTMNet(nn.Module):
    def __init__(self, vocab_size=20, embed_dim=300, hidden_dim=512, num_layers=2):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.encoder = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True)
        self.decoder = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        embed = self.embedding(x)
        out, hidden = self.encoder(embed)
        out = self.decoder(out)
        out = out.view(-1, out.size(2))
        return out, hidden

summary(
    LSTMNet(),
    (1, 100),
    dtypes=[torch.long],
    verbose=0,
    col_width=16,
    col_names=["kernel_size", "output_size", "num_params", "mult_adds"],
    row_settings=["var_names"],
)

Layer (type (var_name))                  Kernel Shape     Output Shape     Param #          Mult-Adds
LSTMNet (LSTMNet)                        --               [100, 20]        --               --
├─Embedding (embedding)                  --               [1, 100, 300]    6,000            6,000
├─LSTM (encoder)                         --               [1, 100, 512]    3,768,320        376,832,000
├─Linear (decoder)                       --               [1, 100, 20]     10,260           10,260
Total params: 3,784,580
Trainable params: 3,784,580
Non-trainable params: 0
Total mult-adds (M): 376.85
Input size (MB): 0.00
Forward/backward pass size (MB): 0.67
Params size (MB): 15.14
Estimated Total Size (MB): 15.80

# 保存结构和参数

In [7]:
torch.save(LSTMNet(), 'LSTMNet.pt')
load_model = torch.load("LSTMNet.pt")

In [10]:
!ls LSTMNet.pt -alh

-rw-r--r-- 1 root root 15M Aug  5 07:38 LSTMNet.pt


In [9]:
type(load_model)

In [12]:
summary(
    load_model,
    (1, 100),
    dtypes=[torch.long],
    verbose=2,
    col_width=16,
    col_names=["kernel_size", "output_size", "num_params", "mult_adds"],
    row_settings=["var_names"],
)

Layer (type (var_name))                  Kernel Shape     Output Shape     Param #          Mult-Adds
LSTMNet (LSTMNet)                        --               [100, 20]        --               --
├─Embedding (embedding)                  --               [1, 100, 300]    6,000            6,000
│    └─weight                            [300, 20]                         └─6,000
├─LSTM (encoder)                         --               [1, 100, 512]    3,768,320        376,832,000
│    └─weight_ih_l0                      [2048, 300]                       ├─614,400
│    └─weight_hh_l0                      [2048, 512]                       ├─1,048,576
│    └─bias_ih_l0                        [2048]                            ├─2,048
│    └─bias_hh_l0                        [2048]                            ├─2,048
│    └─weight_ih_l1                      [2048, 512]                       ├─1,048,576
│    └─weight_hh_l1                      [2048, 512]                       ├─1,048,576
│    └

Layer (type (var_name))                  Kernel Shape     Output Shape     Param #          Mult-Adds
LSTMNet (LSTMNet)                        --               [100, 20]        --               --
├─Embedding (embedding)                  --               [1, 100, 300]    6,000            6,000
│    └─weight                            [300, 20]                         └─6,000
├─LSTM (encoder)                         --               [1, 100, 512]    3,768,320        376,832,000
│    └─weight_ih_l0                      [2048, 300]                       ├─614,400
│    └─weight_hh_l0                      [2048, 512]                       ├─1,048,576
│    └─bias_ih_l0                        [2048]                            ├─2,048
│    └─bias_hh_l0                        [2048]                            ├─2,048
│    └─weight_ih_l1                      [2048, 512]                       ├─1,048,576
│    └─weight_hh_l1                      [2048, 512]                       ├─1,048,576
│    └

## 仅保存参数
- 用于模型加载保存的不同的权重

In [24]:
torch.save(LSTMNet().state_dict(), 'LSTMNet_state_dict.pt')
load_model = LSTMNet()
state_dict = torch.load("LSTMNet_state_dict.pt")
print(type(state_dict))
load_model.load_state_dict(state_dict)
print(type(load_model))
!ls LSTMNet_state_dict.pt -alh
summary(
    load_model,
    (1, 100),
    dtypes=[torch.long],
    verbose=2,
    col_width=16,
    col_names=["kernel_size", "output_size", "num_params", "mult_adds"],
    row_settings=["var_names"],
)


<class 'collections.OrderedDict'>
<class '__main__.LSTMNet'>
-rw-r--r-- 1 root root 15M Aug  5 07:51 LSTMNet_state_dict.pt
Layer (type (var_name))                  Kernel Shape     Output Shape     Param #          Mult-Adds
LSTMNet (LSTMNet)                        --               [100, 20]        --               --
├─Embedding (embedding)                  --               [1, 100, 300]    6,000            6,000
│    └─weight                            [300, 20]                         └─6,000
├─LSTM (encoder)                         --               [1, 100, 512]    3,768,320        376,832,000
│    └─weight_ih_l0                      [2048, 300]                       ├─614,400
│    └─weight_hh_l0                      [2048, 512]                       ├─1,048,576
│    └─bias_ih_l0                        [2048]                            ├─2,048
│    └─bias_hh_l0                        [2048]                            ├─2,048
│    └─weight_ih_l1                      [2048, 512]     

Layer (type (var_name))                  Kernel Shape     Output Shape     Param #          Mult-Adds
LSTMNet (LSTMNet)                        --               [100, 20]        --               --
├─Embedding (embedding)                  --               [1, 100, 300]    6,000            6,000
│    └─weight                            [300, 20]                         └─6,000
├─LSTM (encoder)                         --               [1, 100, 512]    3,768,320        376,832,000
│    └─weight_ih_l0                      [2048, 300]                       ├─614,400
│    └─weight_hh_l0                      [2048, 512]                       ├─1,048,576
│    └─bias_ih_l0                        [2048]                            ├─2,048
│    └─bias_hh_l0                        [2048]                            ├─2,048
│    └─weight_ih_l1                      [2048, 512]                       ├─1,048,576
│    └─weight_hh_l1                      [2048, 512]                       ├─1,048,576
│    └

In [25]:
type(state_dict)
for key in state_dict.keys():
    print(f'{key}  : {state_dict[key].shape}')


embedding.weight  : torch.Size([20, 300])
encoder.weight_ih_l0  : torch.Size([2048, 300])
encoder.weight_hh_l0  : torch.Size([2048, 512])
encoder.bias_ih_l0  : torch.Size([2048])
encoder.bias_hh_l0  : torch.Size([2048])
encoder.weight_ih_l1  : torch.Size([2048, 512])
encoder.weight_hh_l1  : torch.Size([2048, 512])
encoder.bias_ih_l1  : torch.Size([2048])
encoder.bias_hh_l1  : torch.Size([2048])
decoder.weight  : torch.Size([20, 512])
decoder.bias  : torch.Size([20])
