In [1]:
import numpy as np
import cv2
from tensorflow import keras
from tqdm import tqdm
import copy
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.pyplot import figure
import tensorflow as tf
import os, json, glob

In [3]:
device_data = {
    'iphone se': -1,
    'iphone 4': -1,
    'iphone 4s': {
        'matrix': [606.59362793, 609.2008667, 236.86116028, 312.28497314],
        'distortion': [ 0.24675941, -0.65499198,  0.00301733, -0.00097767]
    },
    'iphone 5': {
        'matrix': [623.28759766, 626.64154053, 236.86317444, 316.909729  ],
        'distortion': [ 0.03760624, -0.043609, -0.00114902,  0.00269194]
    },
    'iphone 5c': {
        'matrix': [585.13171387, 588.14447021, 242.18914795, 321.20614624],
        'distortion': [ 0.01302955, -0.10349616, -0.0009803,  0.00301618]
    },
    'iphone 5s': {
        'matrix': [585.13171387, 588.14447021, 242.18914795, 321.20614624],
        'distortion': [ 0.01302955, -0.10349616, -0.0009803,  0.00301618]
    },
    'iphone 6': {
        'matrix': [592.50164795, 595.66986084, 236.12217712, 327.50753784],
        'distortion': [ 0.0822313, -0.18398251, -0.00631323, -0.00075782]
    },
    'iphone 6 plus': {
        'matrix': [592.50164795, 595.66986084, 236.12217712, 327.50753784],
        'distortion': [ 0.0822313, -0.18398251, -0.00631323, -0.00075782]
    },
    'iphone 6s': {
        'matrix': [592.50164795, 595.66986084, 236.12217712, 327.50753784],
        'distortion': [ 0.0822313, -0.18398251, -0.00631323, -0.00075782]
    },
    'iphone 6s plus': {
        'matrix': [592.50164795, 595.66986084, 236.12217712, 327.50753784],
        'distortion': [ 0.0822313, -0.18398251, -0.00631323, -0.00075782]
    },
    'iphone 7': {
        'matrix': [592.50164795, 595.66986084, 236.12217712, 327.50753784],
        'distortion': [ 0.0822313, -0.18398251, -0.00631323, -0.00075782]
    },
    'iphone 7 plus': {
        'matrix': [592.50164795, 595.66986084, 236.12217712, 327.50753784],
        'distortion': [ 0.0822313, -0.18398251, -0.00631323, -0.00075782]
    },
    'iphone 8': {
        'matrix': [580.34485, 581.34717, 239.41379, 319.58548],
        'distortion': [ 0.0822313, -0.18398251, -0.00631323, -0.00075782]
    },
    'iphone 8 plus': {
        'matrix': [580.34485, 581.34717, 239.41379, 319.58548],
        'distortion': [ 0.0822313, -0.18398251, -0.00631323, -0.00075782]
    },
    'iphone x': {
        'matrix': [592.16473, 593.1875, 242.00687, 320.23456],
        'distortion': [ 0.0822313, -0.18398251, -0.00631323, -0.00075782]
    },
    'iphone xs': -1,
    'iphone xs max global': -1,
    'iphone xr': -1,

    'ipad air': {
        'matrix': [578, 578, 240, 320],
        'distortion': [0.124, -0.214, 0, 0]
    },  # ipad air from Web
    'ipad air 2': {
        'matrix': [592.35223389, 595.9105835, 234.15885925, 313.48773193],
        'distortion': [ 1.93445340e-01, -5.54507077e-01,  6.13935478e-03,  3.40262457e-04]
    },
    'ipad 2': {
        'matrix': [621.54315186, 624.44012451, 233.66329956, 313.44387817],
        'distortion': [-0.0243901, -0.10230259, -0.00513017,  0.00057966]
    },
    'ipad 6': -1,
    'ipad pro 2 (10.5-inch': -1,

    'ipod touch 6': -1,
    'ipad mini': {
        'matrix': [623.28759766, 626.64154053, 236.86317444, 316.909729],
        'distortion': [ 0.03760624, -0.043609, -0.00114902,  0.00269194]
    },
}
debug_print = False
normalize_distance = 300

normalize_face_focal_length = 200
normalize_face_size = (96, 48)

normalize_eye_focal_length = 700
normalize_eye_size = (96, 96)

In [4]:
face_norm = np.array([
    [normalize_face_focal_length, 0, normalize_face_size[0] / 2],
    [0, normalize_face_focal_length, normalize_face_size[1] / 2],
    [0, 0, 1.0],
])

eye_norm = np.array([
    [normalize_eye_focal_length, 0, normalize_eye_size[0] / 2],
    [0, normalize_eye_focal_length, normalize_eye_size[1] / 2],
    [0, 0, 1.0],
])


In [5]:
dir_output = '/mnt/sata3/everyone-tfrecord2'
dir_input = '/mnt/sata4/everyone'
image_dir = 'frames'
json_dir = 'data'
landmarkd_dir = 'landmark2'
tfrecord_dir = 'tfrecord2'
temptfrecord_dir = 'temp_record'

phases = ['train', 'val']

In [6]:
def getFocalLengthAndDistortion(device_name):
    device_name = device_name.lower()

    if not(device_name in device_data):
        return None, None

    if device_data[device_name] == -1:
        return None, None

    return np.array(device_data[device_name]['matrix']), np.array(device_data[device_name]['distortion'])

In [7]:
def cutImage(image, landmarks):
    width = image.shape[1]
    height = image.shape[0]

    min_x = max(int(min(landmarks[:, 0])), 0)
    max_x = min(int(max(landmarks[:, 0])), width)
    min_y = max(int(min(landmarks[:, 1])), 0)
    max_y = min(int(max(landmarks[:, 1])), height)
    
#     print('minmax', min_y, max_y)
    w = max_x - min_x
    h = max_y - min_y

    img_xs = int(max(min_x - 0.4 * w, 0))
    img_xe = int(min(max_x + 0.4 * w, width))
    img_ys = int(max(min_y - 0.1 * h, 0))
    img_ye = int(min(max_y + 0.0 * h, height))

    # if jaws are removed
    if len(landmarks) == 17:
        img_ys = int(max(min_y - 0.3 * h, 0))
        img_ye = int(min(max_y + 1.0 * h, height))
        
#     print ('image:', image.shape, ',', img_ys, ',', img_ye, ',', img_xs, ',', img_xe)
    return image[img_ys:img_ye, img_xs:img_xe], landmarks - [img_xs, img_ys]

In [8]:
def drawImage(image):
    plt.figure(figsize=(10,10))
    plt.imshow(image)
    plt.show()
    return

def drawCVImage(image):
    cv2.circle(image, tuple([48, 48]), 2, (0, 0, 255), -1)
    image = image[:,:,::-1]
    plt.figure(figsize=(10,10))
    plt.imshow(image)
    plt.show()
    return

def drawLandmarks(image, landmarks):
    font = cv2.FONT_HERSHEY_SIMPLEX
    for index, pt in enumerate(landmarks):
        cv2.circle(image, tuple([pt[0].astype(int), pt[1].astype(int)]), 1, (0, 0, 255), -1)
#         cv2.putText(image, str(index), tuple([(pt[0] + 2).astype(int), (pt[1] - 2).astype(int)]), font, 0.3, (255, 0, 0), 1, cv2.LINE_AA)
    image = image[:,:,::-1]
    plt.figure(figsize=(10,10))
    plt.imshow(image)
    plt.show()
    return

In [9]:
def modify_landmark_to_classic(landmarks):
    landmark_map = [7, 8, 9, 
                    27, 28, 29, 30, 31, 32, 33, 34, 35,
                    36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]
    classic = np.zeros((68, 2))
    for i, idx in enumerate(landmark_map):
        classic[idx] = landmarks[i]
    return classic

def pickEffectiveLandmarks(raw_landmarks):
    jaws = raw_landmarks[7:10]
    eye_right = raw_landmarks[36:42]
    eye_left = raw_landmarks[42:48]
    nose_bridge = raw_landmarks[27:31]
    nose_bottom = raw_landmarks[33:34]

    return np.vstack((eye_right, eye_left, nose_bridge, nose_bottom, jaws))

In [10]:
def getStandard3DFacePoints():
    faceModel = 'standard3DFace.json'

    if not os.path.isfile(faceModel):
        print(faceModel, 'not found!')
        return None

    with open(faceModel, 'r') as f:
        data = json.load(f)

    eye_right   = np.array( data['eye_right'],   dtype='float32')
    eye_left    = np.array( data['eye_left'],    dtype='float32')
    nose        = np.array( data['nose'],        dtype='float32')
    nose_bottom = np.array( data['nose_bottom'], dtype='float32')
    jaws        = np.array( data['jaws'],        dtype='float32')

    return np.vstack((eye_right, eye_left, nose, nose_bottom, jaws))

In [11]:
# noinspection PyShadowingNames
def estimateHeadPosition(refined_landmarks, position, camera_matrix, camera_distortion):
#     print('refined_landmarks:', refined_landmarks, 'position:',position, 'camera_matrix:',camera_matrix, 'camera_distortion:', camera_distortion)
    ret, rvec, tvec = cv2.solvePnP(position, refined_landmarks,
                                   camera_matrix, camera_distortion, flags=cv2.SOLVEPNP_EPNP)
    ret, rvec, tvec = cv2.solvePnP(position, refined_landmarks, camera_matrix, camera_distortion, rvec, tvec, True)

    return rvec, tvec

In [12]:
def get_plane(p1, p2, p3):
    # These two vectors are in the plane
    v1 = p3 - p1
    v2 = p2 - p1

    # the cross product is a vector normal to the plane
    cp = np.cross(v1, v2)
    a, b, c = cp

    # This evaluates a * x3 + b * y3 + c * z3 which equals d
    d = np.dot(cp, p3)

    # The equation is ax + by + cz = d
    # But we want to have ax + by + c = z
#     print('The equation is {0}x + {1}y + {2}z = {3}'.format(a, b, c, d))
    a = - a / c
    b = - b / c
    c = d / c
    return a, b, c

In [13]:
def get_R(rotation_matrix, center):
    distance = np.linalg.norm(center)
    hRx = rotation_matrix[:, 0]
    
    forward = (center / distance).reshape(3)
    
    down = np.cross(forward, hRx)
    down /= np.linalg.norm(down)
    
    right = np.cross(down, forward)
    right /= np.linalg.norm(right)
    
    return np.c_[right, down, forward].T

In [14]:
'''
    We need modified S that does not scale but move Z by k distance
    make plane that formed by le, re, nose_tip
    ax + by + c = z
    S = |   1      0      0  |
        |   0      1      0  |
        | -ak/c  -bk/c  1+k/c|

    S moves (x, y, z) to (x, y, z + k)
'''

def calculate_S(pt1, pt2, pt3, R, target_distance):
    pt1 = np.dot(R, pt1)
    pt2 = np.dot(R, pt2)
    pt3 = np.dot(R, pt3)
    
    p1 = np.reshape(pt1, 3)
    p2 = np.reshape(pt2, 3)
    p3 = np.reshape(pt3, 3)
    
    a, b, c = get_plane(p1, p2, p3)
    k = target_distance - c
    
    S = np.array([
        [1.0, 0.0, 0.0],
        [0.0, 1.0, 0.0],
        [-a*k/c, -b*k/c, 1+k/c],
    ])
    return S

In [15]:
def getLandmarksFromJson(data):
    if 'landmark2' in data:
        jsonarr = np.array(data['landmark2'])
        jsonarr = jsonarr.reshape([-1, 2])
        return pickEffectiveLandmarks(modify_landmark_to_classic(jsonarr))
    return None

In [16]:
def removeOutsideImage(image, landmarks, face3d):
    # removing jaws only
    isoutside = False
    width = image.shape[1]
    height = image.shape[0]

    for i in range(17, 20):
        if landmarks[i][0] > width * 0.98 or landmarks[i][1] > height * 0.98 or landmarks[i][0] < width * 0.02 or landmarks[i][1] < height * 0.02:
            isoutside = True
            break

#     if isoutside:
    return landmarks[0:17], face3d[0:17]
#     else:
#         return landmarks, face3d

In [17]:
standardFace = getStandard3DFacePoints()
if standardFace is None:
    print('Facemodel file not found!')
    exit(-1)

In [18]:
def read_data(json_path, image_path):

    # load image and data files
    image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
    if not os.path.isfile(image_path):
        print(image_path, 'not exists!')
        return None

    with open(json_path, 'r') as f:
        data = json.load(f)

#     with open(landmark_path, 'r') as f:
#         landmark_data = json.load(f)


    params, dists = getFocalLengthAndDistortion(data['deviceName'])
    if params is None:
        return None

    fx, fy = params[0], params[1]
    if image.shape[1] > image.shape[0]:  # width > height
        cx, cy = params[3], params[2]
    else:
        cx, cy = params[2], params[3]

    camera_distortion = np.hstack((dists, 0))

    camera_matrix = np.array([
        [fx,  0, cx],
        [ 0, fy, cy],
        [ 0,  0, 1 ]
    ])

    landmarks = getLandmarksFromJson(data)

    if landmarks is None:
        print(image_path, 'has no landmark information!')
        return None

    landmarks, face_3d = removeOutsideImage(image, landmarks, copy.deepcopy(standardFace))
    # make to 3-D array
    landmarks = landmarks.reshape(-1, 1, 2)
    face_3d = face_3d.reshape(-1, 1, 3)

    lookat = np.array([-data['XCam'], -data['YCam'], 0])
    # cm to mm
    lookat = lookat * 10
    lookat = lookat.reshape((3, 1))
    
    return image, camera_matrix, camera_distortion, landmarks, face_3d, lookat

In [19]:
def normalizeDataAndGaze(image, face, camera_matrix, look_vector, head_rotate, landmarks):
    # normalizing face area
    nose_tip = face[:, 15].reshape((3, 1))
    eye_center = np.array([sum(x) for x in face[:, 0:12]]) / 12
    re = np.array([sum(x) for x in face[:, 0:6]]) / 6
    le = np.array([sum(x) for x in face[:, 6:12]]) / 6

    center = eye_center.reshape((3, 1))
    
    gaze_data = []
    warped_image = []
    R_list = []
    for eye in [re, le]:
        R = get_R(head_rotate, eye)
        S = calculate_S(re, le, nose_tip, R, normalize_distance)
        W = np.dot(np.dot(eye_norm, S), np.dot(R, np.linalg.inv(camera_matrix)))  # transformation matrix
        image_warped = cv2.warpPerspective(image, W, normalize_eye_size)  # image normalization

        eye = eye.reshape((3, 1))
        g = look_vector - eye
        g = np.dot(R, g)
        g = g / (-g[2])
        
        warped_image.append(image_warped)
        gaze_data.append(g)
        R_list.append(R)
        if debug_print:
            drawCVImage(image_warped)
    
    center = (nose_tip + eye_center.reshape((3, 1))) / 2
    R = get_R(head_rotate, center)
    S = calculate_S(re, le, nose_tip, R, normalize_distance)
    W = np.dot(np.dot(face_norm, S), np.dot(R, np.linalg.inv(camera_matrix)))  # transformation matrix
    image_warped = cv2.warpPerspective(image, W, normalize_face_size)  # image normalization
    transformed_lks = cv2.perspectiveTransform(landmarks, W)

    warped_image.append(image_warped)
    R_list.append(R)
    if debug_print:
        drawLandmarks(image_warped, transformed_lks.reshape((-1, 2)))

#     right_gaze = gaze_data[0]
#     right_eye = eye_right.reshape((3, 1))
#     original_right_gaze = np.dot(np.linalg.inv(R), right_gaze)
#     multiplier = - np.divide(right_eye[2], original_right_gaze[2])
#     target = np.add(right_eye, np.multiply(original_right_gaze, multiplier))
#     print('right_gaze:', right_gaze, ', right_eye:', right_eye, ', target(re):', target, ', target(true)', look_vector)
    
    return warped_image, np.array(gaze_data), R_list #cv2.Rodrigues(np.dot(R, hR))[0]

In [20]:
def do_normalize(subject, frame):
    name = subject + '_' + frame
    json_path = os.path.join(input_json_dir, subject, name) + '.json'
    image_path = os.path.join(input_image_dir, subject, name) + '.jpg'
#     landmark_path = os.path.join(input_landmark_dir, data_name) + '.json'
    
    try:
        image, camera_matrix, camera_distortion, landmarks, face_3d, lookat = read_data(json_path, image_path)
    except:
        print('Failed to read data : ', data_name)
        return None
    
    # undistort landmark points and image
    landmarks = cv2.undistortPoints(landmarks, camera_matrix, camera_distortion, P=camera_matrix)
    image_undistorted = cv2.undistort(image, camera_matrix, camera_distortion)
    
#     drawCVImage(image)
    # pnp R and T
    hr, ht = estimateHeadPosition(landmarks, face_3d, camera_matrix, camera_distortion)
    face_3d = face_3d.reshape(-1, 3).T
    ht = ht.reshape((3, 1))

    # Rodrigues expression to 3x3 rotation matrix
    hR = cv2.Rodrigues(hr)[0]  # rotation matrix
    translate = ht.reshape((3, 1))
    face = np.dot(hR, face_3d) + translate  # 3D positions of facial landmarks

    # warped image, rotated gaze vector, face R, S, W, 3D face points
    image_warped, gaze_vector, R = normalizeDataAndGaze(image_undistorted, 
                                                        face,
                                                        camera_matrix,
                                                        lookat,
                                                        hR,
                                                        landmarks)


    eye_right = np.array([sum(x) for x in face[:, 0:6]]) / 6
    eye_left = np.array([sum(x) for x in face[:, 6:12]]) / 6
    
#     if abs(hR[2,0]) + abs(hR[2,1]) > abs(hR[2,2]) or hR[2,2]< 0:
#         print(hR)
#         drawCVImage(image)
#         drawLandmarks(image_undistorted, np.reshape(landmarks, [-1, 2]))
    
    right_eye_pose = cv2.Rodrigues(np.dot(R[0], hR))[0]
    left_eye_pose = cv2.Rodrigues(np.dot(R[1], hR))[0]
    face_pose = cv2.Rodrigues(np.dot(R[2], hR))[0]
    
    
    eyes = np.append(eye_right, eye_left)
    poses = np.reshape(np.concatenate((right_eye_pose, left_eye_pose, face_pose), axis=0), (-1))

#     if out_print:
#         print(record_data['gaze_pixel'])
#         drawCVImage(image)
#         print(record_data['orientation'])
#         print(eyes)    
#         print(eye_poses)

    return image_warped, gaze_vector, eyes, poses, R[0], R[1], R[2], lookat

In [21]:
if debug_print:
    phase = phases[0]
    input_dir = os.path.join(dir_input, phase)
    output_dir = os.path.join(dir_output, phase)
    input_image_dir = os.path.join(input_dir, image_dir)
    input_json_dir = os.path.join(input_dir, json_dir)
    input_landmark_dir = os.path.join(input_dir, landmarkd_dir)
    
    with open(os.path.join(input_dir, 'info_normalized2.json'), 'r') as f:
        data = json.load(f)['subjects']

    start_subject = 50
    end_subject = 70

    subject = list(data.keys())[0]
    count = 0
    for subject in data.keys():
        count += 1
        if count < start_subject:
            continue
        for frame in data[subject]:
            data_name = subject + '_' + frame
            image_warped, gaze_vector, eyes, eye_poses, right_inv_R, left_inv_R, lookat = do_normalize(subject, frame)
            print(eyes)
            print(eye_poses)
            break

        if (count >= end_subject):
            break


In [22]:
import tensorflow as tf
import numpy as np
import cv2
import os
import json
from tqdm import tqdm
import IPython.display as display
from multiprocessing import Pool

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import shutil

tf.debugging.set_log_device_placement(False)
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

1 Physical GPUs, 1 Logical GPUs


In [23]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _bytes_feature2(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def make_example(subject_name, frame_name, re, le, we, gaze, eyes, poses, rR, lR, cR, gaze2d):
    feature = {
        'subject': _bytes_feature(subject_name),
        'frame': _bytes_feature(frame_name),
        'img_re': _bytes_feature(re),
        'img_le': _bytes_feature(le),
        'img_we': _bytes_feature(we),
        'gaze': _float_feature(np.reshape(gaze, (-1)).tolist()),
        'eyes': _float_feature(np.reshape(eyes, (-1)).tolist()),
        'poses': _float_feature(np.reshape(poses, (-1)).tolist()),
        'rR': _float_feature(np.reshape(rR, (-1)).tolist()),
        'lR': _float_feature(np.reshape(lR, (-1)).tolist()),
        'cR': _float_feature(np.reshape(cR, (-1)).tolist()),
        'gaze2d': _float_feature(np.reshape(gaze2d, (-1)).tolist()),
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))

def _parse_image_function(example_raw):
  # Parse the input tf.Example proto using the dictionary above.
  return tf.io.parse_single_example(example_raw, image_feature_description)

In [24]:
def write_tfrecord(writer, subject, frame, image_warped, gaze, eyes, poses, rR, lR, cR, gaze2d):
    subject_name = bytes(str(subject), encoding='ascii')
    frame_name = bytes(str(frame), encoding='ascii')

    re = image_warped[0]
    le = image_warped[1]
    we = image_warped[2]
    
    try:

#         re = cv2.resize(re, (96, 64))
    #     drawImage(re)
        re = tf.io.encode_jpeg(re, quality=100).numpy()

#         le = cv2.resize(le, (96, 64))
    #     drawImage(le)
        le = tf.io.encode_jpeg(le, quality=100).numpy()

#         we = cv2.resize(we, (96, 48))
    #     drawImage(we)            
        we = tf.io.encode_jpeg(we, quality=100).numpy()

        example = make_example(subject_name, frame_name, re, le, we, gaze, eyes, poses, rR, lR, cR, gaze2d)
        writer.write(example.SerializeToString())

    except:
        print("Failed to make example: ", subject, "-", frame)

In [25]:
def preprocess_one_subject(subject):
    temp_file = os.path.join(temp_out_dir, subject + '.tfrecords')
    dst_file = os.path.join(out_tfrecord_dir, subject + '.tfrecords')
    
    if os.path.isfile(dst_file):
        return
    
    with tf.io.TFRecordWriter(temp_file) as writer:
        for frame in data[subject]:
            data_name = subject + '_' + frame
#             try:
            output = do_normalize(subject, frame)
            if output is None:
                continue
            [image_warped, gaze, eyes, poses, rR, lR, cR, gaze2d] = output
            write_tfrecord(writer, subject, frame, image_warped, gaze, eyes, poses, rR, lR, cR, gaze2d)
#             except:
#                 print("Failed frame : ", data_name)
    shutil.move(temp_file, dst_file)

In [26]:
for phase in phases:
    input_dir = os.path.join(dir_input, phase)
    output_dir = os.path.join(dir_output, phase)
    input_image_dir = os.path.join(input_dir, image_dir)
    input_json_dir = os.path.join(input_dir, json_dir)
    input_landmark_dir = os.path.join(input_dir, landmarkd_dir)
    out_tfrecord_dir = os.path.join(output_dir, tfrecord_dir)
    temp_out_dir = os.path.join(output_dir, temptfrecord_dir)

    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    if not os.path.exists(out_tfrecord_dir):
        os.mkdir(out_tfrecord_dir)    
    if not os.path.exists(temp_out_dir):
        os.mkdir(temp_out_dir)
        
    with open(os.path.join(input_dir, 'info_normalized2.json'), 'r') as f:
        data = json.load(f)['subjects']
    
    keys = list(data.keys())
    keys.reverse()
    pool = Pool(4)
    for _ in tqdm(pool.imap_unordered(preprocess_one_subject, keys), total=len(keys)):
        pass
    pool.close()
    pool.join()

#     for subject in data.keys():
#         preprocess_one_subject(subject)
#         break


100%|██████████| 1227/1227 [5:59:32<00:00, 17.58s/it]  
100%|██████████| 50/50 [14:27<00:00, 17.35s/it]
