# PyTorch迁移学习训练
目的：finetuning一个神经网络模型，分类猫vs狗图片

In [1]:
%matplotlib inline

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cpu')

In [3]:
from pprint import pprint

import matplotlib.pyplot as plt
import numpy as np
from IPython.core.debugger import set_trace

In [4]:
from my_train_helper import get_trainable, get_frozen, all_trainable, all_frozen, freeze_all

# 数据集

## Transforms

In [5]:
from torchvision import transforms

_image_size = 224
_mean = [0.485, 0.456, 0.406]
_std = [0.229, 0.224, 0.225]


train_trans = transforms.Compose([
    transforms.Resize(256),  # some images are pretty small
    transforms.RandomCrop(_image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(.3, .3, .3),
    transforms.ToTensor(),
    transforms.Normalize(_mean, _std),
])
val_trans = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(_image_size),
    transforms.ToTensor(),
    transforms.Normalize(_mean, _std),
])

## Dataset

In [6]:
from my_datasets import DogsCatsDataset

In [7]:
train_ds = DogsCatsDataset("../data/raw", "sample/train", transform=train_trans)
val_ds = DogsCatsDataset("../data/raw", "sample/valid", transform=val_trans)

batch_size = 2
n_classes = 2

Loading data from ../data/raw/dogscats/sample/train.
Loading data from ../data/raw/dogscats/sample/valid.


完整数据集如下

In [None]:
# train_ds = DogsCatsDataset("../data/raw", "train", transform=train_trans)
# val_ds = DogsCatsDataset("../data/raw", "valid", transform=val_trans)

# batch_size = 128
# n_classes = 2

In [8]:
len(train_ds), len(val_ds)

(16, 8)

## DataLoader
批量读取数据（可添加多进程、抽样方法等）

In [9]:
from torch.utils.data import DataLoader


train_dl = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
)

val_dl = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
)

# The Model
PyTorch 内置网络架构及预训练权重下载[pre-trained networks](https://pytorch.org/docs/stable/torchvision/models.html)例如：
- AlexNet
- VGG
- ResNet
- SqueezeNet
- DenseNet
- Inception v3

更多模型下载：[pretrained-models.pytorch](https://github.com/Cadene/pretrained-models.pytorch):
- NASNet,
- ResNeXt,
- InceptionV4,
- InceptionResnetV2, 
- Xception, 
- DPN,
- ...

我们使用resnet18:

In [10]:
from torchvision import models

model = models.resnet18(pretrained=False)#True, 下载预训练权重

In [11]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [14]:
import torchsummary

torchsummary.summary(model, (3, 224, 224), device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64,

In [15]:
# 冻结所有参数
for param in model.parameters():
    param.requires_grad = False

替换最后一层网络结构，并在我们的数据集上微调模型权重：`requires_grad = True`.

In [16]:
model.fc = nn.Linear(512, n_classes)

In [18]:
def get_model(n_classes=2):
    model = models.resnet18(pretrained=False)
    for param in model.parameters():
        param.requires_grad = False
    model.fc = nn.Linear(512, n_classes)
    model = model.to(DEVICE)
    return model


model = get_model()

# 损失函数

In [19]:
criterion = nn.CrossEntropyLoss()

# 优化方法

In [20]:
optimizer = torch.optim.Adam(
    get_trainable(model.parameters()),
    lr=0.001,
    # momentum=0.9,
)

# 训练迭代

In [21]:
N_EPOCHS = 2

for epoch in range(N_EPOCHS):
    
    # 训练模式
    model.train()
    
    # mini-batch
    total_loss, n_correct, n_samples = 0.0, 0, 0
    for batch_i, (X, y) in enumerate(train_dl):
        X, y = X.to(DEVICE), y.to(DEVICE)
        
        optimizer.zero_grad()
        y_ = model(X)
        loss = criterion(y_, y)
        loss.backward()
        optimizer.step()
        
        # 打印loss
        # print(
        #     f"Epoch {epoch+1}/{N_EPOCHS} |"
        #     f"  batch: {batch_i} |"
        #     f"  batch loss:   {loss.item():0.3f}"
        #)
        _, y_label_ = torch.max(y_, 1)
        n_correct += (y_label_ == y).sum().item()
        total_loss += loss.item() * X.shape[0]
        n_samples += X.shape[0]
    
    print(
        f"Epoch {epoch+1}/{N_EPOCHS} |"
        f"  train loss: {total_loss / n_samples:9.3f} |"
        f"  train acc:  {n_correct / n_samples * 100:9.3f}%"
    )
    
    
    # Eval
    model.eval()#重要！（dropout关闭）
    
    total_loss, n_correct, n_samples = 0.0, 0, 0
    with torch.no_grad():
        for X, y in val_dl:
            X, y = X.to(DEVICE), y.to(DEVICE)
                    
            y_ = model(X)
        
            # 计算loss
            _, y_label_ = torch.max(y_, 1)
            n_correct += (y_label_ == y).sum().item()
            loss = criterion(y_, y)
            total_loss += loss.item() * X.shape[0]
            n_samples += X.shape[0]

    
    print(
        f"Epoch {epoch+1}/{N_EPOCHS} |"
        f"  valid loss: {total_loss / n_samples:9.3f} |"
        f"  valid acc:  {n_correct / n_samples * 100:9.3f}%"
    )

Epoch 1/2 |  train loss:     0.923 |  train acc:     37.500%
Epoch 1/2 |  valid loss:     0.710 |  valid acc:     37.500%
Epoch 2/2 |  train loss:     0.946 |  train acc:     37.500%
Epoch 2/2 |  valid loss:     0.746 |  valid acc:     50.000%


# 练习
- 使用任意一个预训练网络结构作为backbone，并替换最后一层（head），训练网络

In [None]:
class Net(nn.Module):
    def __init__(self, backbone: nn.Module, n_classes: int):
        super().__init__()
        # self.backbone
        # self.head = init_head(n_classes)
        
    def forward(self, x):
        # TODO
        return x