In [None]:
import torch

from utils.tools import get_config, random_bbox, mask_image
# from utils.tools import tensor_img_to_npimg as tensor_to_img
from utils.tools import tensor_to_img
from utils.tools import local_patch, spatial_discounting_mask
from data.dataset import Dataset

import numpy as np
import cv2
import matplotlib.pyplot as plt



In [None]:
def random_ff_mask(config, batch_size, to_tensor=True):
    """Generate a random free form mask with configuration.

    Args:
        config: Config should have configuration including IMG_SHAPES,
            VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.

    Returns:
        tuple: (top, left, height, width)
    """

    h,w = config['shape']
#     mask = np.zeros((h,w))
    num_v = 12+np.random.randint(config['mv'])#tf.random_uniform([], minval=0, maxval=config.MAXVERTEX, dtype=tf.int32)

    masks = []
    for i in range(batch_size):
        mask = np.zeros((h,w))
        for i in range(num_v):
            start_x = np.random.randint(w)
            start_y = np.random.randint(h)
            for j in range(1+np.random.randint(5)):
                angle = 0.01+np.random.randint(config['ma'])
                if i % 2 == 0:
                    angle = 2 * 3.1415926 - angle
                length = 10+np.random.randint(config['ml'])
                brush_w = 10+np.random.randint(config['mbw'])
                end_x = (start_x + length * np.sin(angle)).astype(np.int32)
                end_y = (start_y + length * np.cos(angle)).astype(np.int32)

                cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)
                start_x, start_y = end_x, end_y

        mask = mask.reshape(mask.shape+(1,)).astype(np.float32).copy()

        if to_tensor:
            mask = torch.from_numpy(mask)
            
        masks.append(mask)
        
        
    if to_tensor:
#         mask = torch.from_numpy(mask)
        masks = torch.cat(masks, dim=-1).permute(2,0,1).unsqueeze(1) # [HWN] -> [NCHW]
        
    return masks
    

In [None]:
config = get_config('configs/config-gated-spectnorm.yaml')

In [None]:
mask_ff = random_ff_mask(config['random_ff_settings'])
print(mask_ff.shape)

In [None]:
plt.imshow(mask_ff)

## Look at 'local_patch' output and mimic behaviour for free form mask

In [None]:
# Data
#
sampler = None
train_dataset = Dataset(
    data_path=config['train_data_path'],
    with_subfolder=config['data_with_subfolder'],
    image_shape=config['image_shape'],
    random_crop=config['random_crop']
)


# sampler = torch.utils.data.distributed.DistributedSampler(
#     train_dataset,
# #             num_replicas=torch.cuda.device_count(),
#     num_replicas=len(config['gpu_ids']),
# #         rank = local_rank
# )


train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=config['batch_size'],
    shuffle=(sampler is None),
    num_workers=config['num_workers'],
    pin_memory=True,
    sampler=sampler,
    drop_last=True
)

In [None]:
import tqdm.auto as tqdm

iterable_train_loader = iter(train_loader)

epochs = config['niter'] + 1
# pbar = tqdm(range(0, epochs), dynamic_ncols=True, smoothing=0.01)
# for iteration in pbar:
for iteration in range(0, epochs):
#     sampler.set_epoch(iteration)

    try:
        ground_truth = next(iterable_train_loader)
    except StopIteration:
        iterable_train_loader = iter(train_loader)
        ground_truth = next(iterable_train_loader)

#     # Prepare the inputs
#     bboxes = random_bbox(config, batch_size=ground_truth.size(0))
#     x, mask = mask_image(ground_truth, bboxes, config)


#     # Move to proper device.
#     #
#     bboxes = bboxes.cuda(local_rank)
#     x = x.cuda(local_rank)
#     mask = mask.cuda(local_rank)
#     ground_truth = ground_truth.cuda(local_rank)
    break

In [None]:
print(ground_truth.size())
# img = ground_truth[0].squeeze().permute(1,2,0)
img = tensor_to_img(ground_truth[0])
plt.imshow(img)


In [None]:
# Prepare the inputs
bboxes = random_bbox(config, batch_size=ground_truth.size(0))
x, mask = mask_image(ground_truth, bboxes, config)
print(x.shape)
print(mask.shape)

# # Move to proper device.
# #
# bboxes = bboxes.cuda(local_rank)
# x = x.cuda(local_rank)
# mask = mask.cuda(local_rank)
# ground_truth = ground_truth.cuda(local_rank)

In [None]:
plt.imshow(tensor_to_img(x[0])/255); plt.show()
plt.imshow(mask[0].squeeze())

In [None]:
local_patch_gt = local_patch(ground_truth, bboxes)
print(local_patch_gt.shape)

In [None]:
plt.imshow(tensor_to_img(local_patch_gt[0])/255.)

In [None]:
from model.networks import Generator, LocalDis, GlobalDis

In [None]:
netG = Generator(config['netG'], use_cuda=True, device=0).cuda()
localD = LocalDis(config['netD'], use_cuda=True, device_id=0).cuda()
globalD = GlobalDis(config['netD'], use_cuda=True, device_id=0).cuda()

In [None]:
local_rank = 0
# Move to proper device.
#
# bboxes = bboxes.cuda(local_rank)
# mask = mask.cuda(local_rank)
x = x.cuda(local_rank)
mask_ff = random_ff_mask(config['random_ff_settings'], batch_size=x.shape[0], to_tensor=True).cuda(local_rank)
ground_truth = ground_truth.cuda(local_rank)

print(x.shape)
print(mask_ff.shape)
print(ground_truth.shape)

In [None]:
x1, x2, offset_flow = netG(x, mask_ff)

In [None]:
# local_patch_gt = local_patch(ground_truth, bboxes)
# x1_inpaint = x1 * masks + x * (1. - masks)
# x2_inpaint = x2 * masks + x * (1. - masks)
# local_patch_x1_inpaint = local_patch(x1_inpaint, bboxes)
# local_patch_x2_inpaint = local_patch(x2_inpaint, bboxes)

# local_patch_gt = local_patch(ground_truth, bboxes)
x1_inpaint = x1 * mask_ff + x * (1. - mask_ff)
x2_inpaint = x2 * mask_ff + x * (1. - mask_ff)
# local_patch_x1_inpaint = local_patch(x1_inpaint, bboxes)
# local_patch_x2_inpaint = local_patch(x2_inpaint, bboxes)