In [1]:
import os
from matplotlib import pyplot
import numpy as np
import numpy.testing as npt
from scipy.spatial.transform import Rotation
import pickle
import copy
from PIL import Image
from IPython.display import display
from ipywidgets import interact
import ipywidgets as widgets
from torch.utils.data import Subset
import h5py
import time
import cv2
import tqdm
import pprint

In [2]:
from facenet_pytorch import MTCNN
import trimesh
import pyrender

In [3]:
from trackertraincode.facemodel.bfm import BFMModel
from trackertraincode.facemodel import keypoints68
from trackertraincode.datatransformation import _ensure_image_nchw
from trackertraincode.datasets.dshdf5pose import Hdf5PoseDataset
from trackertraincode.datasets.dshdf5 import open_dataset
from trackertraincode import vis
from trackertraincode import utils
import trackertraincode.datatransformation as dtr
from trackertraincode.datasets.preprocessing import imencode, ImageFormat
from scripts.filter_dataset import filter_file_by_frames
import face3drotationaugmentation

In [4]:
filename = os.path.join(os.environ['DATADIR'],'lapa-megaface.h5')

In [5]:
def set_field_for_has_exactly_one_face(filename):
    mtcnn = MTCNN(keep_all=True, device='cpu', min_face_size = 32)
    ds = Hdf5PoseDataset(filename, monochrome=False, transform=dtr.to_numpy, whitelist=['/images'])
    mask = np.zeros((len(ds),), dtype='?')
    for i, sample in enumerate(tqdm.tqdm(ds)):
        myboxes, probs = mtcnn.detect(Image.fromarray(sample['image']))
        if len(probs) != 1:
            continue
        mask[i] = True
    ds.close()
    with h5py.File(filename,'r+') as f:
        ds = f.require_dataset('has_one_face', shape=mask.shape, dtype=mask.dtype)
        ds[...] = mask

In [6]:
# Enable if "has_one_face" dataset is not yet present in the hdf5
#set_field_for_has_exactly_one_face(filename)

In [7]:
with h5py.File(filename,'r') as f:
    N = len(f['2dfit_v2/quats'])
    g = f['2dfit_v2']
    h = f['pseudolabels']
    pt2d_68 = f['pt2d_68'][...]
    rois = f['rois'][...].astype(np.float64)
    pred_offline = g['pt3d_68'][...]
    pred_nn = h['pt3d_68'][:N,...]
    quats_offline = g['quats'][:N,...]
    quats_nn = h['quats'][:N,...]
    coords_offline = g['coords'][:N,...]
    coords_nn = h['coords'][:N,...]
    shapeparam_offline = g['shapeparams'][:N,...]
    shapeparam_nn = h['shapeparams'][:N,...]
    has_one_face = f['has_one_face'][...]

In [129]:
#rot_magnitudes_offline = Rotation.from_quat(quats_offline).magnitude()
rot_magnitudes_offline = np.abs(utils.as_hpb(Rotation.from_quat(quats_offline))[:,0])
mask_small_rotation = rot_magnitudes_offline < np.pi/180.*5.
mask_large_rotation = rot_magnitudes_offline > np.pi/180.*15.
mask_rotations = ~(mask_small_rotation | mask_large_rotation)
diameters = np.linalg.norm(rois[:,[2,3]]-rois[:,[0,1]],axis=-1)
rot_differences = (Rotation.from_quat(quats_offline).inv() * Rotation.from_quat(quats_nn)).magnitude()
mask_small = diameters < 196
delta = 1.*rot_differences + np.linalg.norm(coords_offline - coords_nn,axis=-1)*100./diameters + np.average(np.linalg.norm(pred_nn - pred_offline, axis=-1)*0.5, axis=-1)

In [130]:
def candidates_for_bad_fits():
    # Bad = nn and offline fit are in good agreement except for the jaw on one side where the agreement must be bad.
    not_chin = list(set(range(68)).difference(set(keypoints68.chin_left+keypoints68.chin_right)))
    delta_nn = np.linalg.norm(pt2d_68 - pred_nn[:,:,:2], axis=-1)
    delta_offline = np.linalg.norm(pt2d_68 - pred_offline[:,:,:2], axis=-1)
    delta_chin_nn = delta_nn[:,keypoints68.chin_left+keypoints68.chin_right]
    delta_chin_offline = delta_offline[:,keypoints68.chin_left+keypoints68.chin_right]
    bad_mask1 = np.average(delta_chin_nn, axis=-1) > np.average(delta_chin_offline, axis=-1) + diameters*0.02
    bad_mask2 = np.average(delta_offline, axis=-1) < diameters * 0.1
    bad_mask3 = np.average(delta_nn[:,not_chin], axis=-1) < diameters * 0.1
    return bad_mask1 & bad_mask2 & bad_mask3

In [131]:
np.count_nonzero(candidates_for_bad_fits())

3416

In [150]:
mask = has_one_face & (~mask_small) # mask_rotations & (~candidates_for_bad_fits()) &
order = np.nonzero(mask)[0]

In [151]:
len(order)

7109

In [152]:
order = np.flip(order[np.argsort(delta[order])])

In [141]:
#bad_sequences = bad_sequences.union(order)

In [122]:
#order = np.setdiff1d(order, list(bad_sequences))

In [16]:
facerender = vis.FaceRender()

In [17]:
ds = Hdf5PoseDataset(filename, monochrome=False, transform=dtr.to_numpy)

In [153]:
# TODO: bad_frames ...

# These are the frames I identified as bad fits. Thus they shall not be used.
bad_sequences = {1, 11, 12, 19, 21, 24, 28, 29, 32, 33, 34, 35, 44, 49, 51, 52, 53, 58, 59, 64, 65, 72, 73, 78, 82, 83, 86, 88, 89, 92, 97, 100, 107, 110,
 120, 122, 124, 126, 130, 131, 133, 144, 146, 147, 149, 154, 158, 159, 166, 167, 168, 170, 180, 181, 185, 186, 189, 190, 191, 192, 196, 197,
 198, 199, 200, 201, 202, 203, 204, 205, 222, 229, 237, 240, 249, 252, 255, 257, 258, 259, 260, 262, 263, 264, 266, 270, 272, 273, 276, 278,
 279, 280, 283, 285, 287, 289, 291, 292, 298, 300, 301, 307, 308, 309, 310, 311, 316, 325, 326, 330, 331, 332, 336, 341, 342, 345, 354, 355,
 359, 361, 362, 365, 369, 378, 383, 384, 387, 389, 390, 392, 393, 394, 396, 398, 399, 401, 403, 404, 405, 408, 409, 412, 413, 414, 418, 423,
 429, 430, 431, 438, 443, 444, 446, 450, 453, 461, 462, 472, 477, 478, 479, 486, 487, 488, 489, 490, 495, 497, 499, 503, 504, 505, 509, 511,
 512, 514, 521, 524, 525, 526, 531, 532, 533, 535, 539, 540, 544, 545, 548, 549, 550, 551, 557, 558, 564, 566, 573, 578, 579, 588, 589, 592,
 594, 595, 600, 601, 602, 612, 618, 619, 620, 621, 628, 629, 632, 637, 641, 642, 644, 650, 660, 661, 663, 671, 673, 674, 675, 679, 683, 684,
 685, 687, 692, 695, 696, 701, 702, 711, 716, 717, 724, 725, 728, 733, 734, 735, 736, 737, 741, 742, 746, 750, 751, 755, 756, 757, 758, 762,
 766, 770, 771, 774, 777, 778, 779, 782, 784, 787, 794, 795, 797, 798, 803, 807, 809, 810, 812, 814, 818, 828, 830, 835, 837, 838, 840, 841,
 842, 843, 844, 846, 848, 853, 859, 861, 865, 872, 876, 877, 883, 887, 888, 892, 893, 894, 895, 900, 901, 903, 904, 918, 919, 924, 925, 932,
 934, 935, 938, 942, 943, 944, 951, 953, 954, 956, 962, 967, 973, 974, 975, 979, 980, 982, 984, 986, 994, 996, 999, 1001, 1002, 1005, 1014,
 1018, 1023, 1035, 1040, 1042, 1044, 1046, 1048, 1049, 1055, 1063, 1064, 1067, 1068, 1074, 1075, 1080, 1083, 1084, 1087, 1090, 1100, 1104,
 1105, 1106, 1108, 1115, 1116, 1118, 1119, 1122, 1123, 1127, 1128, 1133, 1144, 1145, 1148, 1149, 1150, 1151, 1155, 1156, 1158, 1160, 1161,
 1162, 1163, 1168, 1174, 1175, 1176, 1178, 1179, 1180, 1181, 1182, 1184, 1185, 1190, 1191, 1192, 1193, 1203, 1205, 1209, 1211, 1212, 1216,
 1219, 1222, 1236, 1238, 1239, 1241, 1244, 1247, 1248, 1252, 1253, 1256, 1257, 1260, 1261, 1263, 1264, 1272, 1277, 1278, 1279, 1280, 1288,
 1291, 1292, 1293, 1298, 1299, 1302, 1303, 1309, 1314, 1319, 1324, 1325, 1327, 1329, 1335, 1339, 1340, 1341, 1342, 1343, 1347, 1353, 1356,
 1357, 1359, 1364, 1365, 1366, 1367, 1368, 1369, 1377, 1380, 1382, 1386, 1392, 1402, 1403, 1404, 1407, 1410, 1415, 1422, 1424, 1426, 1430,
 1431, 1432, 1437, 1440, 1446, 1448, 1450, 1451, 1452, 1454, 1458, 1461, 1466, 1467, 1475, 1476, 1481, 1487, 1491, 1494, 1501, 1506, 1508,
 1511, 1513, 1514, 1515, 1517, 1521, 1522, 1523, 1525, 1526, 1527, 1532, 1533, 1534, 1536, 1537, 1542, 1544, 1546, 1550, 1553, 1559, 1562,
 1564, 1567, 1569, 1570, 1572, 1574, 1576, 1581, 1582, 1584, 1585, 1587, 1596, 1597, 1600, 1601, 1608, 1610, 1614, 1617, 1620, 1624, 1634,
 1635, 1641, 1644, 1651, 1653, 1655, 1656, 1660, 1664, 1666, 1667, 1670, 1680, 1682, 1688, 1689, 1695, 1699, 1707, 1710, 1711, 1712, 1716,
 1717, 1723, 1729, 1733, 1736, 1739, 1742, 1746, 1751, 1752, 1753, 1755, 1758, 1759, 1760, 1761, 1762, 1763, 1764, 1774, 1776, 1778, 1779,
 1781, 1782, 1783, 1788, 1789, 1794, 1796, 1798, 1799, 1800, 1802, 1805, 1810, 1811, 1819, 1820, 1826, 1827, 1830, 1835, 1838, 1840, 1841,
 1842, 1846, 1847, 1853, 1855, 1859, 1860, 1861, 1862, 1864, 1870, 1871, 1873, 1875, 1877, 1879, 1880, 1888, 1889, 1892, 1893, 1896, 1897,
 1899, 1903, 1904, 1906, 1913, 1915, 1916, 1917, 1920, 1924, 1925, 1926, 1928, 1930, 1932, 1934, 1942, 1943, 1948, 1958, 1960, 1961, 1962,
 1965, 1966, 1967, 1968, 1973, 1974, 1979, 1980, 1981, 1982, 1983, 1986, 1996, 2000, 2002, 2005, 2009, 2010, 2013, 2014, 2022, 2023, 2024,
 2025, 2027, 2031, 2033, 2036, 2040, 2043, 2046, 2047, 2051, 2052, 2053, 2057, 2058, 2061, 2062, 2063, 2064, 2068, 2072, 2075, 2079, 2081,
 2093, 2101, 2102, 2103, 2104, 2113, 2117, 2118, 2119, 2126, 2130, 2135, 2138, 2143, 2144, 2153, 2154, 2159, 2161, 2163, 2165, 2168, 2173,
 2174, 2175, 2176, 2177, 2178, 2179, 2181, 2183, 2186, 2188, 2189, 2190, 2192, 2193, 2194, 2197, 2198, 2200, 2201, 2202, 2203, 2204, 2205,
 2209, 2215, 2217, 2220, 2222, 2225, 2229, 2230, 2234, 2235, 2245, 2249, 2256, 2257, 2259, 2264, 2265, 2266, 2270, 2273, 2282, 2285, 2287,
 2291, 2294, 2295, 2298, 2302, 2305, 2307, 2310, 2311, 2316, 2318, 2319, 2325, 2327, 2331, 2333, 2334, 2340, 2343, 2344, 2348, 2352, 2355,
 2356, 2357, 2368, 2369, 2371, 2376, 2382, 2383, 2395, 2399, 2406, 2408, 2415, 2417, 2423, 2424, 2426, 2427, 2428, 2433, 2436, 2438, 2440,
 2441, 2442, 2444, 2447, 2448, 2449, 2450, 2453, 2459, 2462, 2464, 2465, 2466, 2468, 2469, 2470, 2473, 2477, 2482, 2483, 2485, 2488, 2491,
 2494, 2497, 2499, 2501, 2503, 2504, 2508, 2511, 2517, 2518, 2519, 2521, 2522, 2525, 2527, 2530, 2533, 2539, 2540, 2541, 2543, 2544, 2551,
 2553, 2554, 2555, 2556, 2557, 2558, 2559, 2560, 2561, 2562, 2563, 2567, 2574, 2575, 2576, 2582, 2584, 2585, 2586, 2587, 2588, 2593, 2595,
 2596, 2597, 2600, 2601, 2611, 2613, 2624, 2627, 2630, 2632, 2633, 2635, 2638, 2639, 2640, 2641, 2642, 2646, 2653, 2655, 2660, 2661, 2662,
 2667, 2672, 2674, 2675, 2676, 2677, 2681, 2682, 2683, 2685, 2686, 2690, 2692, 2698, 2699, 2702, 2705, 2708, 2709, 2710, 2716, 2717, 2719,
 2720, 2724, 2731, 2734, 2737, 2739, 2740, 2742, 2744, 2746, 2749, 2751, 2755, 2757, 2758, 2759, 2760, 2761, 2764, 2766, 2773, 2775, 2784,
 2785, 2791, 2792, 2793, 2794, 2803, 2806, 2808, 2810, 2812, 2820, 2822, 2823, 2824, 2830, 2831, 2834, 2835, 2836, 2838, 2839, 2841, 2843,
 2845, 2846, 2847, 2852, 2853, 2855, 2861, 2863, 2864, 2870, 2871, 2872, 2879, 2884, 2887, 2889, 2890, 2891, 2896, 2897, 2899, 2908, 2925,
 2930, 2932, 2933, 2935, 2939, 2940, 2943, 2948, 2949, 2951, 2956, 2957, 2959, 2961, 2964, 2971, 2972, 2974, 2976, 2981, 2987, 2988, 2993,
 2994, 3000, 3002, 3004, 3005, 3008, 3024, 3027, 3029, 3030, 3043, 3044, 3047, 3051, 3055, 3058, 3063, 3064, 3067, 3071, 3077, 3078, 3079,
 3082, 3084, 3090, 3092, 3095, 3097, 3100, 3103, 3105, 3107, 3112, 3114, 3118, 3124, 3126, 3130, 3131, 3133, 3137, 3141, 3142, 3148, 3149,
 3150, 3156, 3157, 3162, 3163, 3164, 3167, 3169, 3175, 3177, 3180, 3183, 3184, 3186, 3187, 3190, 3192, 3198, 3199, 3201, 3203, 3205, 3208,
 3209, 3211, 3212, 3220, 3223, 3225, 3227, 3229, 3231, 3232, 3234, 3235, 3236, 3238, 3239, 3241, 3249, 3253, 3254, 3255, 3258, 3259, 3263,
 3266, 3270, 3277, 3289, 3298, 3301, 3306, 3307, 3308, 3309, 3315, 3317, 3321, 3325, 3327, 3329, 3330, 3332, 3336, 3339, 3340, 3342, 3352,
 3353, 3355, 3357, 3361, 3366, 3368, 3373, 3380, 3384, 3390, 3391, 3393, 3399, 3403, 3404, 3405, 3406, 3413, 3414, 3417, 3418, 3423, 3425,
 3427, 3433, 3444, 3446, 3448, 3450, 3454, 3455, 3462, 3464, 3466, 3467, 3471, 3472, 3473, 3475, 3478, 3481, 3489, 3501, 3503, 3507, 3516,
 3518, 3519, 3522, 3526, 3528, 3529, 3531, 3533, 3534, 3536, 3537, 3540, 3545, 3546, 3548, 3554, 3560, 3563, 3564, 3566, 3568, 3575, 3580,
 3581, 3583, 3584, 3590, 3593, 3595, 3596, 3598, 3599, 3601, 3603, 3607, 3608, 3609, 3610, 3611, 3612, 3613, 3614, 3617, 3620, 3621, 3624,
 3625, 3627, 3628, 3632, 3635, 3636, 3640, 3647, 3653, 3655, 3660, 3669, 3670, 3672, 3680, 3683, 3684, 3685, 3686, 3688, 3689, 3690, 3691,
 3692, 3694, 3696, 3698, 3711, 3713, 3714, 3715, 3724, 3725, 3726, 3727, 3728, 3730, 3736, 3742, 3744, 3745, 3746, 3761, 3765, 3768, 3774,
 3784, 3786, 3788, 3789, 3790, 3792, 3793, 3797, 3799, 3800, 3802, 3805, 3807, 3808, 3814, 3817, 3828, 3834, 3837, 3839, 3844, 3846, 3849,
 3851, 3852, 3854, 3855, 3859, 3864, 3865, 3866, 3867, 3868, 3872, 3873, 3880, 3887, 3888, 3891, 3897, 3898, 3900, 3907, 3912, 3917, 3919,
 3921, 3925, 3926, 3927, 3928, 3929, 3932, 3934, 3941, 3942, 3947, 3953, 3958, 3978, 3979, 3986, 3991, 3993, 4000, 4001, 4009, 4015, 4018,
 4023, 4024, 4025, 4026, 4027, 4029, 4036, 4037, 4038, 4039, 4040, 4041, 4042, 4050, 4054, 4056, 4057, 4067, 4069, 4070, 4071, 4072, 4073,
 4074, 4079, 4080, 4086, 4088, 4089, 4090, 4092, 4095, 4108, 4110, 4117, 4118, 4120, 4121, 4123, 4128, 4130, 4131, 4132, 4136, 4138, 4140,
 4141, 4142, 4143, 4149, 4150, 4151, 4158, 4159, 4162, 4174, 4181, 4184, 4194, 4195, 4197, 4199, 4204, 4205, 4206, 4207, 4209, 4210, 4214,
 4218, 4219, 4220, 4224, 4226, 4232, 4233, 4237, 4239, 4240, 4242, 4247, 4266, 4273, 4274, 4281, 4283, 4284, 4285, 4286, 4287, 4288, 4290,
 4299, 4300, 4303, 4304, 4313, 4315, 4316, 4317, 4319, 4323, 4328, 4331, 4332, 4333, 4334, 4336, 4337, 4338, 4347, 4351, 4353, 4359, 4360,
 4365, 4366, 4370, 4371, 4373, 4379, 4381, 4383, 4384, 4385, 4389, 4395, 4399, 4401, 4402, 4406, 4414, 4416, 4417, 4419, 4421, 4422, 4425,
 4429, 4430, 4431, 4433, 4453, 4454, 4456, 4462, 4466, 4469, 4471, 4474, 4480, 4482, 4485, 4487, 4494, 4497, 4501, 4503, 4504, 4505, 4508,
 4509, 4515, 4518, 4519, 4520, 4523, 4529, 4532, 4534, 4542, 4543, 4544, 4546, 4549, 4551, 4552, 4553, 4563, 4564, 4567, 4571, 4573, 4577,
 4580, 4588, 4590, 4594, 4595, 4596, 4597, 4600, 4602, 4603, 4606, 4608, 4611, 4613, 4615, 4622, 4623, 4624, 4625, 4631, 4632, 4639, 4641,
 4648, 4649, 4653, 4656, 4663, 4665, 4668, 4671, 4672, 4673, 4675, 4680, 4682, 4685, 4694, 4697, 4701, 4702, 4706, 4709, 4713, 4714, 4716,
 4720, 4721, 4722, 4723, 4725, 4726, 4728, 4730, 4731, 4737, 4738, 4739, 4742, 4745, 4747, 4748, 4754, 4756, 4759, 4765, 4767, 4768, 4777,
 4778, 4779, 4784, 4787, 4791, 4792, 4793, 4797, 4798, 4802, 4803, 4806, 4810, 4816, 4818, 4819, 4820, 4822, 4823, 4824, 4829, 4836, 4837,
 4838, 4848, 4851, 4852, 4853, 4855, 4859, 4866, 4868, 4874, 4883, 4888, 4889, 4891, 4893, 4895, 4897, 4899, 4900, 4905, 4906, 4908, 4909,
 4910, 4917, 4918, 4924, 4925, 4927, 4934, 4936, 4940, 4944, 4945, 4947, 4948, 4956, 4957, 4960, 4963, 4964, 4965, 4972, 4979, 4980, 4982,
 4984, 4995, 4997, 4998, 5001, 5002, 5003, 5004, 5011, 5013, 5016, 5017, 5020, 5021, 5025, 5035, 5040, 5041, 5042, 5043, 5044, 5045, 5048,
 5052, 5055, 5056, 5057, 5058, 5059, 5060, 5061, 5062, 5063, 5064, 5065, 5067, 5074, 5076, 5081, 5087, 5089, 5091, 5094, 5100, 5101, 5102,
 5103, 5104, 5105, 5106, 5108, 5118, 5121, 5122, 5123, 5124, 5126, 5130, 5132, 5134, 5137, 5139, 5144, 5150, 5152, 5153, 5155, 5157, 5163,
 5171, 5172, 5179, 5181, 5182, 5187, 5192, 5193, 5202, 5203, 5205, 5206, 5207, 5208, 5209, 5210, 5212, 5214, 5219, 5222, 5223, 5225, 5226,
 5227, 5239, 5241, 5243, 5246, 5248, 5253, 5255, 5262, 5263, 5264, 5266, 5268, 5273, 5274, 5275, 5278, 5279, 5284, 5292, 5299, 5301, 5304,
 5307, 5309, 5313, 5320, 5321, 5324, 5325, 5326, 5330, 5332, 5333, 5334, 5335, 5339, 5341, 5342, 5345, 5357, 5359, 5360, 5369, 5375, 5379,
 5386, 5392, 5393, 5394, 5408, 5412, 5415, 5416, 5417, 5424, 5430, 5432, 5433, 5434, 5444, 5446, 5448, 5449, 5452, 5454, 5458, 5460, 5461,
 5463, 5466, 5473, 5480, 5482, 5483, 5485, 5486, 5487, 5488, 5491, 5493, 5494, 5496, 5498, 5499, 5502, 5507, 5510, 5511, 5513, 5514, 5520,
 5527, 5528, 5536, 5540, 5542, 5549, 5550, 5552, 5554, 5555, 5556, 5562, 5566, 5572, 5574, 5577, 5579, 5588, 5589, 5590, 5591, 5595, 5596,
 5597, 5599, 5603, 5605, 5606, 5609, 5614, 5616, 5617, 5618, 5619, 5626, 5627, 5630, 5631, 5632, 5635, 5636, 5637, 5638, 5642, 5643, 5651,
 5652, 5656, 5658, 5659, 5663, 5664, 5665, 5666, 5669, 5670, 5678, 5683, 5699, 5705, 5724, 5726, 5729, 5731, 5734, 5735, 5736, 5741, 5744,
 5746, 5750, 5752, 5753, 5754, 5756, 5757, 5760, 5765, 5766, 5767, 5768, 5770, 5775, 5777, 5781, 5782, 5784, 5787, 5789, 5791, 5794, 5797,
 5802, 5804, 5805, 5806, 5808, 5814, 5815, 5823, 5829, 5836, 5839, 5840, 5841, 5843, 5847, 5848, 5851, 5852, 5853, 5866, 5868, 5870, 5877,
 5878, 5881, 5893, 5897, 5901, 5903, 5907, 5916, 5923, 5924, 5926, 5927, 5929, 5936, 5937, 5944, 5946, 5947, 5948, 5951, 5952, 5953, 5954,
 5955, 5959, 5965, 5968, 5972, 5974, 5978, 5982, 5987, 5993, 6002, 6004, 6006, 6014, 6018, 6020, 6024, 6025, 6026, 6028, 6036, 6038, 6044,
 6045, 6047, 6049, 6050, 6053, 6062, 6069, 6073, 6077, 6082, 6084, 6086, 6096, 6103, 6105, 6106, 6108, 6109, 6112, 6114, 6115, 6116, 6117,
 6119, 6120, 6122, 6124, 6126, 6127, 6130, 6131, 6133, 6139, 6140, 6141, 6143, 6144, 6152, 6160, 6161, 6165, 6167, 6169, 6171, 6175, 6177,
 6181, 6184, 6185, 6186, 6188, 6189, 6190, 6191, 6192, 6198, 6199, 6200, 6201, 6204, 6206, 6212, 6213, 6216, 6218, 6219, 6221, 6224, 6227,
 6228, 6230, 6238, 6251, 6254, 6256, 6257, 6258, 6263, 6265, 6267, 6270, 6271, 6274, 6275, 6276, 6279, 6287, 6288, 6291, 6297, 6306, 6310,
 6314, 6320, 6321, 6323, 6326, 6327, 6329, 6337, 6340, 6344, 6348, 6349, 6360, 6363, 6376, 6377, 6378, 6383, 6384, 6385, 6386, 6389, 6390,
 6391, 6392, 6398, 6399, 6400, 6406, 6416, 6417, 6418, 6419, 6420, 6423, 6429, 6431, 6433, 6436, 6437, 6438, 6447, 6449, 6451, 6453, 6455,
 6456, 6458, 6461, 6465, 6473, 6475, 6487, 6490, 6494, 6496, 6497, 6499, 6500, 6504, 6505, 6509, 6515, 6518, 6520, 6521, 6522, 6524, 6530,
 6535, 6536, 6540, 6543, 6545, 6548, 6556, 6562, 6566, 6569, 6573, 6577, 6579, 6583, 6584, 6585, 6586, 6587, 6588, 6591, 6595, 6596, 6598,
 6599, 6604, 6606, 6607, 6609, 6611, 6613, 6614, 6615, 6623, 6624, 6625, 6626, 6628, 6632, 6635, 6637, 6638, 6641, 6642, 6643, 6645, 6647,
 6648, 6650, 6654, 6658, 6659, 6662, 6663, 6665, 6667, 6668, 6669, 6672, 6676, 6678, 6680, 6683, 6685, 6687, 6695, 6697, 6700, 6703, 6705,
 6706, 6708, 6713, 6716, 6723, 6724, 6726, 6731, 6741, 6744, 6750, 6752, 6755, 6761, 6762, 6764, 6765, 6768, 6770, 6771, 6776, 6779, 6780,
 6783, 6790, 6791, 6795, 6796, 6797, 6798, 6800, 6814, 6816, 6818, 6820, 6826, 6835, 6843, 6844, 6845, 6857, 6860, 6861, 6862, 6863, 6864,
 6865, 6868, 6869, 6870, 6871, 6872, 6873, 6878, 6879, 6880, 6881, 6884, 6885, 6891, 6893, 6894, 6895, 6898, 6902, 6916, 6918, 6922, 6923,
 6926, 6927, 6928, 6930, 6935, 6937, 6940, 6943, 6946, 6947, 6949, 6956, 6959, 6967, 6969, 6970, 6973, 6979, 6980, 6982, 6984, 6987, 6988,
 7004, 7009, 7017, 7019, 7020, 7021, 7022, 7025, 7026, 7031, 7033, 7038, 7043, 7044, 7045, 7046, 7050, 7057, 7058, 7059, 7061, 7062, 7067,
 7071, 7072, 7073, 7075, 7081, 7085, 7089, 7091, 7092, 7093, 7094, 7096, 7099, 7100, 7103, 7104, 7105, 7106, 7108, 7109, 7110, 7111, 7112,
 7113, 7119, 7124, 7125, 7127, 7128, 7132, 7135, 7136, 7137, 7139, 7140, 7141, 7142, 7144, 7145, 7147, 7148, 7156, 7158, 7159, 7161, 7163,
 7164, 7165, 7176, 7177, 7178, 7179, 7182, 7192, 7199, 7203, 7205, 7214, 7216, 7222, 7224, 7230, 7239, 7242, 7243, 7244, 7246, 7250, 7253,
 7255, 7256, 7257, 7263, 7268, 7270, 7273, 7276, 7277, 7278, 7281, 7284, 7285, 7292, 7297, 7299, 7308, 7313, 7317, 7319, 7320, 7324, 7325,
 7327, 7330, 7332, 7334, 7336, 7337, 7339, 7340, 7343, 7349, 7350, 7354, 7356, 7361, 7362, 7363, 7364, 7372, 7376, 7380, 7381, 7389, 7396,
 7398, 7399, 7400, 7404, 7406, 7409, 7412, 7413, 7414, 7417, 7418, 7420, 7421, 7425, 7426, 7427, 7428, 7429, 7433, 7435, 7439, 7440, 7442,
 7443, 7446, 7450, 7454, 7457, 7458, 7462, 7466, 7467, 7468, 7469, 7470, 7471, 7472, 7478, 7480, 7482, 7484, 7490, 7494, 7501, 7504, 7507,
 7508, 7512, 7515, 7520, 7523, 7524, 7526, 7528, 7532, 7535, 7539, 7542, 7543, 7544, 7550, 7555, 7563, 7565, 7570, 7573, 7574, 7576, 7580,
 7581, 7591, 7592, 7595, 7597, 7598, 7606, 7607, 7615, 7616, 7617, 7618, 7627, 7630, 7632, 7633, 7635, 7637, 7640, 7646, 7648, 7650, 7655,
 7657, 7663, 7666, 7667, 7668, 7669, 7674, 7679, 7680, 7683, 7685, 7688, 7691, 7695, 7696, 7699, 7700, 7701, 7703, 7712, 7714, 7715, 7725,
 7726, 7736, 7740, 7743, 7747, 7749, 7750, 7753, 7757, 7758, 7762, 7764, 7767, 7770, 7773, 7775, 7776, 7778, 7779, 7784, 7786, 7787, 7790,
 7793, 7794, 7796, 7799, 7802, 7805, 7810, 7814, 7822, 7824, 7827, 7828, 7832, 7833, 7835, 7837, 7838, 7840, 7841, 7842, 7844, 7846, 7848,
 7855, 7859, 7862, 7863, 7871, 7874, 7882, 7883, 7887, 7888, 7889, 7890, 7891, 7893, 7894, 7896, 7898, 7899, 7902, 7908, 7916, 7924, 7928,
 7941, 7943, 7946, 7950, 7951, 7952, 7953, 7956, 7961, 7964, 7971, 7974, 7977, 7978, 7979, 7982, 7983, 7986, 7988, 7993, 7994, 7995, 7996,
 7999, 8002, 8005, 8014, 8017, 8029, 8037, 8040, 8042, 8044, 8046, 8053, 8059, 8064, 8065, 8066, 8067, 8069, 8071, 8072, 8075, 8081, 8082,
 8083, 8089, 8090, 8097, 8100, 8103, 8105, 8107, 8108, 8112, 8113, 8114, 8115, 8119, 8120, 8121, 8122, 8123, 8124, 8125, 8126, 8127, 8128,
 8129, 8130, 8131, 8137, 8140, 8144, 8147, 8148, 8151, 8152, 8154, 8155, 8157, 8158, 8162, 8164, 8165, 8171, 8173, 8177, 8179, 8180, 8181,
 8182, 8183, 8189, 8190, 8192, 8195, 8197, 8200, 8208, 8210, 8211, 8212, 8213, 8215, 8223, 8228, 8230, 8231, 8232, 8233, 8246, 8247, 8250,
 8251, 8261, 8263, 8264, 8265, 8266, 8267, 8268, 8271, 8280, 8281, 8284, 8286, 8289, 8290, 8291, 8293, 8294, 8297, 8298, 8299, 8302, 8308,
 8310, 8312, 8313, 8314, 8317, 8326, 8327, 8328, 8333, 8336, 8346, 8351, 8352, 8353, 8359, 8369, 8371, 8372, 8373, 8374, 8375, 8382, 8384,
 8392, 8393, 8402, 8403, 8404, 8408, 8409, 8411, 8412, 8413, 8414, 8415, 8418, 8419, 8423, 8428, 8431, 8433, 8434, 8436, 8437, 8438, 8439,
 8444, 8446, 8447, 8452, 8453, 8454, 8457, 8458, 8463, 8465, 8466, 8473, 8480, 8481, 8486, 8492, 8494, 8498, 8499, 8502, 8503, 8504, 8515,
 8519, 8523, 8525, 8529, 8530, 8534, 8539, 8540, 8550, 8555, 8560, 8561, 8562, 8564, 8565, 8566, 8571, 8572, 8576, 8585, 8587, 8589, 8594,
 8595, 8601, 8608, 8610, 8614, 8615, 8617, 8622, 8623, 8625, 8626, 8627, 8628, 8635, 8637, 8638, 8640, 8642, 8644, 8646, 8647, 8648, 8652,
 8654, 8656, 8659, 8660, 8661, 8665, 8666, 8670, 8673, 8696, 8697, 8698, 8699, 8700, 8702, 8710, 8714, 8716, 8723, 8724, 8725, 8727, 8729,
 8730, 8737, 8739, 8740, 8741, 8742, 8743, 8745, 8746, 8757, 8758, 8762, 8764, 8766, 8767, 8770, 8771, 8772, 8775, 8776, 8781, 8784, 8787,
 8788, 8789, 8790, 8797, 8799, 8801, 8804, 8806, 8807, 8808, 8811, 8812, 8814, 8820, 8822, 8828, 8829, 8830, 8833, 8837, 8838, 8839, 8844,
 8847, 8848, 8849, 8850, 8856, 8859, 8863, 8864, 8866, 8867, 8869, 8870, 8871, 8872, 8873, 8877}
# TODO: actually use these predictions instead of offline fits.
indices_use_nn_prediction = set()  # Default is offline fit

In [154]:
button = widgets.Button(description='Bad')
button_use_nn_pred = widgets.Button(description='Use NN')
button_use_offline_pred = widgets.Button(description='Use Offline')
button_next = widgets.Button(description='Next')
button_prev = widgets.Button(description='Prev')
slider = widgets.IntSlider(value = 0, min=0, max=len(order)-1)
image_widget = widgets.Image()

label = widgets.Label("---")


my_widgets = widgets.HBox([button_prev, slider, button_next, button, button_use_nn_pred, button_use_offline_pred, label])


current_index = 0


def on_next(_):
    global slider
    slider.value += 1 if slider.value < slider.max else 0
    #update_label()
    #render_sample(current_index)

    
def on_prev(_):
    global slider
    slider.value -= 1 if slider.value > slider.min else 0
    #update_label()
    #render_sample(current_index)
    

def on_button_clicked(_):
    bad_sequences.add(current_index)
    try:
        indices_use_nn_prediction.remove(current_index)
    except KeyError:
        pass
    update_label()


def on_button_use_nn_pred(_):
    try:
        bad_sequences.remove(current_index)
    except KeyError:
        pass
    indices_use_nn_prediction.add(current_index)
    update_label()

    
def on_button_use_offline_pred(_):
    try:
        bad_sequences.remove(current_index)
    except KeyError:
        pass
    try:
        indices_use_nn_prediction.remove(current_index)
    except KeyError:
        pass
    update_label()


def update_label():
    if current_index in bad_sequences:
        label.value = f"Bad Label {current_index}"
        return
    if current_index in indices_use_nn_prediction:
        label.value = f"NN {current_index}"
        return
    label.value = f"Offline {current_index}"

def value_changed(change):
    render_sample(change.new)
    
button.on_click(on_button_clicked)
button_use_nn_pred.on_click(on_button_use_nn_pred)
button_use_offline_pred.on_click(on_button_use_offline_pred)
button_prev.on_click(on_prev)
button_next.on_click(on_next)
slider.observe(value_changed, 'value')


display(my_widgets)
display(image_widget)

def _visualize(sample):
    img = sample['image']
    rendering = facerender.set(
        sample['coord'][:2],
        sample['coord'][2],
        Rotation.from_quat(sample['pose']),
        sample['shapeparam'][:50],
        img.shape[:2],
    )
    sample = copy.copy(sample)
    rendering = Image.fromarray(rendering)
    img = Image.fromarray(img)
    img = Image.blend(rendering, img, 0.4)
    sample['image'] = np.asarray(img)
    img = vis.draw_dataset_sample(sample)
    return img


#@interact(idx = (0,len(order)-1))
def render_sample(idx): #, overlay, show_network_prediction):
    i = order[idx]
    sample_image = ds[i]['image']
    h, w = sample_image.shape[:2]    
    
    global current_index
    current_index = i
    
    img = sample_image
    if img.shape[-1] == 1:
        img = np.broadcast_to(img, img.shape[:2]+(3,))
    
    nn_sample = {
        'image' : np.asarray(img),
        'coord' : coords_nn[i],
        'pose' : quats_nn[i],
        'pt3d_68' : pred_nn[i],
        'shapeparam' : shapeparam_nn[i],
    }

    offline_sample = {
        'image' : np.asarray(img),
        'coord' : coords_offline[i],
        'pose' : quats_offline[i],
        'pt3d_68' : pred_offline[i],
        'shapeparam' : shapeparam_offline[i],
    }
    
    vis.draw_points3d(
        img,
        ds[i]['pt2d_68'],
        labels=False
    )
    nn_vis = _visualize(nn_sample)
    nn_vis = cv2.putText(nn_vis, 'nn', (15,15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255))
    offline_vis = _visualize(offline_sample)
    offline_vis = cv2.putText(offline_vis, 'offline', (15,15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255))
    img = np.r_['1,1,1',img,nn_vis,offline_vis]
    
    if i in bad_sequences:
        linewidth = 2
        color = (255,0,0)
        cv2.line(img, (0,0), (img.shape[1],img.shape[0]), color, linewidth)
        cv2.line(img, (0,img.shape[0]), (img.shape[1],0), color, linewidth)
    
    update_label()

    image_widget.value = imencode(img,ImageFormat.PNG)

HBox(children=(Button(description='Prev', style=ButtonStyle()), IntSlider(value=0, max=7108), Button(descripti…

Image(value=b'')

In [137]:
# Note stopped at 101

In [143]:
pprint.pprint(bad_sequences,compact=True, width=140)
pprint.pprint(indices_use_nn_prediction,compact=True)

{1, 11, 12, 19, 21, 24, 28, 29, 32, 33, 34, 35, 44, 49, 51, 52, 53, 58, 59, 64, 65, 72, 73, 78, 82, 83, 86, 88, 89, 92, 97, 100, 107, 110,
 120, 122, 124, 126, 130, 131, 133, 144, 146, 147, 149, 154, 158, 159, 166, 167, 168, 170, 180, 181, 185, 186, 189, 190, 191, 192, 196, 197,
 198, 199, 200, 201, 202, 203, 204, 205, 222, 229, 237, 240, 249, 252, 255, 257, 258, 259, 260, 262, 263, 264, 266, 270, 272, 273, 276, 278,
 279, 280, 283, 285, 287, 289, 291, 292, 298, 300, 301, 307, 308, 309, 310, 311, 316, 325, 326, 330, 331, 332, 336, 341, 342, 345, 354, 355,
 359, 361, 362, 365, 369, 378, 383, 384, 387, 389, 390, 392, 393, 394, 396, 398, 399, 401, 403, 404, 405, 408, 409, 412, 413, 414, 418, 423,
 429, 430, 431, 438, 443, 444, 446, 450, 453, 461, 462, 472, 477, 478, 479, 486, 487, 488, 489, 490, 495, 497, 499, 503, 504, 505, 509, 511,
 512, 514, 521, 524, 525, 526, 531, 532, 533, 535, 539, 540, 544, 545, 548, 549, 550, 551, 557, 558, 564, 566, 573, 578, 579, 588, 589, 592,
 594, 595, 600,

In [148]:
good_indices = np.setdiff1d(
    np.nonzero(has_one_face & (~mask_small))[0],
    bad_sequences)

In [155]:
len(good_indices)

7109

In [160]:
good_indices_file = 'lapa_megaface_for_3d_rot_aug_v3.txt'
augmented_filename = os.path.join(os.environ['DATADIR'],'lapa-megaface-augmented.h5')
filename = os.path.join(os.environ['DATADIR'],'lapa-megaface.h5')
# Temporary data
destination = os.path.join('/tmp','lapa-megaface_augmented_w_offline_fits.h5')
filtered_destination = os.path.join('/tmp','lapa-megaface_augmented_good_fitted_faces.h5')

In [157]:
with open(os.path.join(os.environ['DATADIR'],good_indices_file),'w',encoding='utf-8') as f:
    f.write(','.join(map(str,good_indices)))

In [158]:
#with open(os.path.join(os.environ['DATADIR'],good_indices_file),'r',encoding='utf-8') as f:
#    good_indices = np.array([int(x) for x in f.read().split(',')])

In [161]:
with h5py.File(filename, 'r') as f_input, h5py.File(destination, 'w') as f_output:
    from_to = [
        ('images','images'),
        ('rois','rois'),
        ('2dfit_v2/quats','quats'),
        ('2dfit_v2/coords','coords'),
        ('2dfit_v2/pt3d_68','pt3d_68'),
        ('2dfit_v2/shapeparams','shapeparams'),
    ]
    for from_, to in from_to:
        f_input.copy(from_, f_output, to)
with h5py.File(destination, 'r') as f_output, h5py.File(filtered_destination, 'w') as f_filtered:
    filter_file_by_frames(f_output, f_filtered,good_frame_indices=good_indices)

In [162]:
def as_rotaug_sample(sample):
    fields = dict(sample)
    fields['rot'] = Rotation.from_quat(fields.pop('pose'))
    xys = fields.pop('coord')
    fields['xy'] = xys[:2]
    fields['scale'] = xys[2]
    fields['image'] = np.asarray(fields.pop('image'))
    return fields

In [163]:
rng = np.random.RandomState(seed=12345678)

In [164]:
augds = Hdf5PoseDataset(filtered_destination, transform=dtr.to_numpy, monochrome=False)

In [171]:
# TODO: fix it not updating the display
#visualizer = face3drotationaugmentation.SampleVisualizerWindow()

In [172]:
with face3drotationaugmentation.dataset_writer(augmented_filename) as writer:
    for i, sample in enumerate(tqdm.tqdm(map(as_rotaug_sample, augds), total=len(augds))):
        gen = face3drotationaugmentation.augment_sample(rng=rng, angle_step=5., prob_closed_eyes=0.5, prob_spotlight=0.001, sample=sample)
        name  = f'sample{i:02d}'
        del sample['index']
        writer.write(name, sample)
        for new_sample in gen:
            writer.write(name,new_sample)

100%|█████████████████████████████████████| 7109/7109 [3:04:56<00:00,  1.56s/it]
