## train流程
## upernet

In [None]:
# System libs
import os
import time
# import math
import random
# Numerical libs
import numpy as np
import torch
import torch.nn as nn


### 创建保存点
这里是每一个epoch一个保存点，首先载入

In [None]:
def checkpoint(nets, history, args, epoch_num):
    print('Saving checkpoints')
    (net_encoder, net_decoder) = nets
    suffix_latest = 'epoch_{}.pth'.format(epoch_num)

    dict_encoder = net_encoder.state_dict()
    dict_decoder = net_decoder.state_dict()

    torch.save(history, '{}/history_{}'.format(args.ckpt, suffix_latest))
    torch.save(dict_encoder, '{}/encoder_{}'.format(args.ckpt, suffix_latest))
    torch.save(dict_decoder, '{}/decoder_{}'.format(args.ckpt, suffix_latest))

### 创建

In [None]:
def create_optimizers(nets, args):
    (net_encoder, net_decoder) = nets
    optimizer_encoder = torch.optim.SGD(
        group_weight(net_encoder),
        lr=args.lr_encoder,
        momentum=args.beta1,
        weight_decay=args.weight_decay)
    optimizer_decoder = torch.optim.SGD(
        group_weight(net_decoder),
        lr=args.lr_decoder,
        momentum=args.beta1,
        weight_decay=args.weight_decay)
    return (optimizer_encoder, optimizer_decoder)

## main
这里是已经封装好的main代码

In [None]:
def main(args):
    builder = ModelBuilder()
    # 载入网络结构
    net_encoder = builder.build_encoder(
        arch = args.arch_encoder,
        fc_dim = args.fc_dim,
        weights = args.weights_encoder
    )
    net_decoder = builder.build_encoder(
        arch = args.fc_dim,
        fc_dim = args.fc_dim,
        nr_classes = args.nr_classes,
        weights = args.weights_encoder
    )

    if args.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(net_encoder, net_decoder, args.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(net_encoder, net_decoder)

    print('1 Epoch = {} iters'.format(args.epoch_iters))

    # create loader iterator
    iterator_train = create_multi_source_train_data_loader(args=args)

    # load nets into gpu
    if args.num_gpus > 1:
        segmentation_module = UserScatteredDataParallel(
            segmentation_module,
            device_ids=range(args.num_gpus))
        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder)
    optimizers = create_optimizers(nets, args)

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': []}}

    for epoch in range(args.start_epoach, args.num_epoch + 1):
        train(segmentation_module, iterator_train, optimizers, history, epoach, args)
        checkpoint(nets, history, args, epoch)

    print('Training done')