In [None]:
import torch
import numpy as np
import torchvision as tv
import torchvision.transforms as T
import matplotlib.pyplot as plt
from PIL import Image
import os

import tensorflow as tf

from model.networks import Generator, Discriminator




### Loading pretrained checkpoints for DeepFill and CNN models

In [None]:
# Load pretrained GAN model for inpainting
path_to_generator_ckpt = "pretrained_inpaint/states.pth"
path_to_CNN_ckpt = "pretrained_GNN/checkpoints"
use_cuda_if_available = True
device = torch.device('cuda' if torch.cuda.is_available() 
                             and use_cuda_if_available else 'cpu')
generator_inpaint = Generator(cnum_in=5, cnum=48, return_flow=True).to(device)

# generator_state_dict = torch.load('pretrained/states_tf_places2.pth')['G']
generator_inpaint_state_dict = torch.load(path_to_generator_ckpt)['G']
generator_inpaint.load_state_dict(generator_inpaint_state_dict)

# Load pretrained GAN model for inverse design from field to geometry

forward_CNN = tf.keras.models.load_model(path_to_CNN_ckpt)

### Functions needed for data processing

In [None]:
def pt_to_rgb(pt): return pt[0].cpu().permute(1, 2, 0)*0.5 + 0.5
def predict_inpaint(image, mask, generator_inpaint):
    """
    The input are Pillow image objects
    """
    image_org = T.ToTensor()(image)
    mask = T.ToTensor()(mask)

    _, h, w = image_org.shape
    grid = 8

    image = image_org[:3, :h//grid*grid, :w//grid*grid].unsqueeze(0)
    mask = mask[0:1, :h//grid*grid, :w//grid*grid].unsqueeze(0)

    print(f"Shape of image: {image.shape}")

    image = (image*2 - 1.).to(device)  # map image values to [-1, 1] range
    mask = (mask > 0.).to(dtype=torch.float32, device=device)  # 1.: masked 0.: unmasked

    image_masked = image * (1.-mask)  # mask image

    ones_x = torch.ones_like(image_masked)[:, 0:1, :, :]
    x = torch.cat([image_masked, ones_x, ones_x*mask], dim=1)  # concatenate channels
    with torch.no_grad():
        x_stage1, x_stage2, offset_flow = generator_inpaint(x, mask)

    image_inpainted = image_masked * (1.-mask) + x_stage2 * mask
    return pt_to_rgb(image_masked).numpy(), pt_to_rgb(image_inpainted).numpy()                                                                                                                  


def predict_geo(image):
    """
    The input are numpy arrays
    """
    IMG_WIDTH = 256
    IMG_HEIGHT = 256
    num_block = 8
    interval_x = IMG_WIDTH / num_block
    interval_y = IMG_HEIGHT / num_block
    seq = np.squeeze((forward_CNN.predict(np.expand_dims(image, axis=0)) > 0.5).astype("int"))
    img = np.zeros((IMG_WIDTH, IMG_HEIGHT, 3))
    for i in range(IMG_WIDTH):
        for j in range(IMG_HEIGHT):
            idx_x = i // interval_x
            if idx_x > 3:
                idx_x = 7 - idx_x
            idx_y = (IMG_HEIGHT - j - 1) // interval_y
            order = int(idx_x * num_block + idx_y)
            if i < 3 or i > IMG_WIDTH - 4 or j < 3 or j > IMG_HEIGHT - 4:
                img[j][i][0] = 0
                img[j][i][1] = 0
                img[j][i][2] = 0 
            else:
                img[j][i][0] = 1
                img[j][i][1] = seq[order]
                img[j][i][2] = img[j][i][1]
        
    
    return seq, img


### Testing field completion results

In [None]:

# Generate results for all test data

path_field_map = "./predictions/field_maps/"
# Check whether the specified path exists or not
isExist = os.path.exists(path_field_map)
if not isExist:

   # Create a new directory because it does not exist
    os.makedirs(path_field_map)

for i in range(200):
    image = Image.open("./datasets/S11_8/test/" + str(i + 801) + ".jpg") 
    bbox = random_bbox()
    regular_mask = bbox2mask(bbox)
    irregular_mask = brush_stroke_mask()
    mask = torch.logical_or(irregular_mask, regular_mask).to(torch.float32)
    empty = np.zeros((1, 4, 256, 256))
    mask = mask + empty
    mask = mask[0].permute(1, 2, 0) * 255
    mask = np.squeeze(mask.to(device='cpu', dtype=torch.uint8).numpy())
    mask = Image.fromarray(mask.astype(np.uint8))
    img_masked, img_field_pred = predict_inpaint(image, mask, generator_inpaint)
    img_object = Image.fromarray(np.uint8(img_field_pred * 255))
    img_object.save(path_field_map + str(i + 1) + ".jpg")



### Using CNN model for inverse translation from field back to structure

In [None]:
# Generate inverse results

path_geo = "./predictions/geos/"

# Check whether the specified path exists or not
isExist = os.path.exists(path_geo)
if not isExist:

   # Create a new directory because it does not exist
    os.makedirs(path_geo)
    
for i in range(200):
    img_field_pred = Image.open(path_field_map + str(i + 1) + ".jpg")
    img_field_pred = np.array(img_field_pred) / 255
    seq, img_geo_pred = predict_geo(img_field_pred)
    img_object = Image.fromarray(np.uint8(img_geo_pred * 255))
    img_object.save(path_geo + str(i + 1) + ".jpg")


