In [7]:
from abc import ABC
import torch
import torch.nn as nn

In [8]:
x0, y0 = torch.normal(mean=1.7, std=1, size=(100, 2)) + 1, torch.zeros(100)  # 数据集1
x1, y1 = torch.normal(mean=-1.7, std=1, size=(100, 2)) + 1, torch.ones(100)  # 数据集2
train_x, train_y = torch.cat((x0, x1), 0), torch.cat((y0, y1), 0)


class LR(nn.Module, ABC):
    def __init__(self):
        super(LR, self).__init__()
        self.features = nn.Linear(2, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.features(x)
        x = self.sigmoid(x)
        return x


lr_net = LR()
loss_fn = nn.BCELoss()
optimizer = torch.optim.SGD(lr_net.parameters(), lr=0.01, momentum=0.9)

for iteration in range(500):
    optimizer.zero_grad()  # 梯度清零
    y_pred = lr_net(train_x)
    loss = loss_fn(y_pred.squeeze(), train_y)
    loss.backward()
    optimizer.step()

    if iteration % 20 == 0:
        print("iteration：{iteration},    loss: {loss}".format(iteration=iteration, loss=loss))

iteration：0,    loss: 2.07944393157959
iteration：20,    loss: 0.16775953769683838
iteration：40,    loss: 0.12824194133281708
iteration：60,    loss: 0.11418388038873672
iteration：80,    loss: 0.10358712822198868
iteration：100,    loss: 0.09519942104816437
iteration：120,    loss: 0.0884229838848114
iteration：140,    loss: 0.08285802602767944
iteration：160,    loss: 0.07822190970182419
iteration：180,    loss: 0.07430948317050934
iteration：200,    loss: 0.07096906006336212
iteration：220,    loss: 0.0680866688489914
iteration：240,    loss: 0.06557554006576538
iteration：260,    loss: 0.0633687674999237
iteration：280,    loss: 0.06141415983438492
iteration：300,    loss: 0.0596705824136734
iteration：320,    loss: 0.0581052340567112
iteration：340,    loss: 0.056691765785217285
iteration：360,    loss: 0.055408719927072525
iteration：380,    loss: 0.05423854663968086
iteration：400,    loss: 0.0531667023897171
iteration：420,    loss: 0.052181076258420944
iteration：440,    loss: 0.051271501928567886

In [9]:
lr_dict = lr_net.state_dict()  # Returns a dictionary containing a whole state of the module
lr_dict

OrderedDict([('features.weight', tensor([[-1.6983, -1.4639]])),
             ('features.bias', tensor([2.2274]))])

In [10]:
torch.save(lr_dict,
           'torch_dict.pth')  # 保存的状态字典的文件

In [11]:
lr_net_new = LR()  # 定义一个新的模型

# torch.load():Loads an object saved with torch.save() from a file.
# .load_state_dict():Copies parameters and buffers from :attr:`state_dict` into this module and its descendants
lr_net_new.load_state_dict(torch.load('lr_dict.pth'))

<All keys matched successfully>

In [12]:
lr_net_new.state_dict()

OrderedDict([('features.weight', tensor([[-1.2590, -1.4755]])),
             ('features.bias', tensor([2.2686]))])