In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

In [2]:
%matplotlib inline


Transfer Learning for Computer Vision Tutorial
==============================================

These two major transfer learning scenarios look as follows:

-  **Finetuning the convnet**: Instead of random initialization, we
   initialize the network with a pretrained network, like the one that is
   trained on imagenet 1000 dataset. Rest of the training looks as
   usual.
-  **ConvNet as fixed feature extractor**: Here, we will freeze the weights
   for all of the network except that of the final fully connected
   layer. This last fully connected layer is replaced with a new one
   with random weights and only this layer is trained.


In [3]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

cudnn.benchmark = True
plt.ion()   # interactive mode

Load Data
---------

We will use torchvision and torch.utils.data packages for loading the
data.

The problem we're going to solve today is to train a model to classify
**ants** and **bees**. We have about 120 training images each for ants and bees.
There are 75 validation images for each class. Usually, this is a very
small dataset to generalize upon, if trained from scratch. Since we
are using transfer learning, we should be able to generalize reasonably
well.

This dataset is a very small subset of imagenet.

.. Note ::
   Download the data from
   `here <https://download.pytorch.org/tutorial/hymenoptera_data.zip>`_
   and extract it to the current directory.



In [4]:
# Data augmentation and normalization for training
# Just normalization for validation
image_transforms  = {
    'train': transforms.Compose([
        transforms.Resize(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

dataset = 'microsphere2'
train_directory = os.path.join(dataset, 'train')
valid_directory = os.path.join(dataset, 'valid')

batch_size = 8
num_classes = 240

data = {
    'train': datasets.ImageFolder(root=train_directory, transform=image_transforms['train']),
    'valid': datasets.ImageFolder(root=valid_directory, transform=image_transforms['valid'])

}


train_data_size = len(data['train'])
valid_data_size = len(data['valid'])

train_data = DataLoader(data['train'], batch_size=batch_size, shuffle=True)
valid_data = DataLoader(data['valid'], batch_size=batch_size, shuffle=True)

print(train_data_size, valid_data_size)

1038 346


迁移学习
---------


这里使用ResNet-50的预训练模型。


In [5]:
resnet50 = models.resnet50(pretrained=True)

for param in resnet50.parameters():
    param.requires_grad = True

fc_inputs = resnet50.fc.in_features
resnet50.fc = nn.Sequential(
    nn.Linear(fc_inputs, 256),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(128, 240),
    nn.LogSoftmax(dim=1)

)
resnet50 = resnet50.to('cuda:0')
loss_func = nn.NLLLoss()
optimizer = optim.AdamW(resnet50.parameters(),lr=1e-4)

Training the model
------------------

Now, let's write a general function to train a model. Here, we will
illustrate:

-  Scheduling the learning rate
-  Saving the best model

In the following, parameter ``scheduler`` is an LR scheduler object from
``torch.optim.lr_scheduler``.



In [6]:
def train_and_valid(model, loss_function, optimizer, epochs=25):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    history = []
    best_acc = 0.0
    best_epoch = 0

    for epoch in range(epochs):
        epoch_start = time.time()
        print("Epoch: {}/{}".format(epoch+1, epochs))

        model.train()

        train_loss = 0.0
        train_acc = 0.0
        valid_loss = 0.0
        valid_acc = 0.0

        for i, (inputs, labels) in enumerate(train_data):
            inputs = inputs.to(device)
            labels = labels.to(device)

            #因为这里梯度是累加的，所以每次记得清零
            optimizer.zero_grad()

            outputs = model(inputs)

            loss = loss_function(outputs, labels)

            loss.backward()

            optimizer.step()

            train_loss += loss.item() * inputs.size(0)

            ret, predictions = torch.max(outputs.data, 1)
            correct_counts = predictions.eq(labels.data.view_as(predictions))

            acc = torch.mean(correct_counts.type(torch.FloatTensor))

            train_acc += acc.item() * inputs.size(0)

        with torch.no_grad():
            model.eval()

            for j, (inputs, labels) in enumerate(valid_data):
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)

                loss = loss_function(outputs, labels)

                valid_loss += loss.item() * inputs.size(0)

                ret, predictions = torch.max(outputs.data, 1)
                correct_counts = predictions.eq(labels.data.view_as(predictions))

                acc = torch.mean(correct_counts.type(torch.FloatTensor))

                valid_acc += acc.item() * inputs.size(0)

        avg_train_loss = train_loss/train_data_size
        avg_train_acc = train_acc/train_data_size

        avg_valid_loss = valid_loss/valid_data_size
        avg_valid_acc = valid_acc/valid_data_size

        history.append([avg_train_loss, avg_valid_loss, avg_train_acc, avg_valid_acc])

        if best_acc < avg_valid_acc:
            best_acc = avg_valid_acc
            best_epoch = epoch + 1

        epoch_end = time.time()

        print("Epoch: {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}%, \n\t\tValidation: Loss: {:.4f}, Accuracy: {:.4f}%, Time: {:.4f}s".format(
            epoch+1, avg_valid_loss, avg_train_acc*100, avg_valid_loss, avg_valid_acc*100, epoch_end-epoch_start
        ))
        print("Best Accuracy for validation : {:.4f} at epoch {:03d}".format(best_acc, best_epoch))

        torch.save(model, 'models/'+dataset+'_model_'+str(epoch+1)+'.pt')
    return model, history


In [None]:
num_epochs = 3000
trained_model, history = train_and_valid(resnet50, loss_func, optimizer, num_epochs)
torch.save(history, 'models/'+dataset+'_history.pt')

history = np.array(history)
plt.plot(history[:, 0:2])
plt.legend(['Tr Loss', 'Val Loss'])
plt.xlabel('Epoch Number')
plt.ylabel('Loss')
plt.ylim(0, 1)
plt.savefig(dataset+'_loss_curve.png')
plt.show()

plt.plot(history[:, 2:4])
plt.legend(['Tr Accuracy', 'Val Accuracy'])
plt.xlabel('Epoch Number')
plt.ylabel('Accuracy')
plt.ylim(0, 1)
plt.savefig(dataset+'_accuracy_curve.png')
plt.show()

Epoch: 1/3000
Epoch: 001, Training: Loss: 5.4603, Accuracy: 0.6744%, 
		Validation: Loss: 5.4603, Accuracy: 0.5780%, Time: 22.0370s
Best Accuracy for validation : 0.0058 at epoch 001
Epoch: 2/3000
Epoch: 002, Training: Loss: 5.2170, Accuracy: 0.3854%, 
		Validation: Loss: 5.2170, Accuracy: 1.1561%, Time: 13.5910s
Best Accuracy for validation : 0.0116 at epoch 002
Epoch: 3/3000
Epoch: 003, Training: Loss: 4.9013, Accuracy: 1.2524%, 
		Validation: Loss: 4.9013, Accuracy: 1.4451%, Time: 13.3761s
Best Accuracy for validation : 0.0145 at epoch 003
Epoch: 4/3000
Epoch: 004, Training: Loss: 4.7920, Accuracy: 1.5414%, 
		Validation: Loss: 4.7920, Accuracy: 2.0231%, Time: 13.4747s
Best Accuracy for validation : 0.0202 at epoch 004
Epoch: 5/3000
Epoch: 005, Training: Loss: 4.5637, Accuracy: 1.5414%, 
		Validation: Loss: 4.5637, Accuracy: 2.0231%, Time: 13.4998s
Best Accuracy for validation : 0.0202 at epoch 004
Epoch: 6/3000
Epoch: 006, Training: Loss: 4.5622, Accuracy: 2.2158%, 
		Validation: L

Epoch: 046, Training: Loss: 3.1409, Accuracy: 23.0250%, 
		Validation: Loss: 3.1409, Accuracy: 6.6474%, Time: 13.4850s
Best Accuracy for validation : 0.1098 at epoch 030
Epoch: 47/3000
Epoch: 047, Training: Loss: 3.0258, Accuracy: 26.1079%, 
		Validation: Loss: 3.0258, Accuracy: 10.6936%, Time: 13.3010s
Best Accuracy for validation : 0.1098 at epoch 030
Epoch: 48/3000
Epoch: 048, Training: Loss: 2.9370, Accuracy: 26.3969%, 
		Validation: Loss: 2.9370, Accuracy: 12.1387%, Time: 13.4278s
Best Accuracy for validation : 0.1214 at epoch 048
Epoch: 49/3000
Epoch: 049, Training: Loss: 2.9929, Accuracy: 25.5299%, 
		Validation: Loss: 2.9929, Accuracy: 8.6705%, Time: 13.4897s
Best Accuracy for validation : 0.1214 at epoch 048
Epoch: 50/3000
Epoch: 050, Training: Loss: 2.9596, Accuracy: 28.9017%, 
		Validation: Loss: 2.9596, Accuracy: 10.6936%, Time: 13.5637s
Best Accuracy for validation : 0.1214 at epoch 048
Epoch: 51/3000
Epoch: 051, Training: Loss: 3.0061, Accuracy: 28.2274%, 
		Validation: L

Epoch: 091, Training: Loss: 3.7502, Accuracy: 63.7765%, 
		Validation: Loss: 3.7502, Accuracy: 10.1156%, Time: 13.4378s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 92/3000
Epoch: 092, Training: Loss: 3.7960, Accuracy: 61.4644%, 
		Validation: Loss: 3.7960, Accuracy: 12.7168%, Time: 13.5475s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 93/3000
Epoch: 093, Training: Loss: 3.9741, Accuracy: 63.0058%, 
		Validation: Loss: 3.9741, Accuracy: 10.6936%, Time: 13.5681s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 94/3000
Epoch: 094, Training: Loss: 3.6113, Accuracy: 61.7534%, 
		Validation: Loss: 3.6113, Accuracy: 10.9827%, Time: 13.6250s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 95/3000
Epoch: 095, Training: Loss: 3.5627, Accuracy: 65.3179%, 
		Validation: Loss: 3.5627, Accuracy: 13.5838%, Time: 13.4917s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 96/3000
Epoch: 096, Training: Loss: 3.6404, Accuracy: 65.7033%, 
		Validation:

Epoch: 135/3000
Epoch: 135, Training: Loss: 4.6626, Accuracy: 81.6956%, 
		Validation: Loss: 4.6626, Accuracy: 10.4046%, Time: 13.5261s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 136/3000
Epoch: 136, Training: Loss: 4.7614, Accuracy: 82.1773%, 
		Validation: Loss: 4.7614, Accuracy: 10.1156%, Time: 13.8035s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 137/3000
Epoch: 137, Training: Loss: 4.7675, Accuracy: 82.0809%, 
		Validation: Loss: 4.7675, Accuracy: 12.7168%, Time: 13.8421s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 138/3000
Epoch: 138, Training: Loss: 4.8707, Accuracy: 82.9480%, 
		Validation: Loss: 4.8707, Accuracy: 10.1156%, Time: 13.5467s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 139/3000
Epoch: 139, Training: Loss: 4.6101, Accuracy: 81.4066%, 
		Validation: Loss: 4.6101, Accuracy: 11.2717%, Time: 13.9914s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 140/3000
Epoch: 140, Training: Loss: 4.5639, Accuracy: 81.

Epoch: 179/3000
Epoch: 179, Training: Loss: 6.0171, Accuracy: 88.5356%, 
		Validation: Loss: 6.0171, Accuracy: 14.7399%, Time: 13.4446s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 180/3000
Epoch: 180, Training: Loss: 5.6926, Accuracy: 87.5723%, 
		Validation: Loss: 5.6926, Accuracy: 10.1156%, Time: 13.6757s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 181/3000
Epoch: 181, Training: Loss: 5.3411, Accuracy: 87.2832%, 
		Validation: Loss: 5.3411, Accuracy: 9.5376%, Time: 13.4565s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 182/3000
Epoch: 182, Training: Loss: 5.4390, Accuracy: 89.9807%, 
		Validation: Loss: 5.4390, Accuracy: 11.2717%, Time: 13.4667s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 183/3000
Epoch: 183, Training: Loss: 5.5044, Accuracy: 88.6320%, 
		Validation: Loss: 5.5044, Accuracy: 10.9827%, Time: 13.4029s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 184/3000
Epoch: 184, Training: Loss: 5.5980, Accuracy: 87.4

Epoch: 223/3000
Epoch: 223, Training: Loss: 6.6581, Accuracy: 92.1965%, 
		Validation: Loss: 6.6581, Accuracy: 10.6936%, Time: 13.4707s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 224/3000
Epoch: 224, Training: Loss: 6.7637, Accuracy: 91.2331%, 
		Validation: Loss: 6.7637, Accuracy: 10.9827%, Time: 13.4521s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 225/3000
Epoch: 225, Training: Loss: 5.9932, Accuracy: 91.3295%, 
		Validation: Loss: 5.9932, Accuracy: 11.8497%, Time: 13.3687s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 226/3000
Epoch: 226, Training: Loss: 5.8992, Accuracy: 92.1002%, 
		Validation: Loss: 5.8992, Accuracy: 11.8497%, Time: 13.3997s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 227/3000
Epoch: 227, Training: Loss: 6.1955, Accuracy: 90.9441%, 
		Validation: Loss: 6.1955, Accuracy: 12.1387%, Time: 13.5483s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 228/3000
Epoch: 228, Training: Loss: 6.4738, Accuracy: 94.

Epoch: 267/3000
Epoch: 267, Training: Loss: 6.8231, Accuracy: 93.9306%, 
		Validation: Loss: 6.8231, Accuracy: 11.5607%, Time: 13.4734s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 268/3000
Epoch: 268, Training: Loss: 8.5478, Accuracy: 92.6782%, 
		Validation: Loss: 8.5478, Accuracy: 9.2486%, Time: 13.4111s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 269/3000
Epoch: 269, Training: Loss: 6.1742, Accuracy: 93.4489%, 
		Validation: Loss: 6.1742, Accuracy: 11.5607%, Time: 13.4114s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 270/3000
Epoch: 270, Training: Loss: 6.5744, Accuracy: 93.6416%, 
		Validation: Loss: 6.5744, Accuracy: 11.5607%, Time: 13.4449s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 271/3000
Epoch: 271, Training: Loss: 6.5659, Accuracy: 94.1233%, 
		Validation: Loss: 6.5659, Accuracy: 12.1387%, Time: 13.5412s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 272/3000
Epoch: 272, Training: Loss: 6.7216, Accuracy: 92.8

Epoch: 311/3000
Epoch: 311, Training: Loss: 7.2504, Accuracy: 94.7977%, 
		Validation: Loss: 7.2504, Accuracy: 12.4277%, Time: 13.4197s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 312/3000
Epoch: 312, Training: Loss: 6.7760, Accuracy: 93.5453%, 
		Validation: Loss: 6.7760, Accuracy: 12.1387%, Time: 13.4130s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 313/3000
Epoch: 313, Training: Loss: 6.7755, Accuracy: 94.7977%, 
		Validation: Loss: 6.7755, Accuracy: 8.6705%, Time: 13.4273s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 314/3000
Epoch: 314, Training: Loss: 7.3115, Accuracy: 94.6050%, 
		Validation: Loss: 7.3115, Accuracy: 10.9827%, Time: 13.7136s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 315/3000
Epoch: 315, Training: Loss: 7.0051, Accuracy: 94.1233%, 
		Validation: Loss: 7.0051, Accuracy: 10.9827%, Time: 13.6089s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 316/3000
Epoch: 316, Training: Loss: 7.2195, Accuracy: 94.9

Epoch: 355, Training: Loss: 7.0365, Accuracy: 93.6416%, 
		Validation: Loss: 7.0365, Accuracy: 11.5607%, Time: 13.4377s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 356/3000
Epoch: 356, Training: Loss: 8.0634, Accuracy: 94.7013%, 
		Validation: Loss: 8.0634, Accuracy: 10.6936%, Time: 13.4186s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 357/3000
Epoch: 357, Training: Loss: 6.9342, Accuracy: 95.0867%, 
		Validation: Loss: 6.9342, Accuracy: 14.1618%, Time: 13.4068s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 358/3000
Epoch: 358, Training: Loss: 7.1644, Accuracy: 95.5684%, 
		Validation: Loss: 7.1644, Accuracy: 13.0058%, Time: 13.5111s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 359/3000
Epoch: 359, Training: Loss: 7.7887, Accuracy: 96.5318%, 
		Validation: Loss: 7.7887, Accuracy: 13.8728%, Time: 13.4208s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 360/3000
Epoch: 360, Training: Loss: 7.2604, Accuracy: 94.9904%, 
		Valida

Epoch: 399, Training: Loss: 7.7871, Accuracy: 95.0867%, 
		Validation: Loss: 7.7871, Accuracy: 13.8728%, Time: 13.4622s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 400/3000
Epoch: 400, Training: Loss: 8.3737, Accuracy: 95.6647%, 
		Validation: Loss: 8.3737, Accuracy: 12.4277%, Time: 13.5150s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 401/3000
Epoch: 401, Training: Loss: 8.9162, Accuracy: 94.9904%, 
		Validation: Loss: 8.9162, Accuracy: 11.5607%, Time: 13.5373s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 402/3000
Epoch: 402, Training: Loss: 7.6438, Accuracy: 95.0867%, 
		Validation: Loss: 7.6438, Accuracy: 11.5607%, Time: 13.4587s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 403/3000
Epoch: 403, Training: Loss: 7.2153, Accuracy: 95.8574%, 
		Validation: Loss: 7.2153, Accuracy: 13.2948%, Time: 13.4157s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 404/3000
Epoch: 404, Training: Loss: 7.6590, Accuracy: 95.8574%, 
		Valida

Epoch: 443/3000
Epoch: 443, Training: Loss: 9.1259, Accuracy: 95.4721%, 
		Validation: Loss: 9.1259, Accuracy: 11.8497%, Time: 13.3784s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 444/3000
Epoch: 444, Training: Loss: 8.3612, Accuracy: 96.6281%, 
		Validation: Loss: 8.3612, Accuracy: 11.8497%, Time: 13.4562s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 445/3000
Epoch: 445, Training: Loss: 8.0593, Accuracy: 96.1464%, 
		Validation: Loss: 8.0593, Accuracy: 10.9827%, Time: 13.4468s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 446/3000
Epoch: 446, Training: Loss: 7.7782, Accuracy: 96.9171%, 
		Validation: Loss: 7.7782, Accuracy: 10.9827%, Time: 13.4718s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 447/3000
Epoch: 447, Training: Loss: 7.9973, Accuracy: 96.5318%, 
		Validation: Loss: 7.9973, Accuracy: 11.5607%, Time: 13.4552s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 448/3000
Epoch: 448, Training: Loss: 8.1786, Accuracy: 96.

Epoch: 487/3000
Epoch: 487, Training: Loss: 8.6865, Accuracy: 96.0501%, 
		Validation: Loss: 8.6865, Accuracy: 9.8266%, Time: 15.4429s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 488/3000
Epoch: 488, Training: Loss: 8.2499, Accuracy: 96.4355%, 
		Validation: Loss: 8.2499, Accuracy: 11.8497%, Time: 15.4170s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 489/3000
Epoch: 489, Training: Loss: 8.1987, Accuracy: 96.1464%, 
		Validation: Loss: 8.1987, Accuracy: 12.4277%, Time: 15.3513s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 490/3000
Epoch: 490, Training: Loss: 8.2146, Accuracy: 96.9171%, 
		Validation: Loss: 8.2146, Accuracy: 14.1618%, Time: 15.3463s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 491/3000
Epoch: 491, Training: Loss: 9.5019, Accuracy: 96.2428%, 
		Validation: Loss: 9.5019, Accuracy: 10.9827%, Time: 14.9573s
Best Accuracy for validation : 0.1532 at epoch 078
Epoch: 492/3000
Epoch: 492, Training: Loss: 8.4715, Accuracy: 95.4