In [1]:
# System libs
import os
import time
# import math
import random
import argparse
from distutils.version import LooseVersion
# Numerical libs
import torch
import torch.nn as nn
# Our libs
from config import *
from dataset import TrainDataset
# from models import ModelBuilder, SegmentationModule
from models import *
from utils import AverageMeter, parse_devices, setup_logger
from lib.nn import UserScatteredDataParallel, user_scattered_collate, patch_replication_callback

In [2]:
from matplotlib import pyplot as plt

In [3]:
cfg = {
    'root_dataset' : "./data/",
    'list_train' : "./data/training.odgt",
    'list_val' : "./data/validation.odgt",
    'num_class' : 2,
    'imgSizes' : (480, 480),
    'imgMaxSize' : 1000,
    'padding_constant' : 8,
    'segm_downsampling_rate' : 8,
    'random_flip' : True
}

In [4]:
from types import SimpleNamespace
cfg = SimpleNamespace(**cfg)


In [5]:
dataset_train = TrainDataset(
        './dataset/',
        './data/training.odgt',
        cfg,
        batch_per_gpu=2)

# samples: 5


In [6]:
loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=1,  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        collate_fn=user_scattered_collate,
        num_workers=1,
        drop_last=True,
        pin_memory=True)

In [7]:
iterator_train = iter(loader_train)

In [8]:
a = next(iterator_train)

In [10]:
a['img_data']

TypeError: list indices must be integers or slices, not str

In [None]:
a[0]['img_data'].shape

In [None]:
a[0]['seg_label'].shape

In [None]:
plt.imshow(a[0]['img_data'][0, :, : , :].permute((1, 2, 0)))

In [None]:
plt.imshow(a[0]['seg_label'][0, :, :])

In [None]:
plt.imshow(a[0]['img_data'][1, :, : , :].permute((1, 2, 0)))

In [None]:
plt.imshow(a[0]['seg_label'][1, :, :])

In [None]:
cfg = {
    'a' : {
        'b' : 2
    },
    'c' : 5
}

In [None]:
from types import SimpleNamespace
cfg = SimpleNamespace(**cfg)


In [None]:
assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), \
    'PyTorch>=0.4.0 is required'

parser = argparse.ArgumentParser(
    description="PyTorch Semantic Segmentation Training"
)
parser.add_argument(
    "--cfg",
    default="config/resnet50dilated-ppm_deepsup.yaml",
    metavar="FILE",
    help="path to config file",
    type=str,
)
# parser.add_argument(
#     "--gpus",
#     default="0-3",
#     help="gpus to use, e.g. 0-3 or 0,1,2,3"
# )
parser.add_argument(
    "--gpus",
    default="0",
    help="gpus to use, e.g. 0-3 or 0,1,2,3"
)
parser.add_argument(
    "opts",
    help="Modify config options using the command-line",
    default=None,
    nargs=argparse.REMAINDER,
)
args = parser.parse_args(args=[])

In [None]:
cfg.merge_from_file(args.cfg)
cfg.merge_from_list(args.opts)
# cfg.freeze()

logger = setup_logger(distributed_rank=0)   # TODO
logger.info("Loaded configuration file {}".format(args.cfg))
logger.info("Running with config:\n{}".format(cfg))

# Output directory
if not os.path.isdir(cfg.DIR):
    os.makedirs(cfg.DIR)
logger.info("Outputing checkpoints to: {}".format(cfg.DIR))
with open(os.path.join(cfg.DIR, 'config.yaml'), 'w') as f:
    f.write("{}".format(cfg))

# Start from checkpoint
if cfg.TRAIN.start_epoch > 0:
    cfg.MODEL.weights_encoder = os.path.join(
        cfg.DIR, 'encoder_epoch_{}.pth'.format(cfg.TRAIN.start_epoch))
    cfg.MODEL.weights_decoder = os.path.join(
        cfg.DIR, 'decoder_epoch_{}.pth'.format(cfg.TRAIN.start_epoch))
    assert os.path.exists(cfg.MODEL.weights_encoder) and \
        os.path.exists(cfg.MODEL.weights_decoder), "checkpoint does not exitst!"

# Parse gpu ids
gpus = parse_devices(args.gpus)
gpus = [x.replace('gpu', '') for x in gpus]
gpus = [int(x) for x in gpus]
num_gpus = len(gpus)
cfg.TRAIN.batch_size = num_gpus * cfg.TRAIN.batch_size_per_gpu

cfg.TRAIN.max_iters = cfg.TRAIN.epoch_iters * cfg.TRAIN.num_epoch
cfg.TRAIN.running_lr_encoder = cfg.TRAIN.lr_encoder
cfg.TRAIN.running_lr_decoder = cfg.TRAIN.lr_decoder

random.seed(cfg.TRAIN.seed)
torch.manual_seed(cfg.TRAIN.seed)

main(cfg, gpus)
