In [None]:
import torch
import torch.nn as nn
from torch.utils import data, model_zoo
import numpy as np
import pickle
from torch.autograd import Variable
import torch.optim as optim
import scipy.misc
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import sys
import os
import os.path as osp
import random
import time
import yaml
import easydict
from tensorboardX import SummaryWriter

from trainer_ms import AD_Trainer
from utils.loss import CrossEntropy2d
from utils.tool import adjust_learning_rate, adjust_learning_rate_D, Timer 
from dataset.gta5_dataset import GTA5DataSet
from dataset.cityscapes_dataset import cityscapesDataSet

from config import CONSTS

IMG_MEAN = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)

AUTOAUG = False
AUTOAUG_TARGET = False

MODEL = 'DeepLab'
BATCH_SIZE = 16
ITER_SIZE = 1
NUM_WORKERS = 2
DATA_DIRECTORY = CONSTS.GTA_PATH
DATA_LIST_PATH = CONSTS.GTA_TRAIN_LIST_PATH
DROPRATE = 0.1
IGNORE_LABEL = 255
INPUT_SIZE = '1280,720'
DATA_DIRECTORY_TARGET = CONSTS.CITYSCAPES_PATH
DATA_LIST_PATH_TARGET = CONSTS.CITYSCAPES_TRAIN_LIST_PATH
INPUT_SIZE_TARGET = '1024,512'
CROP_SIZE = '384,192' # 640,360
LEARNING_RATE = 2.5e-4
MOMENTUM = 0.9
MAX_VALUE = 2
NUM_CLASSES = 19
NUM_STEPS = 100000
NUM_STEPS_STOP = 100000  # early stopping
POWER = 0.9
RANDOM_SEED = 1234
RESTORE_FROM = 'http://vllab.ucmerced.edu/ytsai/CVPR18/DeepLab_resnet_pretrained_init-f81d91e8.pth'
SAVE_NUM_IMAGES = 2
SAVE_PRED_EVERY = 5000
SNAPSHOT_DIR = './snapshots/'
WEIGHT_DECAY = 0.0005
WARM_UP = 0 # no warmup
LOG_DIR = './log'

LEARNING_RATE_D = 1e-4
LAMBDA_SEG = 0.1
LAMBDA_ADV_TARGET1 = 0.0002
LAMBDA_ADV_TARGET2 = 0.001

LAMBDA_ME_TARGET = 0
LAMBDA_KL_TARGET = 0

TARGET = 'cityscapes'
SET = 'train'
NORM_STYLE = 'bn' # or in

In [None]:
def get_arguments():
    """Parse all the arguments provided from the CLI.

    Returns:
      A list of parsed arguments.
    """
    args = easydict.EasyDict({
        "autoaug": True,
        "autoaug_target": True,
        "model": MODEL,
        "batch-size": BATCH_SIZE,
        "iter-size": ITER_SIZE,
        "num-workers": NUM_WORKERS,
        "data-dir": DATA_DIRECTORY,
        "data-list": DATA_LIST_PATH,
        "droprate": DROPRATE,
        "ignore-label": IGNORE_LABEL,
        "input-size": INPUT_SIZE,
        "crop-size": CROP_SIZE,
        "data-dir-target": DATA_DIRECTORY_TARGET,
        "data-list-target": DATA_LIST_PATH_TARGET,
        "input-size-target": INPUT_SIZE_TARGET,
        "is-training": True,
        "learning-rate": LEARNING_RATE, # Base learning rate for training with polynomial decay.
        "learning-rate-D": LEARNING_RATE_D, # Base learning rate for discriminator
        "lambda-seg": LAMBDA_SEG,
        "lambda-adv-target1": LAMBDA_ADV_TARGET1,
        "lambda-adv-target2": LAMBDA_ADV_TARGET2,
        "lambda-me-target": LAMBDA_ME_TARGET,
        "lambda-kl-target":LAMBDA_KL_TARGET,
        "momentum": MOMENTUM,
        "max-value": MAX_VALUE,
        "norm-style": NORM_STYLE,
        "lambda-me-target": LAMBDA_ME_TARGET,
        "lambda-kl-target": LAMBDA_KL_TARGET,
        "momentum": MOMENTUM,
        "max-value": MAX_VALUE,
        "norm-style": NORM_STYLE,
        "not-restore-last": True,
        "num-classes": NUM_CLASSES,
        "num-steps": NUM_STEPS,
        "num-steps-stop": NUM_STEPS_STOP,
        "power": POWER,
        "random-mirror": True,
        "random-scale": True,
        "fp16": True,
        "random-seed": RANDOM_SEED,
        "restore-from": RESTORE_FROM,
        "save-num-images": SAVE_NUM_IMAGES,
        "save-pred-every": SAVE_PRED_EVERY,
        "snapshot-dir": SNAPSHOT_DIR,
        "weight-decay": WEIGHT_DECAY,
        "warm-up": WARM_UP,
        "cpu": True,
        "class-balance": True,
        "use-se": True,
        "only-hard-label": 0,
        "train_bn": True,
        "sync_bn": True,
        "often-balance": True,
        "gpu-ids": '0',
        "tensorboard": True,
        "log-dir": LOG_DIR,
        "set": SET,
        "multi_gpu": False,
    })   

    return args


args = get_arguments()

# save opts
if not os.path.exists(args.snapshot_dir):
    os.makedirs(args.snapshot_dir)

with open('%s/opts.yaml'%args.snapshot_dir, 'w') as fp:
    yaml.dump(vars(args), fp, default_flow_style=False)


def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    args.input_size = (w, h)

    w, h = map(int, args.crop_size.split(','))
    args.crop_size = (h, w)

    w, h = map(int, args.input_size_target.split(','))
    args.input_size_target = (w, h)

    cudnn.enabled = True
    cudnn.benchmark = True


    str_ids = args.gpu_ids.split(',')
    gpu_ids = []
    for str_id in str_ids:
        gid = int(str_id)
        if gid >=0:
            gpu_ids.append(gid)

    num_gpu = len(gpu_ids)

    if num_gpu>1:
        args.multi_gpu = True
        Trainer = AD_Trainer(args)
        Trainer.G = torch.nn.DataParallel( Trainer.G, gpu_ids)
        Trainer.D1 = torch.nn.DataParallel( Trainer.D1, gpu_ids)
        Trainer.D2 = torch.nn.DataParallel( Trainer.D2, gpu_ids)
    else:
        Trainer = AD_Trainer(args)

    print(Trainer)

    trainloader = data.DataLoader(
        GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
                    resize_size=args.input_size,
                    crop_size=args.crop_size,
                    scale=True, mirror=True, mean=IMG_MEAN, autoaug = args.autoaug),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
                                                     max_iters=args.num_steps * args.iter_size * args.batch_size,
                                                     resize_size=args.input_size_target,
                                                     crop_size=args.crop_size,
                                                     scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
                                                     set=args.set, autoaug = args.autoaug_target),
                                   batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
                                   pin_memory=True, drop_last=True)


    targetloader_iter = enumerate(targetloader)

    # set up tensor board
    if args.tensorboard:
        args.log_dir += '/'+ os.path.basename(args.snapshot_dir)
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0


        adjust_learning_rate(Trainer.gen_opt , i_iter, args)
        adjust_learning_rate_D(Trainer.dis1_opt, i_iter, args)
        adjust_learning_rate_D(Trainer.dis2_opt, i_iter, args)

        for sub_i in range(args.iter_size):

            # train G

            # train with source

            _, batch = trainloader_iter.__next__()
            _, batch_t = targetloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.cuda()
            labels = labels.long().cuda()
            images_t, labels_t, _, _ = batch_t
            images_t = images_t.cuda()
            labels_t = labels_t.long().cuda()

            with Timer("Elapsed time in update: %f"):
                loss_seg1, loss_seg2, loss_adv_target1, loss_adv_target2, loss_me, loss_kl, pred1, pred2, pred_target1, pred_target2, val_loss = Trainer.gen_update(images, images_t, labels, labels_t, i_iter)
                loss_seg_value1 += loss_seg1.item() / args.iter_size
                loss_seg_value2 += loss_seg2.item() / args.iter_size
                loss_adv_target_value1 += loss_adv_target1 / args.iter_size
                loss_adv_target_value2 += loss_adv_target2 / args.iter_size
                loss_me_value = loss_me

                if args.lambda_adv_target1 > 0 and args.lambda_adv_target2 > 0:
                    loss_D1, loss_D2 = Trainer.dis_update(pred1, pred2, pred_target1, pred_target2)
                    loss_D_value1 += loss_D1.item()
                    loss_D_value2 += loss_D2.item()
                else:
                    loss_D_value1 = 0
                    loss_D_value2 = 0

        del pred1, pred2, pred_target1, pred_target2

        if args.tensorboard:
            scalar_info = {
                'loss_seg1': loss_seg_value1,
                'loss_seg2': loss_seg_value2,
                'loss_adv_target1': loss_adv_target_value1,
                'loss_adv_target2': loss_adv_target_value2,
                'loss_me_target': loss_me_value,
                'loss_kl_target': loss_kl,
                'loss_D1': loss_D_value1,
                'loss_D2': loss_D_value2,
                'val_loss': val_loss,
            }

            if i_iter % 100 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print(
        '\033[1m iter = %8d/%8d \033[0m loss_seg1 = %.3f loss_seg2 = %.3f loss_me = %.3f  loss_kl = %.3f loss_adv1 = %.3f, loss_adv2 = %.3f loss_D1 = %.3f loss_D2 = %.3f, val_loss=%.3f'%(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_me_value, loss_kl, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2, val_loss))

        # clear loss
        del loss_seg1, loss_seg2, loss_adv_target1, loss_adv_target2, loss_me, loss_kl, val_loss

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(Trainer.G.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(Trainer.D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(Trainer.D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(Trainer.G.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(Trainer.D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(Trainer.D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))

    if args.tensorboard:
        writer.close()



In [None]:
main()