#### 基于经典网络架构训练图像分类模型
##### 数据预处理部分
- 数据增强:torchvision中transforms模型自带功能，比较实用
- 数据预处理：torchvision中transforms也帮我们实现好了，直接调用即可
- DataLoader模型直接读取batch数据

##### 网络模型设置
- 加载预训练模型，torchvision中有很多经典网络架构，调用起来十分方便，并且可以用人家训练好的权重参数来继续训练，也就是所谓的迁移学习
- 需要注意的是别人训练好的任务跟咱们可不是完全一样的，需要把最后的head层改一改，一般也就是最后的全连接层，改成咱们自己的任务
- 训练果可以全部重头训练，也可以只训练最后咱们任务的层，因为前几层都是特征提取的，本质任务目标是一致的

##### 网络模型保存与测试
- 模型保存的时候可以带有选择性的，例如在验证集中如果当前效果好则保存
- 读取模型进行实际测试

In [5]:
import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import transforms, models, datasets
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

#### 数据的读取和预处理操作

In [6]:
data_dir = './flower_data/'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'


#### 制作好数据源
- data_transforms中指定了所有图像预处理操作
- ImageFolder假设所有的文件按文件夹保存好，每个文件夹下面存储同一类别的图片，文件夹名字为分类的名字

In [7]:
data_transforms = {
   'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转，-45到45度之间随机选
       transforms.CenterCrop(224), # 从中心开始裁剪
       transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转，选择一个概率
       transforms.RandomVerticalFlip(0.5), # 随机垂直翻转
       transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度，参数2为对比度，参数3为饱和度，参数4为色相
       transforms.RandomGrayscale(p=0.025), # 概率转换成灰度率，3通道就是R=G=B
       transforms.ToTensor(),
       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值，标准差
   ]),
    'valid': transforms.Compose([ transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [8]:
batch_size = 8
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']}
dataloaders = {x:torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
dataset_sizes = {x:len(image_datasets[x]) for x in ['train', 'valid']}
class_names = image_datasets['train'].classes

In [9]:
image_datasets

{'train': Dataset ImageFolder
     Number of datapoints: 6552
     Root location: ./flower_data/train
     StandardTransform
 Transform: Compose(
                RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
                CenterCrop(size=(224, 224))
                RandomHorizontalFlip(p=0.5)
                RandomVerticalFlip(p=0.5)
                ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
                RandomGrayscale(p=0.025)
                ToTensor()
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ),
 'valid': Dataset ImageFolder
     Number of datapoints: 818
     Root location: ./flower_data/valid
     StandardTransform
 Transform: Compose(
                Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
                CenterCrop(size=(224, 224))
                ToTensor()
                Normalize(mean=[0.485, 0.456,

In [10]:
dataloaders

{'train': <torch.utils.data.dataloader.DataLoader at 0x7f8c68637700>,
 'valid': <torch.utils.data.dataloader.DataLoader at 0x7f8c68634e80>}

#### 读取标签对应的实际名字

In [12]:
with open('cat_to_name.json', 'r') as f:
    cat_to_name = json.load(f)

In [13]:
cat_to_name

{'21': 'fire lily',
 '3': 'canterbury bells',
 '45': 'bolero deep blue',
 '1': 'pink primrose',
 '34': 'mexican aster',
 '27': 'prince of wales feathers',
 '7': 'moon orchid',
 '16': 'globe-flower',
 '25': 'grape hyacinth',
 '26': 'corn poppy',
 '79': 'toad lily',
 '39': 'siam tulip',
 '24': 'red ginger',
 '67': 'spring crocus',
 '35': 'alpine sea holly',
 '32': 'garden phlox',
 '10': 'globe thistle',
 '6': 'tiger lily',
 '93': 'ball moss',
 '33': 'love in the mist',
 '9': 'monkshood',
 '102': 'blackberry lily',
 '14': 'spear thistle',
 '19': 'balloon flower',
 '100': 'blanket flower',
 '13': 'king protea',
 '49': 'oxeye daisy',
 '15': 'yellow iris',
 '61': 'cautleya spicata',
 '31': 'carnation',
 '64': 'silverbush',
 '68': 'bearded iris',
 '63': 'black-eyed susan',
 '69': 'windflower',
 '62': 'japanese anemone',
 '20': 'giant white arum lily',
 '38': 'great masterwort',
 '4': 'sweet pea',
 '86': 'tree mallow',
 '101': 'trumpet creeper',
 '42': 'daffodil',
 '22': 'pincushion flower',
 

#### 展示下数据
- 注意tensor的数据需要转换成numpy的格式，而且还需要还原回标准化的结果

In [17]:
def im_convert(tensor):
    """
    展示数据
    :param tensor:
    :return:
    """
    image = tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1, 2, 0)
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    image = image.clip(0, 1)

    return image

In [18]:
fig = plt.figure(figsize=(20, 12))
columns = 4
rows = 2
dataiter = iter(dataloaders['valid'])
inputs, classes = dataiter.next()

for idx in range(columns * rows):
    ax = fig.add_subplot(rows, columns, idx+1, xticks = [], yticks= [])
    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
    plt.imshow(im_convert(inputs[idx]))
plt.show()

AttributeError: '_SingleProcessDataLoaderIter' object has no attribute 'next'

<Figure size 2000x1200 with 0 Axes>

#### 加载models中提供的模型，并且直接用训练好的权重当做初始化参数
- 第一次执行需要下载 ，可能会比较慢

In [19]:
model_name = 'resnet' # 可选的比较多['resnet','alexnet','vgg', 'squeezenet', 'inception']
# 是否用人家训练好的特征来做
feature_extract = True

In [20]:
# 是否用GPU训练
train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
    print('CUDA is not available. Training on CPU...')
else:
    print("CUDA is available! Training on GPU...")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

CUDA is not available. Training on CPU...
