In [1]:
import torch
from torchvision import transforms
# from torchvision import datasets
from torch.utils.data import DataLoader,Dataset
import torch.nn.functional as F
import torch.optim as optim
import os
import gzip
import numpy as np

In [2]:
class MNISTDataset(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.root = root
        self.train = train
        self.transform = transform

        if self.train:
            images_file = 'train-images-idx3-ubyte.gz'
            labels_file = 'train-labels-idx1-ubyte.gz'
        else:
            images_file = 't10k-images-idx3-ubyte.gz'
            labels_file = 't10k-labels-idx1-ubyte.gz'

        self.images_path = os.path.join(root, images_file)
        self.labels_path = os.path.join(root, labels_file)

        self.images, self.labels = self.load_data()

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = int(self.labels[idx])

        if self.transform:
            image = self.transform(image)

        return image, label

    def load_data(self):
        with gzip.open(self.images_path, 'rb') as f_images:
            images = np.frombuffer(f_images.read(), dtype=np.uint8, offset=16).reshape(-1, 28, 28)

        with gzip.open(self.labels_path, 'rb') as f_labels:
            labels = np.frombuffer(f_labels.read(), dtype=np.uint8, offset=8)

        return images, labels

In [3]:
batch_size = 64
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
])

In [4]:
train_dataset = MNISTDataset(root='./mnist_dataset/', train=True, transform=transform)
train_loader =  DataLoader(train_dataset,shuffle=True,batch_size=batch_size)

# 示例：获取第一个样本
image, label = train_dataset[0]
print(train_dataset.__len__())
print(image.shape)
print(label)

60000
torch.Size([1, 28, 28])
5


  img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()


In [5]:
test_dataset = MNISTDataset(root='./mnist_dataset/', train=False, transform=transform)
test_loader =  DataLoader(test_dataset,shuffle=False,batch_size=batch_size)
print(test_dataset.__len__())

10000


In [6]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(784, 512)
        self.linear2 = torch.nn.Linear(512, 256)
        self.linear3 = torch.nn.Linear(256, 128)
        self.linear4 = torch.nn.Linear(128, 64)
        self.linear5 = torch.nn.Linear(64, 10)
 
    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = F.relu(self.linear3(x))
        x = F.relu(self.linear4(x))
        return self.linear5(x)
 
 
model = Model()

In [7]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)


In [8]:
def train(epoch):
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader,0):
        inputs, target = data
        optimizer.zero_grad()
        
        outputs = model(inputs)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
#         print(batch_idx)
        if batch_idx%300==299:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
            running_loss = 0.0

In [9]:
def test():
    correct = 0 
    total =0 
    with torch.no_grad():
        for data in test_loader:
            images,label=data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)
            total += label.size(0)
            correct += (predicted == label).sum().item()
    print('Accuracy on test set: %d %%' % (100 * correct / total))


In [10]:
if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        test()


[1,   300] loss: 2.235
[1,   600] loss: 0.945
[1,   900] loss: 0.407
Accuracy on test set: 88 %
[2,   300] loss: 0.321
[2,   600] loss: 0.271
[2,   900] loss: 0.231
Accuracy on test set: 93 %
[3,   300] loss: 0.196
[3,   600] loss: 0.173
[3,   900] loss: 0.155
Accuracy on test set: 95 %
[4,   300] loss: 0.129
[4,   600] loss: 0.127
[4,   900] loss: 0.122
Accuracy on test set: 96 %
[5,   300] loss: 0.103
[5,   600] loss: 0.099
[5,   900] loss: 0.096
Accuracy on test set: 96 %
[6,   300] loss: 0.080
[6,   600] loss: 0.077
[6,   900] loss: 0.075
Accuracy on test set: 97 %
[7,   300] loss: 0.061
[7,   600] loss: 0.063
[7,   900] loss: 0.063
Accuracy on test set: 97 %
[8,   300] loss: 0.049
[8,   600] loss: 0.052
[8,   900] loss: 0.053
Accuracy on test set: 97 %
[9,   300] loss: 0.038
[9,   600] loss: 0.043
[9,   900] loss: 0.044
Accuracy on test set: 97 %
[10,   300] loss: 0.032
[10,   600] loss: 0.035
[10,   900] loss: 0.032
Accuracy on test set: 97 %
