In [1]:
import torch
from torch import nn
from torchvision.datasets import FashionMNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_dataset = FashionMNIST(root='../data/FashionMNIST/', train=True, download=True, transform=transform)
test_dataset = FashionMNIST(root='../data/FashionMNIST/', train=False, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)

In [3]:
def evaluate_classify(model:torch.nn.Module, dataloader:DataLoader, criterion:nn.Module) -> (float, float):
    model.eval()
    total_correct = 0
    total_loss = 0
    total = 0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            y_hat = model(X)
            loss = criterion(y_hat, y)

            total_loss += loss.item()
            total_correct += (y_hat.argmax(1) == y).type(torch.float).sum().item()
            total += y.size(0)

    model.train()
    
    return total_loss / total, total_correct / total * 100

In [4]:
def train_classify(model:nn.Module, train_loader:DataLoader, test_loader:DataLoader, 
                   optimizer:torch.optim.Optimizer, criterion:nn.Module, num_epochs:int=10) -> None:
    for epoch in range(num_epochs):
        total_loss = 0
        total_correct = 0
        total = 0
        
        progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))

        for i, (X, y) in progress_bar:
            X, y = X.to(device), y.to(device)

            y_hat = model(X)
            loss = criterion(y_hat, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_correct += (y_hat.argmax(1) == y).type(torch.float).sum().item()
            total += y.size(0)
        
            progress_bar.set_description(f"Epoch {epoch+1}")
            progress_bar.set_postfix(loss=total_loss/(i+1), accuracy=100.*total_correct/total)
        
        print(f"Epoch: {epoch + 1}, loss: {total_loss / len(train_loader)}, acc: {100. * total_correct / total}")
        test_loss, test_acc = evaluate_classify(model, test_loader, criterion)
        print(f"Epoch: {epoch + 1}, test loss: {test_loss}, test acc: {test_acc}")

## AlexNet

从浅层网络到深层网络的关键一步

In [5]:
AlexNet = nn.Sequential(
    # 这里使用一个11*11的更大窗口来捕捉对象。
    # 同时，步幅为4，以减少输出的高度和宽度。
    # 另外，输出通道的数目远大于LeNet
    nn.Conv2d(1, 96, kernel_size=11, stride=4, padding=1), nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2),
    # 减小卷积窗口，使用填充为2来使得输入与输出的高和宽一致，且增大输出通道数
    nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2),
    # 使用三个连续的卷积层和较小的卷积窗口。
    # 除了最后的卷积层，输出通道的数量进一步增加。
    # 在前两个卷积层之后，汇聚层不用于减少输入的高度和宽度
    nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),
    nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),
    nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nn.Flatten(),
    # 这里，全连接层的输出数量是LeNet中的好几倍。使用dropout层来减轻过拟合
    nn.Linear(6400, 4096), nn.ReLU(),
    nn.Dropout(p=0.5),
    nn.Linear(4096, 4096), nn.ReLU(),
    nn.Dropout(p=0.5),
    # 最后是输出层。由于这里使用Fashion-MNIST，所以用类别数为10，而非论文中的1000
    nn.Linear(4096, 10)
).to(device)

In [6]:
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(AlexNet.parameters(), lr=0.01)
optimizer = torch.optim.AdamW(AlexNet.parameters(), lr=1e-3, weight_decay=1e-2)

In [7]:
train_classify(AlexNet, train_loader, test_loader, optimizer, criterion, num_epochs=10)

Epoch 1: 100%|██████████| 469/469 [00:50<00:00,  9.31it/s, accuracy=76.1, loss=0.623]


Epoch: 1, loss: 0.6232904042643524, acc: 76.07833333333333
Epoch: 1, test loss: 0.002927865283191204, test acc: 86.57000000000001


Epoch 2: 100%|██████████| 469/469 [00:47<00:00,  9.79it/s, accuracy=87.2, loss=0.342]


Epoch: 2, loss: 0.3420215174396917, acc: 87.15
Epoch: 2, test loss: 0.0024818234860897064, test acc: 88.14999999999999


Epoch 3: 100%|██████████| 469/469 [00:47<00:00,  9.81it/s, accuracy=88.9, loss=0.294]


Epoch: 3, loss: 0.29430521077819977, acc: 88.92
Epoch: 3, test loss: 0.0024063949570059776, test acc: 88.63


Epoch 4: 100%|██████████| 469/469 [00:47<00:00,  9.79it/s, accuracy=90.1, loss=0.265]


Epoch: 4, loss: 0.26465576415313585, acc: 90.115
Epoch: 4, test loss: 0.0024410379126667977, test acc: 88.94


Epoch 5: 100%|██████████| 469/469 [00:48<00:00,  9.72it/s, accuracy=91, loss=0.242]  


Epoch: 5, loss: 0.24221247701502557, acc: 90.99
Epoch: 5, test loss: 0.002158971853554249, test acc: 89.91


Epoch 6: 100%|██████████| 469/469 [00:50<00:00,  9.27it/s, accuracy=91.8, loss=0.223]


Epoch: 6, loss: 0.22324519634628093, acc: 91.82666666666667
Epoch: 6, test loss: 0.002097383114695549, test acc: 90.32


Epoch 7: 100%|██████████| 469/469 [00:51<00:00,  9.04it/s, accuracy=92.2, loss=0.209]


Epoch: 7, loss: 0.20925184305924088, acc: 92.18666666666667
Epoch: 7, test loss: 0.0021287918344140054, test acc: 90.03999999999999


Epoch 8: 100%|██████████| 469/469 [00:50<00:00,  9.21it/s, accuracy=92.8, loss=0.194]


Epoch: 8, loss: 0.19371465313981082, acc: 92.76333333333334
Epoch: 8, test loss: 0.002015212031453848, test acc: 90.93


Epoch 9: 100%|██████████| 469/469 [00:48<00:00,  9.61it/s, accuracy=93.1, loss=0.184]


Epoch: 9, loss: 0.18418925690816154, acc: 93.095
Epoch: 9, test loss: 0.00202462058365345, test acc: 91.36999999999999


Epoch 10: 100%|██████████| 469/469 [00:48<00:00,  9.63it/s, accuracy=93.6, loss=0.172]


Epoch: 10, loss: 0.17221499204254354, acc: 93.605
Epoch: 10, test loss: 0.001965224288403988, test acc: 91.27


## VGG网络

核心思想是使用块，每个块由数个网络层组成，通过堆叠块形成深度网络。

In [9]:
def vgg_block(num_convs:int, in_channels:int, out_channels:int):
    net = []
    for _ in range(num_convs):
        net.append(nn.Conv2d(in_channels, out_channels,
                             kernel_size=3, padding=1))
        net.append(nn.ReLU())
        in_channels = out_channels
    net.append(nn.MaxPool2d(kernel_size=2, stride=2))
    return nn.Sequential(*net)

In [11]:
def vgg_net(conv_arch:list):
    net = []

    # 卷积部分
    in_channels = 1
    for (num_convs, out_channels) in conv_arch:
        net.append(vgg_block(num_convs, in_channels, out_channels))
        in_channels = out_channels

    # 线性层
    return nn.Sequential(
        *net, nn.Flatten(),
        nn.Linear(out_channels * 7 * 7, 4096), nn.ReLU(), nn.Dropout(0.5),
        nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(0.5),
        nn.Linear(4096, 10)
    )

conv_arch = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))

# FashionMNIST数据集比较简单，将通道数缩放一下，降低计算量
ratio = 4
small_conv_arch = [(pair[0], pair[1] // ratio) for pair in conv_arch]
small_vgg_net = vgg_net(small_conv_arch).to(device)

In [12]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(small_vgg_net.parameters(), lr=1e-3, weight_decay=1e-2)

In [13]:
train_classify(small_vgg_net, train_loader, test_loader, optimizer, criterion, num_epochs=10)

Epoch 1: 100%|██████████| 469/469 [01:14<00:00,  6.30it/s, accuracy=79, loss=0.557]  


Epoch: 1, loss: 0.5565553918830367, acc: 78.96
Epoch: 1, test loss: 0.002487978158891201, test acc: 88.53


Epoch 2: 100%|██████████| 469/469 [01:12<00:00,  6.43it/s, accuracy=89.4, loss=0.288]


Epoch: 2, loss: 0.2876724307534537, acc: 89.42166666666667
Epoch: 2, test loss: 0.002082240116596222, test acc: 90.2


Epoch 3: 100%|██████████| 469/469 [01:12<00:00,  6.44it/s, accuracy=91.1, loss=0.241]


Epoch: 3, loss: 0.24068131121491065, acc: 91.11166666666666
Epoch: 3, test loss: 0.001996009835600853, test acc: 90.8


Epoch 4: 100%|██████████| 469/469 [01:14<00:00,  6.28it/s, accuracy=92.2, loss=0.214]


Epoch: 4, loss: 0.21352951376359347, acc: 92.17
Epoch: 4, test loss: 0.0017884517759084702, test acc: 92.05


Epoch 5: 100%|██████████| 469/469 [01:14<00:00,  6.27it/s, accuracy=93, loss=0.191]  


Epoch: 5, loss: 0.19062784035354535, acc: 92.97833333333334
Epoch: 5, test loss: 0.001710431595146656, test acc: 92.23


Epoch 6: 100%|██████████| 469/469 [01:14<00:00,  6.31it/s, accuracy=93.5, loss=0.174]


Epoch: 6, loss: 0.17355345886176837, acc: 93.53833333333333
Epoch: 6, test loss: 0.0018231979683041573, test acc: 92.04


Epoch 7: 100%|██████████| 469/469 [01:13<00:00,  6.42it/s, accuracy=94.4, loss=0.153]


Epoch: 7, loss: 0.15278119347624178, acc: 94.44166666666666
Epoch: 7, test loss: 0.0018598661419004202, test acc: 92.14


Epoch 8: 100%|██████████| 469/469 [01:12<00:00,  6.46it/s, accuracy=94.8, loss=0.14] 


Epoch: 8, loss: 0.14003762371663345, acc: 94.805
Epoch: 8, test loss: 0.001733958163112402, test acc: 92.36


Epoch 9: 100%|██████████| 469/469 [01:12<00:00,  6.47it/s, accuracy=95.4, loss=0.124]


Epoch: 9, loss: 0.1236349705146002, acc: 95.43166666666667
Epoch: 9, test loss: 0.001883041138201952, test acc: 92.32000000000001


Epoch 10: 100%|██████████| 469/469 [01:12<00:00,  6.45it/s, accuracy=96, loss=0.109]  


Epoch: 10, loss: 0.10936467909911421, acc: 96.005
Epoch: 10, test loss: 0.0018299723632633686, test acc: 93.31
