In [1]:
import os
import cv2
import random
import numpy as np
import torch
import argparse

from edge_connect.src.config import Config
from edge_connect.src.edge_connect import EdgeConnect

In [5]:
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(e) for e in [0])

In [7]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.backends.cudnn.benchmark = True   # cudnn auto-tuner
# set cv2 running threads to 1 (prevents deadlocks with pytorch dataloader)
cv2.setNumThreads(0)

# initialize random seed
SEED = 10
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
random.seed(SEED)

In [38]:
import os
import yaml
class Config(dict):
    def __init__(self, config_path):
        with open(config_path, 'r') as f:
            self._yaml = f.read()
            self._dict = yaml.safe_load(self._yaml)
            self._dict['PATH'] = os.path.dirname(config_path)

    def __getattr__(self, name):
        if self._dict.get(name) is not None:
            return self._dict[name]

        if DEFAULT_CONFIG.get(name) is not None:
            return DEFAULT_CONFIG[name]

        return None

    def print(self):
        print('Model configurations:')
        print('---------------------------------')
        print(self._yaml)
        print('')
        print('---------------------------------')
        print('')


DEFAULT_CONFIG = {
    'MODE': 1,                      # 1: train, 2: test, 3: eval
    'MODEL': 1,                     # 1: edge model, 2: inpaint model, 3: edge-inpaint model, 4: joint model
    'MASK': 3,                      # 1: random block, 2: half, 3: external, 4: (external, random block), 5: (external, random block, half)
    'EDGE': 1,                      # 1: canny, 2: external
    'NMS': 1,                       # 0: no non-max-suppression, 1: applies non-max-suppression on the external edges by multiplying by Canny
    'SEED': 10,                     # random seed
    'GPU': [0],                     # list of gpu ids
    'DEBUG': 0,                     # turns on debugging mode
    'VERBOSE': 0,                   # turns on verbose mode in the output console

    'LR': 0.0001,                   # learning rate
    'D2G_LR': 0.1,                  # discriminator/generator learning rate ratio
    'BETA1': 0.0,                   # adam optimizer beta1
    'BETA2': 0.9,                   # adam optimizer beta2
    'BATCH_SIZE': 8,                # input batch size for training
    'INPUT_SIZE': 256,              # input image size for training 0 for original size
    'SIGMA': 2,                     # standard deviation of the Gaussian filter used in Canny edge detector (0: random, -1: no edge)
    'MAX_ITERS': 2e6,               # maximum number of iterations to train the model

    'EDGE_THRESHOLD': 0.5,          # edge detection threshold
    'L1_LOSS_WEIGHT': 1,            # l1 loss weight
    'FM_LOSS_WEIGHT': 10,           # feature-matching loss weight
    'STYLE_LOSS_WEIGHT': 1,         # style loss weight
    'CONTENT_LOSS_WEIGHT': 1,       # perceptual loss weight
    'INPAINT_ADV_LOSS_WEIGHT': 0.01,# adversarial loss weight

    'GAN_LOSS': 'nsgan',            # nsgan | lsgan | hinge
    'GAN_POOL_SIZE': 0,             # fake images pool size

    'SAVE_INTERVAL': 1000,          # how many iterations to wait before saving model (0: never)
    'SAMPLE_INTERVAL': 1000,        # how many iterations to wait before sampling (0: never)
    'SAMPLE_SIZE': 12,              # number of images to sample
    'EVAL_INTERVAL': 0,             # how many iterations to wait before model evaluation (0: never)
    'LOG_INTERVAL': 10,             # how many iterations to wait before logging training status (0: never)
}

In [39]:
config = Config("edge_connect/config.yml.example")

In [24]:
from edge_connect.src.models import EdgeModel
from edge_connect.src.metrics import PSNR, EdgeAccuracy

In [25]:
edge_model = EdgeModel(config).to(config.DEVICE)
psnr = PSNR(255.0).to(config.DEVICE)
edgeacc = EdgeAccuracy(config.EDGE_THRESHOLD).to(config.DEVICE)

In [None]:
samples_path = os.path.join(config.PATH, 'samples')
results_path = os.path.join(config.PATH, 'results')
if config.RESULTS is not None:
    results_path = os.path.join(config.RESULTS)

if config.DEBUG is not None and config.DEBUG != 0:
    debug = True

log_file = os.path.join(config.PATH, 'log_' + model_name + '.dat')

In [26]:
#edge_model

In [27]:
edge_model.load()

In [None]:
train_dataset = Dataset(config, config.TRAIN_FLIST, config.TRAIN_EDGE_FLIST, config.TRAIN_MASK_FLIST, augment=True, training=True)
val_dataset = Dataset(config, config.VAL_FLIST, config.VAL_EDGE_FLIST, config.VAL_MASK_FLIST, augment=False, training=True)
sample_iterator = val_dataset.create_iterator(config.SAMPLE_SIZE)

In [None]:
train_loader = DataLoader(
            dataset=train_dataset,
            batch_size=config.BATCH_SIZE,
            num_workers=4,
            drop_last=True,
            shuffle=True
        )

### training

In [None]:
def log(logs):
    with open(log_file, 'a') as f:
        f.write('%s\n' % ' '.join([str(item[1]) for item in logs]))

In [None]:
epoch = 0
keep_training = True
model = config.MODEL
max_iteration = int(float((config.MAX_ITERS)))
total = len(train_dataset)

if total == 0:
    print('No training data was provided! Check \'TRAIN_FLIST\' value in the configuration file.')
    return

while(keep_training):
    epoch += 1
    print('\n\nTraining epoch: %d' % epoch)

    progbar = Progbar(total, width=20, stateful_metrics=['epoch', 'iter'])

    for items in train_loader:
        edge_model.train()
        images, images_gray, edges, masks = cuda(*items)

        # edge model
        # train
        outputs, gen_loss, dis_loss, logs = edge_model.process(images_gray, edges, masks)

        # metrics
        precision, recall = edgeacc(edges * masks, outputs * masks)
        logs.append(('precision', precision.item()))
        logs.append(('recall', recall.item()))

        # backward
        edge_model.backward(gen_loss, dis_loss)
        iteration = edge_model.iteration
        
    if iteration >= max_iteration:
        keep_training = False
        break

    logs = [
        ("epoch", epoch),
        ("iter", iteration),
    ] + logs

    progbar.add(len(images), values=logs if config.VERBOSE else [x for x in logs if not x[0].startswith('l_')])

    # log model at checkpoints
    if config.LOG_INTERVAL and iteration % config.LOG_INTERVAL == 0:
        log(logs)

    # sample model at checkpoints
    if config.SAMPLE_INTERVAL and iteration % config.SAMPLE_INTERVAL == 0:
        sample()

    # evaluate model at checkpoints
    if config.EVAL_INTERVAL and iteration % config.EVAL_INTERVAL == 0:
        print('\nstart eval...\n')
        eval()

    # save model at checkpoints
    if config.SAVE_INTERVAL and iteration % config.SAVE_INTERVAL == 0:
        save()

### evaluation

In [None]:
val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=config.BATCH_SIZE,
    drop_last=True,
    shuffle=True
)

model = self.config.MODEL
total = len(val_dataset)

edge_model.eval()

progbar = Progbar(total, width=20, stateful_metrics=['it'])
iteration = 0

for items in val_loader:
    iteration += 1
    images, images_gray, edges, masks = self.cuda(*items)

    # edge model
    if model == 1:
        # eval
        outputs, gen_loss, dis_loss, logs = edge_model.process(images_gray, edges, masks)

        # metrics
        precision, recall = edgeacc(edges * masks, outputs * masks)
        logs.append(('precision', precision.item()))
        logs.append(('recall', recall.item()))
    logs = [("it", iteration), ] + logs
    progbar.add(len(images), values=logs)

### test

In [None]:
edge_model.eval()
model = config.MODEL
create_dir(results_path)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=1,
)

index = 0
for items in test_loader:
    name = test_dataset.load_name(index)
    images, images_gray, edges, masks = cuda(*items)
    index += 1

    # edge model
    if model == 1:
        outputs = edge_model(images_gray, edges, masks)
        outputs_merged = (outputs * masks) + (edges * (1 - masks))

    output = postprocess(outputs_merged)[0]
    path = os.path.join(results_path, name)
    print(index, name)

    imsave(output, path)

    if self.debug:
        edges = postprocess(1 - edges)[0]
        masked = postprocess(images * (1 - masks) + masks)[0]
        fname, fext = name.split('.')
        imsave(edges, os.path.join(results_path, fname + '_edge.' + fext))
        imsave(masked, os.path.join(results_path, fname + '_masked.' + fext))
print('\nEnd test....')

In [None]:
def main(mode=None):
    r"""starts the model

    Args:
        mode (int): 1: train, 2: test, 3: eval, reads from config file if not specified
    """

    config = load_config(mode)


    # cuda visble devices
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(e) for e in config.GPU)


    # init device
    if torch.cuda.is_available():
        config.DEVICE = torch.device("cuda")
        torch.backends.cudnn.benchmark = True   # cudnn auto-tuner
    else:
        config.DEVICE = torch.device("cpu")



    # set cv2 running threads to 1 (prevents deadlocks with pytorch dataloader)
    cv2.setNumThreads(0)


    # initialize random seed
    torch.manual_seed(config.SEED)
    torch.cuda.manual_seed_all(config.SEED)
    np.random.seed(config.SEED)
    random.seed(config.SEED)



    # build the model and initialize
    model = EdgeConnect(config)
    model.load()


    # model training
    if config.MODE == 1:
        config.print()
        print('\nstart training...\n')
        model.train()

    # model test
    elif config.MODE == 2:
        print('\nstart testing...\n')
        model.test()

    # eval mode
    else:
        print('\nstart eval...\n')
        model.eval()


def load_config(mode=None):
    r"""loads model config

    Args:
        mode (int): 1: train, 2: test, 3: eval, reads from config file if not specified
    """

    parser = argparse.ArgumentParser()
    parser.add_argument('--path', '--checkpoints', type=str, default='./checkpoints', help='model checkpoints path (default: ./checkpoints)')
    parser.add_argument('--model', type=int, choices=[1, 2, 3, 4], help='1: edge model, 2: inpaint model, 3: edge-inpaint model, 4: joint model')

    # test mode
    if mode == 2:
        parser.add_argument('--input', type=str, help='path to the input images directory or an input image')
        parser.add_argument('--mask', type=str, help='path to the masks directory or a mask file')
        parser.add_argument('--edge', type=str, help='path to the edges directory or an edge file')
        parser.add_argument('--output', type=str, help='path to the output directory')

    args = parser.parse_args()
    config_path = os.path.join(args.path, 'config.yml')

    # create checkpoints path if does't exist
    if not os.path.exists(args.path):
        os.makedirs(args.path)

    # copy config template if does't exist
    if not os.path.exists(config_path):
        copyfile('./config.yml.example', config_path)

    # load config file
    config = Config(config_path)

    # train mode
    if mode == 1:
        config.MODE = 1
        if args.model:
            config.MODEL = args.model

    # test mode
    elif mode == 2:
        config.MODE = 2
        config.MODEL = args.model if args.model is not None else 3
        config.INPUT_SIZE = 0

        if args.input is not None:
            config.TEST_FLIST = args.input

        if args.mask is not None:
            config.TEST_MASK_FLIST = args.mask

        if args.edge is not None:
            config.TEST_EDGE_FLIST = args.edge

        if args.output is not None:
            config.RESULTS = args.output

    # eval mode
    elif mode == 3:
        config.MODE = 3
        config.MODEL = args.model if args.model is not None else 3

    return config