## train流程
## upernet
因为涉及到多数据集加载和

In [None]:
# System libs
import os
import time
# import math
import random
# Numerical libs
import numpy as np
import torch
import torch.nn as nn


In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.initialized = False
        self.val = None
        self.avg = None
        self.sum = None
        self.count = None

    def initialize(self, val, weight):
        self.val = val
        self.avg = val
        self.sum = val * weight
        self.count = weight
        self.initialized = True

    def update(self, val, weight=1):
        if not self.initialized:
            self.initialize(val, weight)
        else:
            self.add(val, weight)

    def add(self, val, weight):
        self.val = val
        self.sum += val * weight
        self.count += weight
        self.avg = self.sum / self.count

    def value(self):
        return self.val

    def average(self):
        return self.avg


In [None]:
def adjust_learning_rate(optimizers, cur_iter, args):
    scale_running_lr = ((1. - float(cur_iter) / args.max_iters) ** args.lr_pow)
    args.running_lr_encoder = args.lr_encoder * scale_running_lr
    args.running_lr_decoder = args.lr_decoder * scale_running_lr

    (optimizer_encoder, optimizer_decoder) = optimizers
    for param_group in optimizer_encoder.param_groups:
        param_group['lr'] = args.running_lr_encoder
    for param_group in optimizer_decoder.param_groups:
        param_group['lr'] = args.running_lr_decoder

In [None]:
def train(segmentation_module, iterator, optimizers, history, epoch, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()

    names = ['object', 'part', 'scene', 'material']
    ave_losses = {n: AverageMeter() for n in names}
    ave_metric = {n: AverageMeter() for n in names}
    ave_losses['total'] = AverageMeter()

    # segmetation_module 为网络加载结构
    segmentation_module.train(not args.fix_bn)

    # main loop
    tic = time.time()
    for i in range(args.epoch_iters):
        # next 方法，返回迭代器的下个使用
        batch_data, src_idx = next(iterator)
        data_time.update(time.time() - tic)
        segmentation_module.zero_grad()

        # 前向传播
        ret = segmentation_module(batch_data)

        # 计算反向传播
        loss = ret['loss']['total'].mean()
        loss.backward()
        for optimizer in optimizers:
            optimizer.step()

        # 计算剩余时间
        batch_time.update(time.time()-tic)
        tic = time.time()

        # measure losses
        for name in ret['loss'].keys():
            ave_losses[name].update(ret['loss'][name].mean().item())

        # measure metrics
        for name in ret['metric'].keys():
            ave_metric[name].update(ret['metric'][name].mean().item())

        # calculate acc and display: logging output
        if i % args.disp_iter == 0:
            loss_info = "Loss: total {:.4f}, ".format(ave_losses['total'].average())
            loss_info += ", ".join(["{} {:.2f}".format(
                n[0], ave_losses[n].average()
                if ave_losses[n].average() is not None else 0) for n in names])
            acc_info = "Accuracy: " + ", ".join(["{} {:4.2f}".format(
                n[0], ave_metric[n].average()
                if ave_metric[n].average() is not None else 0) for n in names])
            print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, '
                  'LR: encoder {:.6f}, decoder {:.6f}, {}, {}'
                  .format(epoch, i, args.epoch_iters,
                          batch_time.average(), data_time.average(),
                          args.running_lr_encoder, args.running_lr_decoder,
                          acc_info, loss_info))

            fractional_epoch = epoch - 1 + 1. * i / args.epoch_iters
            history['train']['epoch'].append(fractional_epoch)
            history['train']['loss'].append(loss.item())

        # adjust learning rate
        cur_iter = i + (epoch - 1) * args.epoch_iters
        adjust_learning_rate(optimizers, cur_iter, args)


### 创建保存点
这里是每一个epoch一个保存点，首先载入

In [None]:
def checkpoint(nets, history, args, epoch_num):
    print('Saving checkpoints')
    (net_encoder, net_decoder) = nets
    suffix_latest = 'epoch_{}.pth'.format(epoch_num)

    dict_encoder = net_encoder.state_dict()
    dict_decoder = net_decoder.state_dict()

    torch.save(history, '{}/history_{}'.format(args.ckpt, suffix_latest))
    torch.save(dict_encoder, '{}/encoder_{}'.format(args.ckpt, suffix_latest))
    torch.save(dict_decoder, '{}/decoder_{}'.format(args.ckpt, suffix_latest))

### 创建optimizer

In [None]:
def group_weight(module):
    group_decay = []
    group_no_decay = []
    for m in module.modules():
        if isinstance(m, nn.Linear):
            group_decay.append(m.weight)
            if m.bias is not None:
                group_no_decay.append(m.bias)
        elif isinstance(m, nn.modules.conv._ConvNd):
            group_decay.append(m.weight)
            if m.bias is not None:
                group_no_decay.append(m.bias)
        elif isinstance(m, nn.modules.batchnorm._BatchNorm):
            if m.weight is not None:
                group_no_decay.append(m.weight)
            if m.bias is not None:
                group_no_decay.append(m.bias)

    assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay)
    groups = [dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0)]
    return groups

In [None]:
def create_optimizers(nets, args):
    (net_encoder, net_decoder) = nets
    optimizer_encoder = torch.optim.SGD(
        group_weight(net_encoder),
        lr=args.lr_encoder,
        momentum=args.beta1,
        weight_decay=args.weight_decay)
    optimizer_decoder = torch.optim.SGD(
        group_weight(net_decoder),
        lr=args.lr_decoder,
        momentum=args.beta1,
        weight_decay=args.weight_decay)
    return (optimizer_encoder, optimizer_decoder)

## 创建一个多重数据集的加载器
其实我觉得是可以根据个人不同的东西

In [None]:
def create_multi_source_train_data_loader(args):
    training_records = broden_dataset.record_list['train']

    # 0: object, part, scene
    # 1: material
    multi_source_iters = []
    for idx_source in range(len(training_records)):
        # def __init__(self, records, source_idx, opt, max_sample=-1, batch_per_gpu=1):
        dataset = TrainDataset(training_records[idx_source], idx_source, args,
                               batch_per_gpu=args.batch_size_per_gpu)
        loader_object_part_scene = torchdata.DataLoader(
            dataset,
            batch_size=args.num_gpus,  # we have modified data_parallel
            shuffle=False,  # we do not use this param
            collate_fn=user_scattered_collate,
            num_workers=int(args.workers),
            drop_last=True,
            pin_memory=True)
        multi_source_iters.append(iter(loader_object_part_scene))

    # sample from multi source
    nr_record = [len(records) for records in training_records]
    sample_prob = np.asarray(nr_record) / np.sum(nr_record)
    while True:  # TODO(LYC):: set random seed.
        source_idx = np.random.choice(len(training_records), 1, p=sample_prob)[0]
        yield next(multi_source_iters[source_idx]), source_idx

## main
这里是已经封装好的main代码

In [None]:
def main(args):
    builder = ModelBuilder()
    # 载入网络结构
    net_encoder = builder.build_encoder(
        arch = args.arch_encoder,
        fc_dim = args.fc_dim,
        weights = args.weights_encoder
    )
    net_decoder = builder.build_encoder(
        arch = args.fc_dim,
        fc_dim = args.fc_dim,
        nr_classes = args.nr_classes,
        weights = args.weights_encoder
    )

    if args.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(net_encoder, net_decoder, args.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(net_encoder, net_decoder)

    print('1 Epoch = {} iters'.format(args.epoch_iters))

    # create loader iterator
    iterator_train = create_multi_source_train_data_loader(args=args)

    # load nets into gpu
    if args.num_gpus > 1:
        segmentation_module = UserScatteredDataParallel(
            segmentation_module,
            device_ids=range(args.num_gpus))
        # For sync bn
        patch_replication_callback(segmentation_module)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_encoder, net_decoder)
    optimizers = create_optimizers(nets, args)

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': []}}

    for epoch in range(args.start_epoach, args.num_epoch + 1):
        train(segmentation_module, iterator_train, optimizers, history, epoch, args)
        checkpoint(nets, history, args, epoch)

    print('Training done')