In [7]:
from abc import ABC
import torch
import torch.nn as nn

In [8]:
x0, y0 = torch.normal(mean=1.7, std=1, size=(100, 2)) + 1, torch.zeros(100)  # 数据集1
x1, y1 = torch.normal(mean=-1.7, std=1, size=(100, 2)) + 1, torch.ones(100)  # 数据集2
train_x, train_y = torch.cat((x0, x1), 0), torch.cat((y0, y1), 0)


class LR(nn.Module, ABC):
    def __init__(self):
        super(LR, self).__init__()
        self.features = nn.Linear(2, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.features(x)
        x = self.sigmoid(x)
        return x


lr_net = LR()
loss_fn = nn.BCELoss()
optimizer = torch.optim.SGD(lr_net.parameters(), lr=0.01, momentum=0.9)

for iteration in range(500):
    optimizer.zero_grad()  # 梯度清零
    y_pred = lr_net(train_x)
    loss = loss_fn(y_pred.squeeze(), train_y)
    loss.backward()
    optimizer.step()

    if iteration % 20 == 0:
        print("iteration：{iteration},    loss: {loss}".format(iteration=iteration, loss=loss))

iteration：0,    loss: 0.5072945356369019
iteration：20,    loss: 0.24114477634429932
iteration：40,    loss: 0.16167768836021423
iteration：60,    loss: 0.13019175827503204
iteration：80,    loss: 0.11143177002668381
iteration：100,    loss: 0.09831245243549347
iteration：120,    loss: 0.08842919021844864
iteration：140,    loss: 0.0806545540690422
iteration：160,    loss: 0.07435550540685654
iteration：180,    loss: 0.06913800537586212
iteration：200,    loss: 0.06474029272794724
iteration：220,    loss: 0.06098010390996933
iteration：240,    loss: 0.05772606283426285
iteration：260,    loss: 0.05488075688481331
iteration：280,    loss: 0.0523703470826149
iteration：300,    loss: 0.05013774335384369
iteration：320,    loss: 0.048138175159692764
iteration：340,    loss: 0.04633596912026405
iteration：360,    loss: 0.04470236599445343
iteration：380,    loss: 0.04321398213505745
iteration：400,    loss: 0.041851550340652466
iteration：420,    loss: 0.04059908539056778
iteration：440,    loss: 0.0394432321190

In [9]:
lr_dict = lr_net.state_dict()  # Returns a dictionary containing a whole state of the module
lr_dict

OrderedDict([('features.weight', tensor([[-1.6137, -1.4067]])),
             ('features.bias', tensor([2.1918]))])

In [10]:
torch.save(lr_dict,
           'torch_dict.pth')  # 保存的状态字典的文件

In [11]:
lr_net_new = LR()  # 定义一个新的模型

# torch.load():Loads an object saved with torch.save() from a file.
# .load_state_dict():Copies parameters and buffers from :attr:`state_dict` into this module and its descendants
lr_net_new.load_state_dict(torch.load('torch_dict.pth'))
lr_net_new

LR(
  (features): Linear(in_features=2, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

In [12]:
lr_net_new.state_dict()

OrderedDict([('features.weight', tensor([[-1.6137, -1.4067]])),
             ('features.bias', tensor([2.1918]))])