# 逻辑回归
通过线性回归模型，我们已经了解过一个模型基本构成，其实模型就是一个数学函数簇，通过数据只是学习函数簇的参数得到具体的函数，从本篇开始学习逻辑回归，在学习逻辑回归之前，我们大概聊一下模型的分类。
通常来说，模型可以分如下几类：
+ 监督学习模型：训练样本不包含答案
+ 非监督学习模型：训练样本包含答案
  - 回归模型：答案是连续的实数
  - 分类模型：答案是离散的整数，分类模型又分为二分类和多分类模型

那么回归模型属于监督学习的回归模型，而本篇介绍的逻辑回归，就是属于监督学习中的二分类模型。


## 逻辑回归模型
下面我们快速撕开逻辑回归模型本质
+ 数据：根据问题来的，比如预测明天会不会下雨，我们会收集很多下雨天和非下雨天之前的天气情况，比如一天前的湿度、温度、前一周的平均下雨天数等【各种变量】，还会收集当天的天气【答案或者目标】，这个就是训练样本
+ 模型：逻辑回归就是规定函数形式就是：y_hat = sigmod(-(b + a1\*x1 + a2\*x2 + an\*xm))，sigmod 是一个将任意实数映射到 0-1 的概率的一个转换函数，sigmod(x)=1/(1+exp(-x))，sigmod(x) 的导数 = sigmod(1)\*(1-sigmod(x))
+ 参数：b a1 a2 am 就是参数
+ 损失函数：z = sum(-(y\*log(y_hat) + (1-y)\*log(1-y_hat)))/n，n 是样本数量
+ 优化方法：梯度下降等，得益于 PyTorch 自动求梯度，我们不再需要显示求出损失函数的导数

可以看到，逻辑回归的模型一部分和线性回归很相似，这也是它名字中包含回归的原因。

下面我们来用 PyTorch 来构建整个逻辑回归模型。


## 1. 数据准备
我们这里选择直接构造一个训练集

In [8]:
import torch
import numpy as np
X = torch.tensor([[2., 1.],
    [2., 2.],
    [5., 4.],
    [4., 5.],
    [2., 3.],
    [3., 2.],
    [6., 5.],
    [4., 1.],
    [6., 3.],
    [7, 4]], dtype=torch.float
)
print('X.shape', X.shape)
true_Y = torch.tensor([0, 0, 1, 1, 0, 0, 1, 0, 1, 1], dtype=torch.float)
print('true_Y.shape', true_Y.shape)
print(true_Y)

X.shape torch.Size([10, 2])
true_Y.shape torch.Size([10])
tensor([0., 0., 1., 1., 0., 0., 1., 0., 1., 1.])


## 2. 定义基本工具函数
+ 模型函数，指定样本和权重，对应的输出
+ 初始化参数函数，初始化所有必须的参数
+ 损失函数，初始化衡量误差的函数
+ 样本获取函数，如何批量获取样本

In [9]:
# 定义回归模型
def logic_reg_model(weight, bias, input, batch_size):
    return torch.sigmoid(torch.mm(input, weight).view(batch_size) + bias)

def params_init():
    p_weights = torch.randn(2, 1, dtype=torch.float, requires_grad=True)
    p_bias = torch.zeros(1, dtype=torch.float, requires_grad=True)
    return p_weights, p_bias

def loss_func(true_Y, hat_Y):
    error = -(true_Y*torch.log(hat_Y) + (1-true_Y)*torch.log(hat_Y))
    return error ** 2

def sample_batchs(X, true_Y, sample_nums, batch_size):
    res = []
    inds = list(range(0, sample_nums))
    np.random.shuffle(inds)
    cur_ind = 0
    while cur_ind + batch_size < sample_nums:
        keep_inds = inds[cur_ind:cur_ind + batch_size]
        res.append((torch.index_select(X, 0, torch.tensor(keep_inds, dtype=torch.int64)), 
            torch.index_select(true_Y, 0, torch.tensor(keep_inds, dtype=torch.int64))))
        cur_ind += batch_size
    keep_inds = inds[cur_ind:cur_ind + batch_size]
    if keep_inds:
        res.append((torch.index_select(X, 0, torch.tensor(keep_inds, dtype=torch.int64)), 
            torch.index_select(true_Y, 0, torch.tensor(keep_inds, dtype=torch.int64))))
    return res

## 3. 开始模型训练
采用微批梯度下降法，需要指定以下参数：
+ epoch_nums: 整个样本迭代训练几次
+ batch_nums: 每次微批的数据量大小
+ X: 输入的数据 X
+ ture_Y: 样本真正的 Y
+ step_ratio: 迭代步长

In [16]:
epoch_nums = 200
batch_nums = 2
step_ratio = 0.03
sample_nums = 10
p_weight, p_bias = params_init()
print(p_weight)
print(p_bias)

for epoch in range(0, epoch_nums):
    batchs = sample_batchs(X, true_Y, sample_nums, batch_nums)
    for batch in batchs:
        b_X, b_true_Y = batch
        batch_size = b_true_Y.shape[0]
        b_hat_Y = logic_reg_model(p_weight, p_bias, b_X, batch_size)
        loss = loss_func(b_true_Y, b_hat_Y).sum()

        loss.backward()
        p_weight.data -= step_ratio*p_weight.grad/batch_size
        p_bias.data -= step_ratio*p_bias.grad/batch_size

        p_weight.grad.data.zero_()
        p_bias.grad.data.zero_()
    if (epoch+1) % 20 == 0:
        with torch.no_grad():
            hat_Y = logic_reg_model(p_weight, p_bias, X, sample_nums)
            loss = loss_func(true_Y, hat_Y)
            print("epoch={}, loss={}".format(epoch+1, loss.mean()))
            print("p_weight={}".format(p_weight))
            print("p_bias={}".format(p_bias))

tensor([[ 0.8588],
        [-0.2786]], requires_grad=True)
tensor([0.], requires_grad=True)
epoch=20, loss=0.006928201764822006
p_weight=tensor([[ 1.0085],
        [-0.1180]], requires_grad=True)
p_bias=tensor([0.0699], requires_grad=True)
epoch=40, loss=0.004014096688479185
p_weight=tensor([[ 1.0698],
        [-0.0553]], requires_grad=True)
p_bias=tensor([0.0995], requires_grad=True)
epoch=60, loss=0.002837373409420252
p_weight=tensor([[ 1.1095],
        [-0.0157]], requires_grad=True)
p_bias=tensor([0.1187], requires_grad=True)
epoch=80, loss=0.0021992085967212915
p_weight=tensor([[1.1390],
        [0.0131]], requires_grad=True)
p_bias=tensor([0.1330], requires_grad=True)
epoch=100, loss=0.0017981810960918665
p_weight=tensor([[1.1625],
        [0.0359]], requires_grad=True)
p_bias=tensor([0.1445], requires_grad=True)
epoch=120, loss=0.0015225138049572706
p_weight=tensor([[1.1821],
        [0.0546]], requires_grad=True)
p_bias=tensor([0.1541], requires_grad=True)
epoch=140, loss=0.001

## 4. 逻辑回归高级版本实现
通过 PyTorch Lightning 来实现， https://github.com/PyTorchLightning/pytorch-lightning

In [19]:
# import nessasary lib
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import pytorch_lightning as pl
import numpy as np

In [32]:
class LogicRegModel(pl.LightningModule):

    def __init__(self):
        super(LogicRegModel, self).__init__()
        # 定义模型结构
        self.l1 = torch.nn.Linear(2, 1)

    def forward(self, x):
        # 必须：定义模型
        return torch.sigmoid(self.l1(x))

    def training_step(self, batch, batch_nb):
        # 必须提供：定于训练过程
        x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_nb):
        # 可选提供：定义验证过程
        x, y = batch
        y_hat = self(x)
        
        return {'val_loss': F.binary_cross_entropy(y_hat, y)}

    def validation_epoch_end(self, outputs):
        # 可选提供：定义验证过程
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'val_loss': avg_loss, 'log': tensorboard_logs}

    def test_step(self, batch, batch_nb):
        # 可选提供：定义测试过程
        x, y = batch
        y_hat = self(x)
        return {'test_loss': F.binary_cross_entropy(y_hat, y)}

    def test_epoch_end(self, outputs):
        # 可选提供：定义测试过程
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        logs = {'test_loss': avg_loss}
        return {'test_loss': avg_loss, 'log': logs, 'progress_bar': logs}

    def configure_optimizers(self):
        # 必须提供：定义优化器
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        return torch.optim.SGD(self.parameters(), lr=0.1)

    def gen_data_loader(self, shuffle, batch_size):
        X = torch.tensor([[2., 1.],
            [2., 2.],
            [5., 4.],
            [4., 5.],
            [2., 3.],
            [3., 2.],
            [6., 5.],
            [4., 1.],
            [6., 3.],
            [7, 4]], dtype=torch.float
        )
        Y = torch.tensor([0, 0, 1, 1, 0, 0, 1, 0, 1, 1], dtype=torch.float)
        # 先转换成 torch 能识别的 Dataset
        torch_dataset = TensorDataset(X, Y)

        # 把 dataset 放入 DataLoader
        loader = DataLoader(
            dataset=torch_dataset,      # torch TensorDataset format
            batch_size=batch_size,      # mini batch size
            shuffle=shuffle,            # 要不要打乱数据 (打乱比较好)
            num_workers=4,              # 多线程来读数据
        )
        return loader

    def train_dataloader(self):
        # 必须提供：提供训练数据集
        return self.gen_data_loader(True, 2)

    def val_dataloader(self):
        # 可选提供：提供验证数据集
        return self.gen_data_loader(False, 10)

    def test_dataloader(self):
        # 可选提供：提供测试数据集
        return self.gen_data_loader(False, 10)

In [34]:
logic_model = LogicRegModel()

# most basic trainer, uses good defaults (1 gpu)
trainer = pl.Trainer(max_epochs=100, num_sanity_val_steps=0)
trainer.fit(logic_model)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name | Type   | Params
--------------------------------
0 | l1   | Linear | 3     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

In [35]:
# 打印所有参数
for i in logic_model.parameters():
    print(i)

true_weight = torch.tensor([2, -3.4], dtype=torch.float32).view(2, 1)
true_bias = torch.tensor(4.2, dtype=torch.float32)

Parameter containing:
tensor([[0.6615, 0.8123]], requires_grad=True)
Parameter containing:
tensor([-4.4929], requires_grad=True)


In [30]:
trainer.test()

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
TEST RESULTS
{'test_loss': tensor(0.1366)}
--------------------------------------------------------------------------------



{'test_loss': 0.1365588903427124}