In [1]:
import os
import json
import pickle
import numpy as np
from tqdm.auto import tqdm

In [2]:
def load_dance(path, source):
    if source == 'dancerevolution':
        with open(path) as f:
            raw_dict = json.loads(f.read())
            dance = raw_dict['dance_array'] # dance is a list with shape 1800(frames)*274(keypoints)
        return np.array(dance)
    
    elif source == 'aist++':
        with open(path, 'rb') as f:
            dance = pickle.load(f)
        return dance

def rescale(x, src):
    w, h = (1280, 720) if src == 'dancerevolution' else (1920, 1080)
    x[0] = (x[0] + 1) * 0.5 * w
    x[1] = (x[1] + 1) * 0.5 * h
    return x

def get_distance(x1, x2, src):
    
    x1 = rescale(x1, src)
    x2 = rescale(x2, src)
    
    distance = (x1[0] - x2[0])**2 + (x1[1] - x2[1])**2
    distance = distance**0.5
    
    return distance

def get_mae(x1, x2, src):
    x1 = rescale(x1, src)
    x2 = rescale(x2, src)
    return abs((x2-x1)).sum()/2.0

### Calculate max/min/avg pixels shift for original and jittered sequences for Dance Revolution dataset

In [88]:
origin_path = '/home/dingxi/AIST++/converted'
jitter_path = '/home/dingxi/AIST++/02sigma'
source = 'aist++'

num_node = 25 if source == 'dancerevolution' else 17
target_length = 1800 if source == 'dancerevolution' else 2878

_min, _max, _mean, count = 2000, 0, 0, 0
file_list = os.listdir(origin_path)
for file in tqdm(file_list):
    origin_dance_path = os.path.join(origin_path, file)
    jitter_dance_path = os.path.join(jitter_path, file)
    
    origin_dance = load_dance(origin_dance_path, source)
    jitter_dance = load_dance(jitter_dance_path, source)
    
    origin_dance = np.array(origin_dance).reshape(-1, num_node ,2)
    jitter_dance = np.array(jitter_dance).reshape(-1, num_node, 2)
    
    for frame in range(0, target_length):
        if origin_dance[frame].any():
            for joint in range(0, num_node):
                distance = get_distance(origin_dance[frame][joint], jitter_dance[frame][joint], source)
                _min = min(distance, _min)
                _max = max(distance, _max)
                _mean += distance
                count += 1
        else:
            break

_mean = _mean / count

print(_mean)
print(_max)
print(_min)

  0%|          | 0/1504 [00:00<?, ?it/s]

27.562004671205965
808.4138505872069
0.0


### Caculate MAE for original and jittered sequences for AIST++ dataset  

Steps:
1. For each point, calculate the MAE between clean(x,y) and noisy(x,y)
2. Repeat this for all points at all frames to get the average

In [3]:
def calculate_MAE_per_point(origin_path, jitter_path, source):
    num_node = 25 if source == 'dancerevolution' else 17
    target_length = 1800 if source == 'dancerevolution' else 2878
    mean = 0
    count = 0
    
    file_list = os.listdir(origin_path)
    for file in tqdm(file_list):
        origin_dance_path = os.path.join(origin_path, file)
        jitter_dance_path = os.path.join(jitter_path, file)

        origin_dance = load_dance(origin_dance_path, source)
        jitter_dance = load_dance(jitter_dance_path, source)

        origin_dance = origin_dance.reshape(-1, num_node, 2)
        jitter_dance = jitter_dance.reshape(-1, num_node, 2)

        for frame in range(0, target_length):
            if origin_dance[frame].any():
                for joint in range(0, num_node):
                    x1 = origin_dance[frame][joint]
                    x2 = jitter_dance[frame][joint]
                    mean += get_mae(x1, x2, source)
                    count += 1
            else:
                break

    print(mean/count)

In [8]:
origin_path = '/home/dingxi/AIST++/converted'
jitter_path = '/home/dingxi/AIST++/03discard_w30_d15/bcurve'
calculate_MAE_per_point(origin_path, jitter_path, 'aist++')

  0%|          | 0/1504 [00:00<?, ?it/s]

7.54476575251309


In [5]:
origin_path = '/home/dingxi/AIST++/converted'
jitter_path = '/home/dingxi/AIST++/04discard/linear'
source = 'aist++'
mean = 0
count = 0

file_list = os.listdir(origin_path)
for file in tqdm(file_list):
    origin_dance_path = os.path.join(origin_path, file)
    jitter_dance_path = os.path.join(jitter_path, file)
    
    origin_dance = load_dance(origin_dance_path, source)
    jitter_dance = load_dance(jitter_dance_path, source)
    
    origin_dance = origin_dance.reshape(-1, 17, 2)
    jitter_dance = jitter_dance.reshape(-1, 17, 2)
    
    for frame in range(0, 2878):
        if origin_dance[frame].any():
            for joint in range(0, 17):
                x1 = origin_dance[frame][joint]
                x2 = jitter_dance[frame][joint]
                mean += get_mae(x1, x2, source)
                count += 1
        else:
            break

print(mean/count)

  0%|          | 0/1504 [00:00<?, ?it/s]

1.292664935177449


In [5]:
from dataset_holder import DanceRevolutionHolder
from dataset import DanceRevolutionDataset

origin_path = '/home/dingxi/AIST++/converted'
jitter_path = '/home/dingxi/AIST++/04discard/bcurve'
source = 'aist++'
mean = 0
count = 0

file_list = os.listdir(origin_path)
holder = DanceRevolutionHolder(jitter_path, 'train', source='aist++', file_list=file_list, train_interval=2878)
dataset = DanceRevolutionDataset(holder, data_in='bcurve', bez_degree=15, window=45)

for i in tqdm(range(len(file_list))):
    data = dataset[i]
    filename = data[-1]['filename']
    origin_dance_path = os.path.join(origin_path, filename)
    
    origin_dance = load_dance(origin_dance_path, source)
    origin_dance = origin_dance.reshape(-1, 17, 2)
    
    jitter_dance = data[0].squeeze().transpose((1, 2, 0))
    
    for frame in range(0, 2878):
        if origin_dance[frame].any():
            for joint in range(0, 17):
                x1 = origin_dance[frame][joint]
                x2 = jitter_dance[frame][joint]
                mean += get_mae(x1, x2, source)
                count += 1
        else:
            break
print(mean/count)

ValueError: could not broadcast input array from shape (2,2701,17,1) into shape (2,2669,17,1)

### Find the actual length of Bezier curve interpolated dance

In [8]:
origin_path = '/home/dingxi/AIST++/converted'
jitter_path = '/home/dingxi/AIST++/04discard/bcurve'
source = 'aist++'
mean = 0
count = 0

file = 'gKR_sBM_c1_d29_mKR5_ch01.pkl'

holder = DanceRevolutionHolder(jitter_path, 'train', source='aist++', file_list=[file], train_interval=2878)
dataset = DanceRevolutionDataset(holder, data_in='bcurve', bez_degree=15, window=45)

origin_dance_path = os.path.join(origin_path, file)
    
origin_dance = load_dance(origin_dance_path, source)
origin_dance = origin_dance.reshape(-1, 17, 2)

data = dataset[0]
jitter_dance = data[0].squeeze().transpose((1, 2, 0))

for frame in range(0, 2878):
    if origin_dance[frame].any():
        continue
    print(frame)
    break
for frame in range(0, 2878):
    if jitter_dance[frame].any():
        continue
    print(frame)
    break

443
476
