# 0.1 使用GPU训练网络
- 只需要将输入和网络转移到GPU上

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

torch.set_printoptions(linewidth=120)

print(torch.__version__)
print(torchvision.__version__)

1.6.0
0.7.0


In [2]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
        
        self.fc1 = nn.Linear(in_features=12 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
        
        
    def forward(self, t):
        t = F.relu(self.conv1(t))
        t = F.max_pool2d(t, kernel_size=2, stride=2)
        
        t = F.relu(self.conv2(t))
        t = F.max_pool2d(t, kernel_size=2, stride=2)
        
        t = t.flatten(start_dim=1)
        t = F.relu(self.fc1(t))
        
        t = F.relu(self.fc2(t))
        
        t = self.out(t)
        
        return t

In [3]:
train_set = torchvision.datasets.FashionMNIST(
    root='./data/FashionMNIST/',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)

In [4]:
def get_num_corrent(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [7]:
network = Network().cuda()
optimizer = optim.Adam(network.parameters(), lr=0.01)

In [8]:
for epoch in range(5):
    total_loss = 0
    total_corrent = 0
    
    for batch in train_loader:
        images = batch[0].cuda()
        labels = batch[1].cuda()

        # 预测
        preds = network(images)
        # loss function
        loss = F.cross_entropy(preds, labels)
        # 计算梯度
        optimizer.zero_grad()
        loss.backward()
        # 更新权重
        optimizer.step()
        
        total_loss += loss.item()
        total_corrent += get_num_corrent(preds, labels)
    print(f'epoch: {epoch}, loss: {total_loss}, total_corrent: {total_corrent}')

epoch: 0, loss: 338.8103349804878, total_corrent: 47169
epoch: 1, loss: 235.20256623625755, total_corrent: 51375
epoch: 2, loss: 217.9649288803339, total_corrent: 51975
epoch: 3, loss: 208.6604774147272, total_corrent: 52214
epoch: 4, loss: 203.08801139891148, total_corrent: 52493


由于网络比较小，没有明显效果