In [1]:
import torch as t
from torch.utils.data import DataLoader
import torchvision as tv

#### 准备数据

In [None]:
transform = tv.transforms.Compose([tv.transforms.ToTensor(),
                                  tv.transforms.Normalize((0.5,), (0.5,))])

train_ts = tv.datasets.MNIST(root='./data', train=True, download=True, 
                             transform=transform)
test_ts = tv.datasets.MNIST(root='./data', train=False, download=True, 
                            transform=transform)
train_dl = DataLoader(train_ts, batch_size=32, shuffle=True, 
                      drop_last=False)
test_dl = DataLoader(test_ts, batch_size=64, shuffle=True, 
                     drop_last=False)

#### 定义模型

In [None]:
# 定义模型结构
model = t.nn.Sequential(
   t.nn.Linear(784, 100),
   t.nn.ReLU(),
   t.nn.Linear(100, 10),
   t.nn.LogSoftmax(dim=1))
#定义损失函数和优化函数
loss_fn = t.nn.NLLLoss(reduction="mean")
optimizer = t.optim.Adam(model.parameters(), lr=1e-3)

#### 训练模型

In [None]:
for s in range(5):
    print("run in step : %d"%s)
    for i, (x_train, y_train) in enumerate(train_dl):
        x_train = x_train.view(x_train.shape[0], -1)
        y_pred = model(x_train)
        train_loss = loss_fn(y_pred, y_train)
        if (i + 1) % 100 == 0:
            print(i + 1, train_loss.item())
        model.zero_grad() # 梯度置为0
        train_loss.backward() # 反向传播，计算梯度
        optimizer.step() #更新参数

#### 评估模型

In [None]:
total = 0;
correct_count = 0
for test_images, test_labels in test_dl:
    for i in range(len(test_labels)):
        image = test_images[i].view(1, 784)
        with t.no_grad():  
            pred_labels = model(image)
        plabels = t.exp(pred_labels)
        probs = list(plabels.numpy()[0])
        pred_label = probs.index(max(probs))
        true_label = test_labels.numpy()[i]
        if pred_label == true_label:
            correct_count += 1
        total += 1
print("total acc : %.2f\n"%(correct_count / total))

#### 使用模型

#### 模型保存

In [None]:
t.save(model, './nn_mnist_model.pt')

#### 迁移学习

In [None]:
#Finetune the models
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
model_ft.fc = nn.Linear(num_ftrs, 2)

# 提取特征
model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
    param.requires_grad = False

num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)