# softmax回归
通过逻辑回归我们已经了解了模型的基本分类，逻辑回归是处理二分类问题，但是现实中有许多场景是需要多分类的，所以需要一种模型能够处理多分类，本文介绍的softmax回归模型就是能够处理多分类的，它的预测结果是一个向量，下面我们详细介绍模型的基本构成。

## softmax回归模型
+ 数据：根据问题来的，比如预测图片里面是汽车、猫、狗等，我们会收集很多图片，而彩色图片一般有3个通道，每个通道均是28*28的矩阵，每个像素点取值是0-256之间，从图片中我们可以提炼很多变量，当然直接将每个像素点也可以当成一个变量，还会收集图片真正包含的图片是什么【答案或者目标】，这个就是训练样本
+ 模型：假设softmax回归模型的预测结果有3类
  - O1 = b11 + a11\*x1 + a12\*x2 + a1n\*xm
  - O2 = b21 + a21\*x1 + a22\*x2 + a2n\*xm
  - O3 = b31 + a31\*x1 + a32\*x2 + a3n\*xm 

  那么预测结果向量为：\[y_hat1=exp(O1)/sum(exp(Oi)), y_hat2=exp(O2)/sum(exp(Oi)), y_hat3=exp(O3)/sum(exp(Oi))\]
+ 参数：b11 a11 a12 a1m 就是参数
+ 损失函数：z = sum(-(y1\*log(y_hat1) + y2\*log(y_hat2) + y3\*log(y_hat3)))/n，n 是样本数量
+ 优化方法：梯度下降等，得益于 PyTorch 自动求梯度，我们不再需要显示求出损失函数的导数

可以看到，逻辑回归的模型一部分和线性回归很相似，这也是它名字中包含回归的原因。

本篇开始，因为我们已经掌握了模型基本构造原理，后续模型均是通过 PyTorch Lightning 高级 api 直接实现， https://github.com/PyTorchLightning/pytorch-lightning


## 1. 训练样本准备
下面先了解一下后面我们经常需要用到的数据集，Fashion-MNIST，图像分类数据集中最常用的是手写数字识别数据集MNIST，但大部分模型在MNIST上的分类精度都超过了95%，也就是说MNIST数据集可以说是深度学习中的 HelloWorld，被玩坏了，那我们就玩点图像内容更加复杂的数据集 Fashion-MNIST（这个数据集也比较小，只有几十M，没有GPU的电脑也能吃得消）。    
本节我们将使用 torchvision 包，它是服务于 PyTorch 深度学习框架的，主要用来构建计算机视觉模型。torchvision 主要由以下几部分构成：

+ torchvision.datasets: 一些加载数据的函数及常用的数据集接口
+ torchvision.models: 包含常用的模型结构（含预训练模型），例如AlexNet、VGG、ResNet等
+ torchvision.transforms: 常用的图片变换，例如裁剪、旋转等
+ torchvision.utils: 其他的一些有用的方法

## 1.1 下载数据集

In [5]:
import torch
import torchvision
import torchvision.transforms as transforms
minist_train = torchvision.datasets.FashionMNIST('/home/qspace/git_workspace/PyTorchLearn/datas', train=True, download=True, transform=transforms.ToTensor)
minist_test = torchvision.datasets.FashionMNIST('/home/qspace/git_workspace/PyTorchLearn/datas', train=False, download=True, transform=transforms.ToTensor)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /home/qspace/git_workspace/PyTorchLearn/datas/FashionMNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Extracting /home/qspace/git_workspace/PyTorchLearn/datas/FashionMNIST/raw/train-images-idx3-ubyte.gz to /home/qspace/git_workspace/PyTorchLearn/datas/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /home/qspace/git_workspace/PyTorchLearn/datas/FashionMNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Extracting /home/qspace/git_workspace/PyTorchLearn/datas/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /home/qspace/git_workspace/PyTorchLearn/datas/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /home/qspace/git_workspace/PyTorchLearn/datas/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Extracting /home/qspace/git_workspace/PyTorchLearn/datas/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /home/qspace/git_workspace/PyTorchLearn/datas/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /home/qspace/git_workspace/PyTorchLearn/datas/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz





HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Extracting /home/qspace/git_workspace/PyTorchLearn/datas/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /home/qspace/git_workspace/PyTorchLearn/datas/FashionMNIST/raw
Processing...
Done!


In [7]:
# 打印数据集的类型，以及训练样本和测试样本的数量
print(type(minist_train))
print('train sample len = {}'.format(len(minist_train)))
print('test sample len = {}'.format(len(minist_test)))

<class 'torchvision.datasets.mnist.FashionMNIST'>
train sample len = 60000
test sample len = 10000


In [None]:
## 2. softmax回归高级版本实现
通过 PyTorch Lightning 来实现， https://github.com/PyTorchLightning/pytorch-lightning

In [19]:
# import nessasary lib
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import pytorch_lightning as pl
import numpy as np

In [32]:
class LogicRegModel(pl.LightningModule):

    def __init__(self):
        super(LogicRegModel, self).__init__()
        # 定义模型结构
        self.l1 = torch.nn.Linear(2, 1)

    def forward(self, x):
        # 必须：定义模型
        return torch.sigmoid(self.l1(x))

    def training_step(self, batch, batch_nb):
        # 必须提供：定于训练过程
        x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_nb):
        # 可选提供：定义验证过程
        x, y = batch
        y_hat = self(x)
        
        return {'val_loss': F.binary_cross_entropy(y_hat, y)}

    def validation_epoch_end(self, outputs):
        # 可选提供：定义验证过程
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'val_loss': avg_loss, 'log': tensorboard_logs}

    def test_step(self, batch, batch_nb):
        # 可选提供：定义测试过程
        x, y = batch
        y_hat = self(x)
        return {'test_loss': F.binary_cross_entropy(y_hat, y)}

    def test_epoch_end(self, outputs):
        # 可选提供：定义测试过程
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        logs = {'test_loss': avg_loss}
        return {'test_loss': avg_loss, 'log': logs, 'progress_bar': logs}

    def configure_optimizers(self):
        # 必须提供：定义优化器
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        return torch.optim.SGD(self.parameters(), lr=0.1)

    def gen_data_loader(self, shuffle, batch_size):
        X = torch.tensor([[2., 1.],
            [2., 2.],
            [5., 4.],
            [4., 5.],
            [2., 3.],
            [3., 2.],
            [6., 5.],
            [4., 1.],
            [6., 3.],
            [7, 4]], dtype=torch.float
        )
        Y = torch.tensor([0, 0, 1, 1, 0, 0, 1, 0, 1, 1], dtype=torch.float)
        # 先转换成 torch 能识别的 Dataset
        torch_dataset = TensorDataset(X, Y)

        # 把 dataset 放入 DataLoader
        loader = DataLoader(
            dataset=torch_dataset,      # torch TensorDataset format
            batch_size=batch_size,      # mini batch size
            shuffle=shuffle,            # 要不要打乱数据 (打乱比较好)
            num_workers=4,              # 多线程来读数据
        )
        return loader

    def train_dataloader(self):
        # 必须提供：提供训练数据集
        return self.gen_data_loader(True, 2)

    def val_dataloader(self):
        # 可选提供：提供验证数据集
        return self.gen_data_loader(False, 10)

    def test_dataloader(self):
        # 可选提供：提供测试数据集
        return self.gen_data_loader(False, 10)

In [34]:
logic_model = LogicRegModel()

# most basic trainer, uses good defaults (1 gpu)
trainer = pl.Trainer(max_epochs=100, num_sanity_val_steps=0)
trainer.fit(logic_model)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name | Type   | Params
--------------------------------
0 | l1   | Linear | 3     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

In [35]:
# 打印所有参数
for i in logic_model.parameters():
    print(i)

true_weight = torch.tensor([2, -3.4], dtype=torch.float32).view(2, 1)
true_bias = torch.tensor(4.2, dtype=torch.float32)

Parameter containing:
tensor([[0.6615, 0.8123]], requires_grad=True)
Parameter containing:
tensor([-4.4929], requires_grad=True)


In [30]:
trainer.test()

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
TEST RESULTS
{'test_loss': tensor(0.1366)}
--------------------------------------------------------------------------------



{'test_loss': 0.1365588903427124}