In [2]:
import torch
from torchvision import datasets, transforms
import torchvision
from tqdm import tqdm
 
device_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] # 可用GPU
BATCH_SIZE = 64
 
transform = transforms.Compose([transforms.ToTensor()])
data_train = datasets.MNIST(root = "./data/",
                            transform=transform,
                            train=True,
                            download=True)
data_test = datasets.MNIST(root="./data/",
                           transform=transform,
                           train=False)
 
data_loader_train = torch.utils.data.DataLoader(dataset=data_train,
                                                # 单卡batch size * 卡数
                                                batch_size=BATCH_SIZE * len(device_ids),
                                                shuffle=True,
                                                num_workers=2)
 
data_loader_test = torch.utils.data.DataLoader(dataset=data_test,
                                               batch_size=BATCH_SIZE * len(device_ids),
                                               shuffle=True,
                                               num_workers=2)
 
 
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = torch.nn.Sequential(
        torch.nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
        torch.nn.ReLU(),
        torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(stride=2, kernel_size=2),
    )
        self.dense = torch.nn.Sequential(
            torch.nn.Linear(14 * 14 * 128, 1024),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=0.5),
            torch.nn.Linear(1024, 10)
    )
    def forward(self, x):
        x = self.conv1(x)
        x = x.view(-1, 14 * 14 * 128)
        x = self.dense(x)
        return x
 
 
model = Model()
# 指定要用到的设备
model = torch.nn.DataParallel(model, device_ids=device_ids)
# 模型加载到设备0
model = model.cuda(device=device_ids[0])
 
cost = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
from time import sleep
n_epochs = 50
for epoch in range(n_epochs):
    running_loss = 0.0
    running_correct = 0
    print("Epoch {}/{}".format(epoch, n_epochs))
    print("-"*10)
    for data in tqdm(data_loader_train):
        X_train, y_train = data
        # 指定设备0
        X_train, y_train = X_train.cuda(device=device_ids[0]), y_train.cuda(device=device_ids[0])
        outputs = model(X_train)
        _,pred = torch.max(outputs.data, 1)
        optimizer.zero_grad()
        loss = cost(outputs, y_train)
 
        loss.backward()
        optimizer.step()
        running_loss += loss.data.item()
        running_correct += torch.sum(pred == y_train.data)
    testing_correct = 0
    for data in data_loader_test:
        X_test, y_test = data
        # 指定设备1
        X_test, y_test = X_test.cuda(device=device_ids[0]), y_test.cuda(device=device_ids[0])
        outputs = model(X_test)
        _, pred = torch.max(outputs.data, 1)
        testing_correct += torch.sum(pred == y_test.data)
    print("Loss is:{:.4f}, Train Accuracy is:{:.4f}%, Test Accuracy is:{:.4f}".format(torch.true_divide(running_loss, len(data_train)),
                                                                                      torch.true_divide(100*running_correct, len(data_train)),
                                                                                      torch.true_divide(100*testing_correct, len(data_test))))
torch.save(model.state_dict(), "model_parameter.pkl")

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 0/50
----------


  9%|▊         | 8/94 [00:26<03:10,  2.21s/it]




 10%|▉         | 9/94 [00:26<02:14,  1.58s/it]





100%|██████████| 94/94 [00:36<00:00,  2.60it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0005, Train Accuracy is:89.5933%, Test Accuracy is:97.7700
Epoch 1/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.62it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0001, Train Accuracy is:98.0633%, Test Accuracy is:98.5300
Epoch 2/50
----------


 97%|█████████▋| 91/94 [00:10<00:00,  8.67it/s]




100%|██████████| 94/94 [00:10<00:00,  8.61it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0001, Train Accuracy is:98.7233%, Test Accuracy is:98.7800
Epoch 3/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.60it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.0517%, Test Accuracy is:98.8400
Epoch 4/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.64it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.2583%, Test Accuracy is:98.8700
Epoch 5/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.61it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.4700%, Test Accuracy is:98.8400
Epoch 6/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.60it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.4600%, Test Accuracy is:98.8100
Epoch 7/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.59it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.5950%, Test Accuracy is:98.9200
Epoch 8/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.62it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.6800%, Test Accuracy is:98.9000
Epoch 9/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.59it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.6217%, Test Accuracy is:98.9400
Epoch 10/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.63it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.6917%, Test Accuracy is:98.9300
Epoch 11/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.62it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.7683%, Test Accuracy is:99.0200
Epoch 12/50
----------


100%|██████████| 94/94 [00:11<00:00,  8.54it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.7867%, Test Accuracy is:98.8900
Epoch 13/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.61it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.8033%, Test Accuracy is:98.9100
Epoch 14/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.63it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.8483%, Test Accuracy is:98.9400
Epoch 15/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.62it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.8200%, Test Accuracy is:98.8200
Epoch 16/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.63it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.8150%, Test Accuracy is:99.0300
Epoch 17/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.61it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.8250%, Test Accuracy is:98.8600
Epoch 18/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.60it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.8333%, Test Accuracy is:99.1300
Epoch 19/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.65it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.7817%, Test Accuracy is:98.7300
Epoch 20/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.59it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.8367%, Test Accuracy is:98.9900
Epoch 21/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.62it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.8950%, Test Accuracy is:99.0200
Epoch 22/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.62it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.8817%, Test Accuracy is:99.0600
Epoch 23/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.61it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.8700%, Test Accuracy is:98.9000
Epoch 24/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.58it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.8250%, Test Accuracy is:98.9100
Epoch 25/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.63it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.8983%, Test Accuracy is:99.0200
Epoch 26/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.62it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.8967%, Test Accuracy is:98.9700
Epoch 27/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.62it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.9050%, Test Accuracy is:98.9600
Epoch 28/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.63it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.9033%, Test Accuracy is:98.9700
Epoch 29/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.64it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.9000%, Test Accuracy is:98.9800
Epoch 30/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.61it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.8667%, Test Accuracy is:98.9800
Epoch 31/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.65it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.9133%, Test Accuracy is:98.9500
Epoch 32/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.63it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.9200%, Test Accuracy is:99.1400
Epoch 33/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.59it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.9383%, Test Accuracy is:99.0500
Epoch 34/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.64it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.9067%, Test Accuracy is:98.9900
Epoch 35/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.64it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.9017%, Test Accuracy is:99.0300
Epoch 36/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.62it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.9267%, Test Accuracy is:98.9700
Epoch 37/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.65it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.9067%, Test Accuracy is:98.9600
Epoch 38/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.64it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.9317%, Test Accuracy is:99.1500
Epoch 39/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.61it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.9167%, Test Accuracy is:99.1400
Epoch 40/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.63it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.8683%, Test Accuracy is:99.0600
Epoch 41/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.62it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.9217%, Test Accuracy is:99.0700
Epoch 42/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.63it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.9317%, Test Accuracy is:99.0700
Epoch 43/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.64it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.9683%, Test Accuracy is:99.0900
Epoch 44/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.63it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.9600%, Test Accuracy is:99.0400
Epoch 45/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.59it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.9417%, Test Accuracy is:99.0900
Epoch 46/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.63it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.9517%, Test Accuracy is:99.0700
Epoch 47/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.63it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.9450%, Test Accuracy is:99.1100
Epoch 48/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.61it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Loss is:0.0000, Train Accuracy is:99.9483%, Test Accuracy is:99.1300
Epoch 49/50
----------


100%|██████████| 94/94 [00:10<00:00,  8.65it/s]


Loss is:0.0000, Train Accuracy is:99.9300%, Test Accuracy is:98.9800
