In [2]:
import torch
import torchvision
from torchvision import datasets
from torchvision import transforms

# 1. 载入CIFAR10数据集

In [None]:
DATA_PATH = "./data/datasets"
cifar10_train = datasets.CIFAR10(
    DATA_PATH,
    train=True,
    download=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4915, 0.4823, 0.4468],
                             [0.2470, 0.2435, 0.2616],),
    ]),
)
cifar10_val = datasets.CIFAR10(
    DATA_PATH,
    train=False,
    download=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4915, 0.4823, 0.4468],
                             [0.2470, 0.2435, 0.2616],),
    ]),
)

In [5]:
# 本例只需要区分airplane和bird
# 因此从整个数据集中抽取airplane和bird
label_map = {0: 0, 2: 1}  # 原始数据集中 0: airplane 2:bird
cifar2_train = [(img, label_map[label]) for img, label in cifar10_train if label in (0, 2)]
cifar2_val = [(img, label_map[label]) for img, label in cifar10_val if label in (0, 2)]
len(cifar2_train), len(cifar2_val)

(10000, 2000)

# 2. 创建全连接层

In [None]:
model = torch.nn.Sequential(
    torch.nn.Linear(32 * 32 * 3, 512),
    torch.nn.Tanh(),
    torch.nn.Linear(512, 2),
    torch.nn.LogSoftmax(dim=-1),
)

In [8]:
# Softmax：将一个向量转换使其符合概率分布
t1 = torch.tensor([1., 2., 3., 4.])
softmax = torch.nn.Softmax(dim=-1)  # 指定Softmax操作的维度
softmax(t1)

tensor([0.0321, 0.0871, 0.2369, 0.6439])

In [9]:
t2 = torch.tensor([[1., 2., 3., 4.],
                   [1., 2., 3., 4.]])
softmax(t2)

tensor([[0.0321, 0.0871, 0.2369, 0.6439],
        [0.0321, 0.0871, 0.2369, 0.6439]])

In [13]:
# LogSoftmax：对Softmax的结果取对数
# 解决了当概率趋于0时求log易出错的问题
t1 = torch.tensor([1., 2., 3., 4.])
torch.nn.LogSoftmax(dim=-1)(t1)

tensor([-3.4402, -2.4402, -1.4402, -0.4402])

In [14]:
torch.log(softmax(t1))

tensor([-3.4402, -2.4402, -1.4402, -0.4402])

In [15]:
# NLL：negative log likelihood
# NLL = -log(prob)
t1 = torch.tensor([0.70, 0.10, 0.05, 0.15])
log_t1 = torch.log(t1)
torch.nn.NLLLoss()(log_t1)

TypeError: forward() missing 1 required positional argument: 'target'