In [1]:
import cv2
from matplotlib import pyplot as plt
import pickle5 as pickle
from PIL import Image
import json
import cv2
import torch
import numpy as np
import os
import matplotlib.pyplot as plt

with open(f"../../datasets/data_230710/annotations/train/anno.pkl", "rb") as st_json:
    meta = pickle.load(st_json)
    

In [2]:



def get_transform(center, scale, res, rot=0):
    """Generate transformation matrix."""
    h = res[0] * scale
    # h = 200 * scale
    t = np.zeros((3, 3))
    t[0, 0] = float(res[1]) / h
    t[1, 1] = float(res[0]) / h
    t[0, 2] = res[1] * (-float(center[0]) / h + .5)
    t[1, 2] = res[0] * (-float(center[1]) / h + .5)
    t[2, 2] = 1
    if not rot == 0:
        rot = -rot # To match direction of rotation from cropping
        rot_mat = np.zeros((3,3))
        rot_rad = rot * np.pi / 180
        sn,cs = np.sin(rot_rad), np.cos(rot_rad)
        rot_mat[0,:2] = [cs, -sn]
        rot_mat[1,:2] = [sn, cs]
        rot_mat[2,2] = 1
        # Need to rotate around center
        t_mat = np.eye(3)
        t_mat[0,2] = -res[1]/2
        t_mat[1,2] = -res[0]/2
        t_inv = t_mat.copy()
        t_inv[:2,2] *= -1
        t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t)))
    return t


def transform(pt, center, scale, res, invert=0, rot=0):
    """Transform pixel location to different reference."""
    t = get_transform(center, scale, res, rot=rot)
    if invert:
        # t = np.linalg.inv(t)
        t_torch = torch.from_numpy(t)
        t_torch = torch.inverse(t_torch)
        t = t_torch.numpy()
    new_pt = np.array([pt[0]-1, pt[1]-1, 1.]).T
    new_pt = np.dot(t, new_pt)
    return new_pt[:2].astype(int)+1


def crop(img, center, scale, res, rot=0):
    """Crop image according to the supplied bounding box."""
    # Upper left point
    ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
    # Bottom right point
    br = np.array(transform([res[0]+1, 
                             res[1]+1], center, scale, res, invert=1))-1
    # Padding so that when rotated proper amount of context is included
    pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
    if not rot == 0:
        ul -= pad
        br += pad
    new_shape = [br[1] - ul[1], br[0] - ul[0]]
    if len(img.shape) > 2:
        new_shape += [img.shape[2]]
    new_img = np.zeros(new_shape)

    # Range to fill new array
    new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
    new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
    # Range to sample from original image
    old_x = max(0, ul[0]), min(len(img[0]), br[0])
    old_y = max(0, ul[1]), min(len(img), br[1])

    new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], 
                                                        old_x[0]:old_x[1]]
    if not rot == 0:
        # Remove padding
        # new_img = scipy.misc.imrotate(new_img, rot)
        new_img = myimrotate(new_img, rot)
        new_img = new_img[pad:-pad, pad:-pad]

    # new_img = scipy.misc.imresize(new_img, res)
    new_img = myimresize(new_img, [res[0], res[1]])
    return new_img


def myimrotate(img, angle, center=None, scale=1.0, border_value=0, auto_bound=False):
    if center is not None and auto_bound:
        raise ValueError('`auto_bound` conflicts with `center`')
    h, w = img.shape[:2]
    if center is None:
        center = ((w - 1) * 0.5, (h - 1) * 0.5)
    assert isinstance(center, tuple)

    matrix = cv2.getRotationMatrix2D(center, angle, scale)
    if auto_bound:
        cos = np.abs(matrix[0, 0])
        sin = np.abs(matrix[0, 1])
        new_w = h * sin + w * cos
        new_h = h * cos + w * sin
        matrix[0, 2] += (new_w - w) * 0.5
        matrix[1, 2] += (new_h - h) * 0.5
        w = int(np.round(new_w))
        h = int(np.round(new_h))
    rotated = cv2.warpAffine(img, matrix, (w, h), borderValue=border_value)
    return rotated

def myimresize(img, size, return_scale=False):

    h, w = img.shape[:2]
    resized_img = cv2.resize(
        img, (size[0],size[1]))
    if not return_scale:
        return resized_img
    else:
        w_scale = size[0] / w
        h_scale = size[1] / h
        return resized_img, w_scale, h_scale
    

def j2d_processing(kp, scale, r):
    """Process gt 2D keypoints and apply all augmentation transforms."""
    nparts = kp.shape[0]
    for i in range(nparts):
        kp[i, 0:2] = transform(kp[i, 0:2]+1, (512/2, 512/2), scale,
                                [512, 512], rot=r)
    return kp

def transform(pt, center, scale, res, invert=0, rot=0):
    """Transform pixel location to different reference."""
    t = get_transform(center, scale, res, rot=rot)
    if invert:
        # t = np.linalg.inv(t)
        t_torch = torch.from_numpy(t)
        t_torch = torch.inverse(t_torch)
        t = t_torch.numpy()
    new_pt = np.array([pt[0]-1, pt[1]-1, 1.]).T
    new_pt = np.dot(t, new_pt)
    return new_pt[:2].astype(int)+1

colors = np.array([[0.4, 0.4, 0.4],
                [0.4, 0.0, 0.0],
                [0.6, 0.0, 0.0],
                [0.8, 0.0, 0.0],
                [1.0, 0.0, 0.0],
                [0.4, 0.4, 0.0],
                [0.6, 0.6, 0.0],
                [0.8, 0.8, 0.0],
                [1.0, 1.0, 0.0],
                [0.0, 0.4, 0.2],
                [0.0, 0.6, 0.3],
                [0.0, 0.8, 0.4],
                [0.0, 1.0, 0.5],
                [0.0, 0.2, 0.4],
                [0.0, 0.3, 0.6],
                [0.0, 0.4, 0.8],
                [0.0, 0.5, 1.0],
                [0.4, 0.0, 0.4],
                [0.6, 0.0, 0.6],
                [0.7, 0.0, 0.8],
                [1.0, 0.0, 1.0]])

colors = colors[:, ::-1]

# define connections and colors of the bones
bones = [((0, 1), colors[1, :]),
        ((1, 2), colors[2, :]),
        ((2, 3), colors[3, :]),
        ((3, 4), colors[4, :]),

        ((0, 5), colors[5, :]),
        ((5, 6), colors[6, :]),
        ((6, 7), colors[7, :]),
        ((7, 8), colors[8, :]),

        ((0, 9), colors[9, :]),
        ((9, 10), colors[10, :]),
        ((10, 11), colors[11, :]),
        ((11, 12), colors[12, :]),

        ((0, 13), colors[13, :]),
        ((13, 14), colors[14, :]),
        ((14, 15), colors[15, :]),
        ((15, 16), colors[16, :]),

        ((0, 17), colors[17, :]),
        ((17, 18), colors[18, :]),
        ((18, 19), colors[19, :]),
        ((19, 20), colors[20, :])]
    
def visualize(image, joint_2d):
    parents = np.array([-1, 0, 1, 2, 3, 0, 5, 6, 7, 0, 9, 10, 11, 0, 13, 14, 15, 0, 17, 18, 19])
    for i in range(21):
        cv2.circle(image, (int(joint_2d[i][0]), int(joint_2d[i][1])), 2, colors[i] * 255,
                    thickness=2)
        if i != 0:
            cv2.line(image, (int(joint_2d[i][0]), int(joint_2d[i][1])),
                        (int(joint_2d[parents[i]][0]), int(joint_2d[parents[i]][1])),
                        colors[i] * 255, 2)
            
    return image


for idx in range(600, 650):

    raw_res = 800
    bg_path = "../../datasets/data_230710/background"
    bg_list = os.listdir(bg_path)
    img_path = os.path.join("../../datasets/data_230710", f"images/train")
    name = '/'.join(meta[idx]['file_name'].split('/')[1:])
    image = cv2.imread(os.path.join(img_path, name))  # PIL image
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)[:,:, (2, 1, 0)]
    
    joint_2d = np.array(meta[idx]['joint_2d'])
    joint_3d = meta[idx]['camera_coor_3d']    
    scale = meta[idx]['scale']   
    rot = meta[idx]['rot']       

    cropped_img = crop(image, (raw_res/2, raw_res/2), scale, [raw_res, raw_res], rot=rot)
    bg_img = cv2.imread(os.path.join(bg_path, bg_list[idx%len(bg_list)]))
    bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB)
    bg_img = cv2.resize(bg_img[:, :, (2, 1, 0)], (raw_res, raw_res))
    
    # plt.imshow(image[:, :, (2,1,0)]/ 255)
    # plt.show()

    # plt.imshow(cropped_img[:, :, (2,1,0)]/255)
    # plt.show()


    cv2.imwrite(f'sample_img/{idx}_st_ori.jpg', image)
    cv2.imwrite(f'sample_img/{idx}_st_cropped.jpg', cropped_img)

    iaz = np.where((cropped_img[:, :, 0] == 0) & (cropped_img[:, :, 1] == 0) & (cropped_img[:, :, 2] == 0))
    cropped_img[iaz] = bg_img[iaz]

    # plt.imshow(bg_img/ 255)
    # plt.show()

    cv2.imwrite(f'sample_img/{idx}_st_bg.jpg', bg_img)

    cv2.imwrite(f'sample_img/{idx}_st_output.jpg', cropped_img)
