In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 30
BATCH_SIZE = 64 # For faster purpose
LR =  0.0003 # [0.0003, 0.003, 0.03, 0.07, 0.13]

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307, ), (0.3081, ))
])

In [4]:
train_set = datasets.MNIST("data_sets", train=True, download=True, transform=transform)
test_set = datasets.MNIST("data_sets", train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [5]:
class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 6, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.dense1 = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.ReLU()
        )
        self.dense2 = nn.Sequential(
            nn.Linear(120, 84),
            nn.ReLU()
        )
        self.dense3 = nn.Sequential(
            nn.Linear(84, 10),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.dense1(x)
        x = self.dense2(x)
        x = self.dense3(x)
        return x

In [6]:
model = LeNet().to(DEVICE)
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[300, 600, 900], gamma=0.5, last_epoch=-1)

In [7]:
def train_model(my_model, device, trains_loader, optimizers, lr_scheduler, epoches):
    # 模型训练
    my_model.train()
    for batch_idx, (data, target) in enumerate(trains_loader):
        data, target = data.to(device), target.to(device)
        optimizers.zero_grad()
        output = my_model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizers.step()
        lr_scheduler.step()

def test_model(my_model, device, test_loder):
    my_model.eval()
    correct = 0
    test_loss = 0
    with torch.no_grad(): 
        for data, target in test_loder:
            data, target = data.to(device), target.to(device)
            output = my_model(data)
            test_loss += F.cross_entropy(output, target).item()
            predict = output.argmax(dim=1)
            correct += predict.eq(target.view_as(predict)).sum().item()
        avg_loss = test_loss / len(test_loder.dataset)
        correct_ratio = 100 * correct / len(test_loder.dataset)
        return avg_loss, correct_ratio

In [8]:
for epoch in range(1, EPOCHS+1):
#     train_model(model, DEVICE, train_loader, optimizer, epoch)
    train_model(model, DEVICE, train_loader, optimizer, lr_scheduler, epoch)
    avg_loss, correct_ratio = test_model(model, DEVICE, test_loader)
    print("Epoch: " + str(epoch) + "\t Loss: {:.5f}\t Accuracy: {:.5f}".format(avg_loss, correct_ratio))

Epoch: 1	 Loss: 0.03614	 Accuracy: 9.94000
Epoch: 2	 Loss: 0.03614	 Accuracy: 10.05000
Epoch: 3	 Loss: 0.03614	 Accuracy: 10.13000
Epoch: 4	 Loss: 0.03614	 Accuracy: 10.25000
Epoch: 5	 Loss: 0.03613	 Accuracy: 10.34000
Epoch: 6	 Loss: 0.03613	 Accuracy: 10.41000
Epoch: 7	 Loss: 0.03613	 Accuracy: 10.43000
Epoch: 8	 Loss: 0.03613	 Accuracy: 10.53000
Epoch: 9	 Loss: 0.03613	 Accuracy: 10.55000
Epoch: 10	 Loss: 0.03612	 Accuracy: 10.60000
Epoch: 11	 Loss: 0.03612	 Accuracy: 10.72000
Epoch: 12	 Loss: 0.03612	 Accuracy: 10.77000
Epoch: 13	 Loss: 0.03612	 Accuracy: 10.82000
Epoch: 14	 Loss: 0.03611	 Accuracy: 10.97000
Epoch: 15	 Loss: 0.03611	 Accuracy: 11.16000
Epoch: 16	 Loss: 0.03611	 Accuracy: 11.27000
Epoch: 17	 Loss: 0.03610	 Accuracy: 11.46000
Epoch: 18	 Loss: 0.03610	 Accuracy: 11.74000
Epoch: 19	 Loss: 0.03610	 Accuracy: 12.03000
Epoch: 20	 Loss: 0.03609	 Accuracy: 12.52000
Epoch: 21	 Loss: 0.03609	 Accuracy: 12.97000
Epoch: 22	 Loss: 0.03608	 Accuracy: 13.80000
Epoch: 23	 Loss: 0.0

In [9]:
torch.cuda.empty_cache()