In [1]:
import easydict
# import argparse
import traceback
import shutil
import logging
import yaml
import sys
import os
import torch
import numpy as np
# import torch.utils.tensorboard as tb
import copy

from runner import Diffusion

In [2]:
def make_parse_args(img_name):
    args = easydict.EasyDict({'seed': 1234, 
                              'exp': 'exp', 
                              'comment': '', 
                              'verbose': 'info', 
                              'sample': 'store_true', 
                              'i': 'images', 
                              'image_folder': img_name, 
                              'ni': 'store_true', 
                              'sample_step': 6, 
                              't': 300})

    level = getattr(logging, args.verbose.upper(), None)
    if not isinstance(level, int):
        raise ValueError('level {} not supported'.format(args.verbose))

    handler1 = logging.StreamHandler()
    formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s')
    handler1.setFormatter(formatter)
    logger = logging.getLogger()
    logger.addHandler(handler1)
    logger.setLevel(level)

    os.makedirs(os.path.join(args.exp, 'image_samples'), exist_ok=True)
    args.image_folder = os.path.join(args.exp, 'image_samples', args.image_folder)
    if not os.path.exists(args.image_folder):
        os.makedirs(args.image_folder)
    else:
        overwrite = False
        if args.ni:
            overwrite = True
        else:
            response = input("Image folder already exists. Overwrite? (Y/N)")
            if response.upper() == 'Y':
                overwrite = True

        if overwrite:
            shutil.rmtree(args.image_folder)
            os.makedirs(args.image_folder)
        else:
            print("Output image folder exists. Program halted.")
            sys.exit(0)

    # add device
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    logging.info("Using device: {}".format(device))

    # set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    torch.backends.cudnn.benchmark = True

    return args

In [3]:
config = "celeba.yml"
origin_img = "source_images/real_images/one_young.jpg"
stroked_img = "source_images/stroked_images/stroked_one_young.jpg"
img_name = stroked_img.split('/')[-1]
img_name = img_name[:-4]

args = make_parse_args(img_name)


try:
    runner = Diffusion(args, config, origin_img, stroked_img)
    runner.image_editing_sample()
except Exception:
    logging.error(traceback.format_exc())

INFO - <ipython-input-2-46d5e0644300> - 2023-06-08 09:53:08,730 - Using device: cuda


Loading model


    There is an imbalance between your GPUs. You may want to exclude GPU 0 which
    has less than 75% of the memory or cores of GPU 1. You can do so by setting
    the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
    environment variable.


Model loaded
Start sampling


  out = torch.gather(torch.tensor(a, dtype=torch.float, device=t.device), 0, t.long())
Iteration 0: 100%|██████████| 300/300 [00:52<00:00,  5.67it/s]
Iteration 1: 100%|██████████| 300/300 [00:43<00:00,  6.94it/s]
Iteration 2: 100%|██████████| 300/300 [00:43<00:00,  6.94it/s]
Iteration 3: 100%|██████████| 300/300 [00:43<00:00,  6.92it/s]
Iteration 4: 100%|██████████| 300/300 [00:43<00:00,  6.90it/s]
Iteration 5: 100%|██████████| 300/300 [00:43<00:00,  6.91it/s]
