In [118]:
import json
import os
from itertools import chain, product
from typing import Callable, Dict, List, Tuple

import numpy as np
import torch

from src.models.modules.shape_embedding import ShapeEmbedding

In [120]:
datasets = os.listdir("./data/")

In [122]:
keypoints = {
    "NOSE": [0],
    "LEFT_EYE": [5],
    "RIGHT_EYE": [5],
    "LEFT_EAR": [8],
    "RIGHT_EAR": [7],
    "CHEST": [11, 12],
    "LEFT_SHOULDER": [12],
    "RIGHT_SHOULDER": [11],
    "LEFT_ELBOW": [14],
    "RIGHT_ELBOW": [13],
    "LEFT_HAND": [16, 18, 20, 22],
    "RIGHT_HAND": [15, 17, 19, 21],
    "LEFT_HIP": [24],
    "RIGHT_HIP": [23],
    "LEFT_KNEE": [26],
    "RIGHT_KNEE": [25],
    "LEFT_FOOT": [28, 30, 32],
    "RIGHT_FOOT": [27, 29, 31]
}

positions_map = dict(NOSE=0,
                     CHEST=1,
                     LEFT_SHOULDER=2,
                     LEFT_ELBOW=3,
                     LEFT_HAND=4,
                     RIGHT_SHOULDER=5,
                     RIGHT_ELBOW=6,
                     RIGHT_HAND=7,
                     LEFT_HIP=8,
                     LEFT_KNEE=9,
                     LEFT_FOOT=10,
                     RIGHT_HIP=11,
                     RIGHT_KNEE=12,
                     RIGHT_FOOT=13,
                     LEFT_EYE=14,
                     RIGHT_EYE=15,
                     LEFT_EAR=16,
                     RIGHT_EAR=17)


In [124]:
def get_features(pose: List[Dict]) -> torch.Tensor:
    list_pose = list(map(lambda s: list(s.values()), pose))
    return torch.Tensor(list_pose)


def query_and_sort_in_position(pose: List) -> torch.Tensor:
    out = {}
    re_align_positions = []

    for k, v in keypoints.items():
        out[k] = pose[v].mean(dim=0)

    for k in positions_map.keys():
        re_align_positions.append(out[k].unsqueeze(0))
    re_align_positions = torch.cat(re_align_positions)
    return re_align_positions.numpy().tolist()


def assign_new_data(s: Dict) -> Dict:
    pose: List[List[Dict]] = s['pose_landmarks'][0]
    pose_world: List[List[Dict]] = s['pose_world_landmarks']
    pose = get_features(pose=pose)
    pose = query_and_sort_in_position(pose=pose)
    pose_world = get_features(pose=pose_world)
    pose_world = query_and_sort_in_position(pose=pose_world)
    s['pose_landmarks'] = pose
    s['pose_world_landmarks'] = pose_world
    return s

In [128]:
for dataset in datasets:
    for file in ['gallery.json', 'query.json', 'train.json', 'test.json']:
        try:
            with open(f"./data/{dataset}/jsons/{file}") as f:
                content = json.load(f)
                content = list(
                    filter(lambda s: s.get("pose_landmarks", None) is not None,
                        content))
                list(map(assign_new_data, content))
            with open(f"./data/{dataset}/jsons/{file}", "w") as f:
                json.dump(obj=content, fp=f, indent=2, ensure_ascii=False)
        except:
            print(dataset)
            print(file)

ltcc
gallery.json
ltcc
train.json
cuhk03
train.json
cuhk03
test.json
real28
train.json
real28
test.json
vc-clothes
train.json
vc-clothes
test.json
prcc
gallery.json
prcc
query.json
prcc
train.json
market1501
gallery.json
market1501
query.json
market1501
train.json
market1501
test.json
