# 导入依赖

In [1]:
import os
import multiprocessing as mp

import torch

from models import LeNet5, loss_func_map, eval_func, init_weights
from train import build_loader, build_sgd_optimizer, train_epochs
from utils import mnist_cvs_to_imgfolder, decode_rbf_vectors_from_imgs, setup_logging

# 准备数据集

In [None]:
csv_train_p = "datasets/mnist_train_demo.csv"
csv_test_p = "datasets/mnist_test_demo.csv"
export_train_dir = "data/train"
export_test_dir = "data/test"

# mnist_cvs_to_imgfolder(csv_train_p, export_train_dir)
# mnist_cvs_to_imgfolder(csv_test_p, export_test_dir)

# 加载 RBF vector

In [3]:
rbf_vector_dir = "data/RBF_kernel"

rbf_vectors = decode_rbf_vectors_from_imgs(rbf_vector_dir)
rbf_vectors = torch.as_tensor(rbf_vectors)

# 设定超参数

In [None]:
# dataloader params
batch_size = 4
num_workers = mp.cpu_count() // 2

# sgd params
lr = 1e-4
momentum = 0.9
weight_decay = 1e-4

# other params
num_epoches = 100
device = "cuda:0"

# 搭建 model、dataloader、optimizer、logger

In [None]:
model = LeNet5(rbf_vectors).to(device)
model.apply(init_weights)

train_data_roots = ["data/train_demo"]
test_data_roots = ["data/test_demo"]

train_loader, test_loader = build_loader(
    train_data_roots,
    test_data_roots,
    {str(i): i for i in range(10)},
    batch_size,
    num_workers
)

optimizer = build_sgd_optimizer(
    model.parameters(), 
    lr, 
    momentum = momentum, 
    weight_decay = weight_decay
)

workdir = "runtimes"
log_p = os.path.join(workdir, "train_log.txt")
logger = setup_logging(log_p)


# 开始训练

In [None]:
save_ckpt_dir = os.path.join(workdir, "ckpts")
print_iter_period = 30000
save_eval_epoch_period = 4

train_epochs(
    model,
    train_loader,
    test_loader,
    optimizer,
    loss_func_map,
    eval_func,
    logger,
    device,
    save_ckpt_dir,
    num_epoches,
    print_iter_period,
    save_eval_epoch_period
)
