In [209]:
from abc import ABC
import torch
import torch.nn as nn
import copy

In [210]:
x0, y0 = torch.normal(mean=1.7, std=1, size=(100, 2)) + 1, torch.zeros(100)  # 数据集1
x1, y1 = torch.normal(mean=-1.7, std=1, size=(100, 2)) + 1, torch.ones(100)  # 数据集2
train_x, train_y = torch.cat((x0, x1), 0), torch.cat((y0, y1), 0)


class LR(nn.Module, ABC):
    def __init__(self):
        super(LR, self).__init__()
        self.features0 = nn.Linear(2, 2)
        self.features1 = nn.Linear(2, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.features0(x)
        x = self.features1(x)
        x = self.sigmoid(x)
        return x


lr_net = LR()
loss_fn = nn.BCELoss()
optimizer = torch.optim.SGD(lr_net.parameters(), lr=0.01, momentum=0.9)

In [211]:
for name, para in lr_net.named_parameters():
    print(name)

features0.weight
features0.bias
features1.weight
features1.bias


In [212]:
lr_dict_init = copy.deepcopy(lr_net.state_dict())  # OrderedDict
lr_dict_init

OrderedDict([('features0.weight',
              tensor([[-0.4356,  0.2673],
                      [ 0.0616, -0.6159]])),
             ('features0.bias', tensor([-0.4102, -0.2546])),
             ('features1.weight', tensor([[-0.0896,  0.0550]])),
             ('features1.bias', tensor([-0.3759]))])

In [213]:
for iteration in range(500):
    optimizer.zero_grad()  # 梯度清零
    y_pred = lr_net(train_x)
    loss = loss_fn(y_pred.squeeze(), train_y)
    loss.backward()
    optimizer.step()

    if iteration % 20 == 0:
        print("iteration：{iteration},    loss: {loss}".format(iteration=iteration, loss=loss))

iteration：0,    loss: 0.6967438459396362
iteration：20,    loss: 0.4755525290966034
iteration：40,    loss: 0.24046087265014648
iteration：60,    loss: 0.1566372662782669
iteration：80,    loss: 0.11610709130764008
iteration：100,    loss: 0.09247508645057678
iteration：120,    loss: 0.07726152241230011
iteration：140,    loss: 0.06672447174787521
iteration：160,    loss: 0.05904078856110573
iteration：180,    loss: 0.05322685092687607
iteration：200,    loss: 0.04869944602251053
iteration：220,    loss: 0.04508866369724274
iteration：240,    loss: 0.042149100452661514
iteration：260,    loss: 0.0397125743329525
iteration：280,    loss: 0.03766081854701042
iteration：300,    loss: 0.035908836871385574
iteration：320,    loss: 0.03439432010054588
iteration：340,    loss: 0.033070798963308334
iteration：360,    loss: 0.03190302476286888
iteration：380,    loss: 0.030863866209983826
iteration：400,    loss: 0.029932165518403053
iteration：420,    loss: 0.029091153293848038
iteration：440,    loss: 0.0283274408

In [214]:
lr_dict = lr_net.state_dict()  # Returns a dictionary containing a whole state of the module
lr_dict

OrderedDict([('features0.weight',
              tensor([[-0.2577,  0.3256],
                      [-1.2212, -0.9174]])),
             ('features0.bias', tensor([-0.6274,  1.4793])),
             ('features1.weight', tensor([[-0.3847,  2.0133]])),
             ('features1.bias', tensor([1.2211]))])

In [215]:
lr_dict_init

OrderedDict([('features0.weight',
              tensor([[-0.4356,  0.2673],
                      [ 0.0616, -0.6159]])),
             ('features0.bias', tensor([-0.4102, -0.2546])),
             ('features1.weight', tensor([[-0.0896,  0.0550]])),
             ('features1.bias', tensor([-0.3759]))])

In [216]:
torch.save(lr_dict,
           'torch_dict.pth')  # 保存的状态字典的文件

In [217]:
lr_net_new = LR()  # 定义一个新的模型

# torch.load():Loads an object saved with torch.save() from a file.
# .load_state_dict():Copies parameters and buffers from :attr:`state_dict` into this module and its descendants
lr_net_new.load_state_dict(torch.load('torch_dict.pth'))
lr_net_new

LR(
  (features0): Linear(in_features=2, out_features=2, bias=True)
  (features1): Linear(in_features=2, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

In [218]:
lr_net_new.state_dict()

OrderedDict([('features0.weight',
              tensor([[-0.2577,  0.3256],
                      [-1.2212, -0.9174]])),
             ('features0.bias', tensor([-0.6274,  1.4793])),
             ('features1.weight', tensor([[-0.3847,  2.0133]])),
             ('features1.bias', tensor([1.2211]))])