## 大致流程

- 把训练数据上传至 google drive，如 flower_photos.tgz

- 挂载到 google drive

- 数据解压缩

- 数据集预处理（如果需要）

- model 的定义

- model 的训练：对于运行时必要的参数需要手动添加
  
```python
config = parser.parse_args(args=["--img_path", "./flower_data/", "--vgg_version", "vgg11"])
```

以下是一个 vgg 的 ipynb 的例子

## 上传所需数据到 google drive

把 flower_photos.tgz 上传到 google drive

## Data preprocessing

挂载到 google drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')
!pwd

In [None]:
!mkdir flower_data
!tar -xvf ./drive/MyDrive/image-processing-by-dl/flower_photos.tgz -C flower_data

## Data split

In [None]:
import os
from shutil import copy, rmtree
import random

def mk_file(file_path: str):
    if os.path.exists(file_path):
        # 如果文件夹存在，则先删除原文件夹在重新创建
        rmtree(file_path)
    os.makedirs(file_path)


def main():
    # 保证随机可复现
    random.seed(0)

    # 将数据集中10%的数据划分到验证集中
    split_rate = 0.1

    # 指向你解压后的flower_photos文件夹
    cwd = os.getcwd()
    data_root = os.path.join(cwd, "flower_data")
    origin_flower_path = os.path.join(data_root, "flower_photos")
    assert os.path.exists(origin_flower_path), "path '{}' does not exist.".format(origin_flower_path)

    flower_class = [cla for cla in os.listdir(origin_flower_path)
                    if os.path.isdir(os.path.join(origin_flower_path, cla))]

    # 建立保存训练集的文件夹
    train_root = os.path.join(data_root, "train")
    mk_file(train_root)
    for cla in flower_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(train_root, cla))

    # 建立保存验证集的文件夹
    val_root = os.path.join(data_root, "val")
    mk_file(val_root)
    for cla in flower_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(val_root, cla))

    for cla in flower_class:
        cla_path = os.path.join(origin_flower_path, cla)
        images = os.listdir(cla_path)
        num = len(images)
        # 随机采样验证集的索引
        eval_index = random.sample(images, k=int(num*split_rate))
        for index, image in enumerate(images):
            if image in eval_index:
                # 将分配至验证集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(val_root, cla)
                copy(image_path, new_path)
            else:
                # 将分配至训练集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(train_root, cla)
                copy(image_path, new_path)
            print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar
        print()

    print("processing done!")

main()

## Define model

In [4]:
import torch
import torch.nn as nn

class VGG(nn.Module):
    def __init__(self, vgg_version, num_classes=1000):
        super().__init__()
        self.enc = get_vgg_enc(vgg_version)    # 编码器 (特征提取)
        self.head = nn.Sequential( # 预测头
            nn.Linear(512*7*7, 4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Linear(4096, num_classes)
        )

    def forward(self, img):
        feats = self.enc(img) # 5 个 maxpool, (3x224x224) -> (512x7x7)
        feats = torch.flatten(feats, start_dim=1)
        output = self.head(feats)
        return output

def get_vgg_enc(vgg_version):
    assert vgg_version in cfgs.keys(), "vgg version not found"

    layers = []
    in_channels = 3
    for v in cfgs[vgg_version]:
        if v == 'M':
            layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
        else:
            layers.append(nn.Conv2d(in_channels, v, kernel_size=3, padding=1, stride=1))
            layers.append(nn.ReLU())
            in_channels = v
    return nn.Sequential(*layers)
        
cfgs = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

## Train Model

In [None]:
import argparse
import os
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from torchvision import transforms, datasets, utils

import matplotlib.pyplot as plt

def get_config():
    parser = argparse.ArgumentParser()
    parser.add_argument("--img_path", type=str, help="path of image to train, if None", default=None, required=True)
    parser.add_argument("--vgg_version", type=str, help="vgg version, optional: vgg11, vgg13, vgg16, vgg19", default=None, required=True)
    parser.add_argument("--output_path", type=str, help="output file's saving path", default="./output", required=False)
    parser.add_argument("--lr", type=float, help="learning rate", default=0.0002, required=False)
    parser.add_argument("--epoch", type=int, help="epoch", default=20, required=False)
    parser.add_argument("--batch_size", type=int, help="batch size", default=32, required=False)

    config = parser.parse_args(args=["--img_path", "./flower_data/", "--vgg_version", "vgg11"])
    return config

def main(config):
    print(config)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    transform = {
        "train": transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]),
        "val": transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    }

    nw = min([os.cpu_count(), 8, config.batch_size if config.batch_size > 1 else 0])

    assert config.img_path, "img_path is needed"

    train_root = os.path.join(config.img_path, "train")
    val_root = os.path.join(config.img_path, "val")
    train_set = datasets.ImageFolder(root=train_root, transform=transform["train"])
    train_loader = DataLoader(train_set, shuffle=True, batch_size=config.batch_size, num_workers=nw)
    val_set = datasets.ImageFolder(root=val_root, transform=transform["val"])
    val_loader = DataLoader(val_set, shuffle=False, batch_size=config.batch_size, num_workers=nw)
    print(f"length of train set: {len(train_set)}")
    print(f"length of val set: {len(val_set)}")

    class2idx = train_set.class_to_idx
    print(class2idx)
    idx2class = dict((idx, cla) for cla, idx in class2idx.items())

    model = VGG(config.vgg_version, len(class2idx.items())).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config.lr)

    train_loss_record = []    # train loss record each 100 steps
    train_acc_record = []     # train acc record
    val_loss_record = []    # val loss record
    val_acc_record = []     # val acc record
    best_val_acc = 0.0


    for epoch in range(config.epoch):
        running_loss, running_acc = 0.0, 0.0
        model.train()
        pbar = tqdm(train_loader)
        for data in pbar:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            acc = (outputs.argmax(dim=-1) == labels).float().sum()

            running_loss += loss.item()
            running_acc += acc.item()
            # pbar.desc = f"[{epoch+1} / {config.epoch}] loss: {loss}"

        model.eval()
        val_loss, val_acc = 0.0, 0.0
        with torch.no_grad():
            for data in val_loader:
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                
                loss = criterion(outputs, labels)
                acc = (outputs.argmax(dim=-1) == labels).float().sum()

                val_loss += loss.item()
                val_acc += acc.item()

            
            val_loss_record.append(val_loss)
            val_acc_record.append(val_acc / len(val_set))
            train_loss_record.append(running_loss)
            train_acc_record.append(running_acc / len(train_set))

            if (val_acc_record[-1] > best_val_acc):
                torch.save(model.state_dict(), os.path.join(config.output_path, f"VGG_checkpoint.pth"))
                best_val_acc = val_acc_record[-1]
            print(f"[epoch:{epoch+1:03d}/{config.epoch:03d}] train loss:{train_loss_record[-1]:.4f}, train acc:{train_acc_record[-1]:.4f} | val loss:{val_loss_record[-1]:.4f} val acc:{val_acc_record[-1]:.4f}")

    np.save(os.path.join(config.output_path, "train_loss_record.npy"), train_loss_record)
    np.save(os.path.join(config.output_path, "train_acc_record.npy"), train_acc_record)
    np.save(os.path.join(config.output_path, "val_loss_record.npy"), val_loss_record)
    np.save(os.path.join(config.output_path, "val_acc_record.npy"), val_acc_record)
    plt.figure()
    plt.subplot(221);
    plt.plot(train_loss_record);plt.title("train loss record");
    plt.subplot(222);
    plt.plot(train_acc_record);plt.title("train acc record");
    plt.subplot(223);
    plt.plot(val_loss_record);plt.title("val loss record");
    plt.subplot(224);
    plt.plot(val_acc_record);plt.title("val acc record");
    plt.savefig(os.path.join(config.output_path, "result.png"))


config = get_config()

if config.output_path == "./output" and "output" not in os.listdir("./"):
    os.mkdir("./output")

myseed = 42069  # set a random seed for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(myseed)
torch.manual_seed(myseed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(myseed)
    torch.cuda.manual_seed_all(myseed)

main(config)