In [None]:
from main import *
import json

In [None]:
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = [15, 15]
mpl.rcParams['figure.dpi'] = 72

In [None]:
def cut_mask(org_mask):
    h, w = org_mask.shape[0], org_mask.shape[1]
    l,r,t,b = get_mask_bbox(org_mask)
    obj_center_x = int(l+(r-l)/2)
    obj_center_y = int(t+(b-t)/2)
    new_l = 0
    new_t = 0
    if h<w:
        if obj_center_x>=h//2: # obj is NOT too far left
            if obj_center_x<=w-h//2: # obj is NOT too far right
                new_l = obj_center_x-h//2
            else: # obj is too far right
                new_l = w-h
        else: # obj is too far left, new_l=0
            pass
        mask = org_mask[:, new_l:new_l+h]
    elif h>w:
        if obj_center_y>=w//2: # obj is NOT too far top
            if obj_center_y<=h-w//2: # obj is NOT too far bottom
                new_t = obj_center_y-w//2
            else: # obj is too far bottom
                new_t = h-w
        else: # obj is too far top
            pass
        mask = org_mask[new_t:new_t+w]
    else:
        mask = org_mask
    return mask, (new_l, new_t)

def first_nonzero(arr, axis, invalid_val=-1):
    mask = arr!=0
    return np.where(mask.any(axis=axis), mask.argmax(axis=axis), invalid_val)

def last_nonzero(arr, axis, invalid_val=-1):
    mask = arr!=0
    val = arr.shape[axis] - np.flip(mask, axis=axis).argmax(axis=axis) - 1
    return np.where(mask.any(axis=axis), val, invalid_val)

def get_mask_bbox(img):
    sum_y = np.sum(img, axis=0)
    sum_x = np.sum(img, axis=1)
    l = first_nonzero(sum_y, 0)[0]
    r = last_nonzero(sum_y, 0)[0]
    t = first_nonzero(sum_x, 0)[0]
    b = last_nonzero(sum_x, 0)[0]
    return (l, r, t, b)

def restore_mask(org_mask, mask, cut_info):
    h, w = mask.shape[0], mask.shape[1]
    # h, w = org_mask.shape[0], org_mask.shape[1]
    new_l, new_t = cut_info
    if h<w:
        pad_mask = np.zeros_like(org_mask)
        pad_mask[:, new_l:new_l+h] = mask
    elif h>w:
        pad_mask = np.zeros_like(org_mask)
        pad_mask[new_t:new_t+w] = mask
    else:
        pad_mask = mask
    return pad_mask

def pointwise_distance(pts1, pts2):
    """Calculates the distance between pairs of points
    Args:
        pts1 (np.ndarray): array of form [[x1, y1], [x2, y2], ...]
        pts2 (np.ndarray): array of form [[x1, y1], [x2, y2], ...]
    Returns:
        np.array: distances between corresponding points
    """
    dist = np.sqrt(np.sum((pts1 - pts2)**2, axis=1))
    return dist

def order_points(pts):
    """Orders points in form [top left, top right, bottom right, bottom left].
    Source: https://www.pyimagesearch.com/2016/03/21/ordering-coordinates-clockwise-with-python-and-opencv/
    Args:
        pts (np.ndarray): list of points of form [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
    Returns:
        [type]: [description]
    """
    # sort the points based on their x-coordinates
    x_sorted = pts[np.argsort(pts[:, 0]), :]

    # grab the left-most and right-most points from the sorted
    # x-roodinate points
    left_most = x_sorted[:2, :]
    right_most = x_sorted[2:, :]

    # now, sort the left-most coordinates according to their
    # y-coordinates so we can grab the top-left and bottom-left
    # points, respectively
    left_most = left_most[np.argsort(left_most[:, 1]), :]
    tl, bl = left_most

    # now that we have the top-left coordinate, use it as an
    # anchor to calculate the Euclidean distance between the
    # top-left and right-most points; by the Pythagorean
    # theorem, the point with the largest distance will be
    # our bottom-right point. Note: this is a valid assumption because
    # we are dealing with rectangles only.
    # We need to use this instead of just using min/max to handle the case where
    # there are points that have the same x or y value.
    D = pointwise_distance(np.vstack([tl, tl]), right_most)
    
    br, tr = right_most[np.argsort(D)[::-1], :]

    # return the coordinates in top-left, top-right,
    # bottom-right, and bottom-left order
    return np.array([tl, tr, br, bl], dtype="float32")

In [None]:
img_name="timecity_00130_none"
image_path = f"./data/dog/images/{img_name}.jpg"
mask_path = f"./data/dog/masks/{img_name}.jpg"
leg_pos_raw = json.load(open(f'./data/dog/leg_pos/{img_name}.json', "r"))
leg_pos = np.zeros((1,4,2)) # leg positions
for i in range(4):
    leg_pos[0][i] = leg_pos_raw['shapes'][i]['points'][0]

tl, tr, br, bl = order_points(leg_pos[0])
leg_pos_sorted = np.array([[tl, tr, br, bl]])

org_mask = cv2.imread(mask_path)

In [None]:
org_image = cv2.imread(image_path)
image_rgb = cv2.cvtColor(org_image, cv2.COLOR_BGR2RGB)
# image = image_rgb[new_t:new_t+mask.shape[0], new_l:new_l+mask.shape[1]]
image = image_rgb

poses2d = leg_pos_sorted[0]
mask = cv2.cvtColor(org_mask, cv2.COLOR_BGR2GRAY)

In [None]:
# image_path = "./data/image_alok.png"
# mask_path = "./data/mask_alok.png"
# poses2d_path = "./data/poses2d.npy"

#
# image = cv2.imread(image_path)
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
h, w = image.shape[:2]
# mask = cv2.imread(mask_path, 0)
# poses2d = np.load(poses2d_path)

#
tri_mc = TriangleMeshCreator(interval=20, angle_constraint=20, area_constraint=200, dilated_pixel=5)
mesh = tri_mc.create(image, mask)

#
vertices = 0.5 * (mesh.vertices + 1) * np.array([w, h]).reshape((1, 2)).astype(np.float32)
distance = cdist(poses2d, vertices)
constraint_v_ids = np.argmin(distance, axis=1)
poses2d = vertices[constraint_v_ids]
constraint_v_coords = augment_handle_points(poses2d, size=(w, h))

constraint_v_ids = np.array([e for i, e in enumerate(constraint_v_ids) if i != 3])
constraint_v_coords = np.array([e for i, e in enumerate(constraint_v_coords) if i != 3])

In [None]:
if VISUALIZE:
    vis_image = mesh.get_image()
    vis_image = cv2.cvtColor(vis_image, cv2.COLOR_GRAY2BGR)
    for x, y in poses2d.astype(int):
        cv2.circle(vis_image, (x, y), radius=3, color=(255, 0, 0), thickness=2)

    for x, y in constraint_v_coords.astype(int):
        cv2.circle(vis_image, (x, y), radius=3, color=(0, 255, 0), thickness=2)

    im_utils.imshow(vis_image)

In [None]:
#
constraint_v_coords = Mesh.normalize_vertices(constraint_v_coords, size=(w, h))

# build vertices texture
vts = 0.5 * (mesh.vertices + 1)
vts[:, 1] = 1. - vts[:, 1]

# deform
arap_deform = ARAPDeformation()
arap_deform.load_from_mesh(mesh)
arap_deform.setup()

deformed_mesh = arap_deform.deform(constraint_v_ids, constraint_v_coords, w=1000.)
save_obj_format(file_path=DEFORM_MESH_PATH, vertices=deformed_mesh.vertices, faces=deformed_mesh.faces,
                texture_vertices=vts)

if VISUALIZE:
    vis_image = deformed_mesh.get_image(size=(w, h))
    im_utils.imshow(vis_image)

#
pt_renderer = render_utils.PytorchRenderer(use_gpu=False)
deformed_image = pt_renderer.render_w_texture(DEFORM_MESH_PATH, image_path)
deformed_image = deformed_image[::-1, :, :]
deformed_image = cv2.cvtColor(deformed_image, cv2.COLOR_BGR2RGB)

im_utils.imshow(deformed_image)

In [None]:
im_utils.imshow(vis_image[100:400, 600:1000]) # Messed-up index because of converting rgb-bgr
im_utils.imshow(deformed_image[100:400, 600:1000]) # Messed-up index because of converting rgb-bgr

In [None]:
im_utils.imshow(image[100:400, 600:1000])