## Visualize Data

In [46]:
from pathlib import Path
import numpy as np

data_path = Path("data/mariel_chunli.npy")
data_raw = np.load(data_path)

def inspect_data(data):
    print("shape", data.shape)
    print("dtype", data.dtype)
    print("max", data.max())
    print("min", data.min())
    print("1q", np.percentile(data, 25))
    print("2q", np.percentile(data, 50))
    print("3q", np.percentile(data, 75))
    print("mean", data.mean())
    print("std", data.std())

inspect_data(data_raw)

shape (55, 4866, 3)
dtype float64
max 1.9273922443389893
min -6.332875728607178
1q -1.006934016942978
2q -0.4554957449436188
3q -0.042966997250914574
mean -0.6010307200624114
std 0.8776853231783209


In [47]:
# Taken from https://github.com/mariel-pettee/choreo-graph/blob/main/functions/load_data.py

ALL_POINT_LABELS = ['ARIEL', 'C7', 'CLAV', 'LANK', 'LBHD', 'LBSH', 'LBWT', 'LELB', 'LFHD', 'LFRM', 'LFSH', 'LFWT', 'LHEL', 'LIEL', 'LIHAND', 'LIWR', 'LKNE', 'LKNI', 'LMT1', 'LMT5', 'LOHAND', 'LOWR', 'LSHN', 'LTHI', 'LTOE', 'LUPA', 'LabelingHips', 'MBWT', 'MFWT', 'RANK', 'RBHD', 'RBSH', 'RBWT', 'RELB', 'RFHD', 'RFRM', 'RFSH', 'RFWT', 'RHEL', 'RIEL', 'RIHAND', 'RIWR', 'RKNE', 'RKNI', 'RMT1', 'RMT5', 'ROHAND', 'ROWR', 'RSHN', 'RTHI', 'RTOE', 'RUPA', 'STRN', 'SolvingHips', 'T10']    
BAD_LABELS = ['SolvingHips', 'LabelingHips']
POINT_LABELS = [label for label in ALL_POINT_LABELS if label not in BAD_LABELS]
NUM_GROUPS = len(POINT_LABELS)

skeleton_lines = [
#     ( (start group), (end group) ),
    ('LHEL', 'LTOE',), # toe to heel
    ('RHEL', 'RTOE',),
    ('LMT1', 'LMT5',), # horizontal line across foot
    ('RMT1', 'RMT5',),   
    ('LHEL', 'LMT1',), # heel to sides of feet
    ('LHEL', 'LMT5',),
    ('RHEL', 'RMT1',),
    ('RHEL', 'RMT5',),
    ('LTOE', 'LMT1',), # toe to sides of feet
    ('LTOE', 'LMT5',),
    ('RTOE', 'RMT1',),
    ('RTOE', 'RMT5',),
    ('LKNE', 'LHEL',), # heel to knee
    ('RKNE', 'RHEL',),
    ('LFWT', 'RBWT',), # connect pelvis
    ('RFWT', 'LBWT',), 
    ('LFWT', 'RFWT',), 
    ('LBWT', 'RBWT',),
    ('LFWT', 'LBWT',), 
    ('RFWT', 'RBWT',), 
    ('LFWT', 'LTHI',), # pelvis to thighs
    ('RFWT', 'RTHI',), 
    ('LBWT', 'LTHI',), 
    ('RBWT', 'RTHI',), 
    ('LKNE', 'LTHI',), 
    ('RKNE', 'RTHI',), 
    ('CLAV', 'LFSH',), # clavicle to shoulders
    ('CLAV', 'RFSH',), 
    ('STRN', 'LFSH',), # sternum & T10 (back sternum) to shoulders
    ('STRN', 'RFSH',), 
    ('T10', 'LFSH',), 
    ('T10', 'RFSH',), 
    ('C7', 'LBSH',), # back clavicle to back shoulders
    ('C7', 'RBSH',), 
    ('LFSH', 'LBSH',), # front shoulders to back shoulders
    ('RFSH', 'RBSH',), 
    ('LFSH', 'RBSH',),
    ('RFSH', 'LBSH',),
    ('LFSH', 'LUPA',), # shoulders to upper arms
    ('RFSH', 'RUPA',), 
    ('LBSH', 'LUPA',), 
    ('RBSH', 'RUPA',), 
    ('LIWR', 'LIHAND',), # wrist to hand
    ('RIWR', 'RIHAND',),
    ('LOWR', 'LOHAND',), 
    ('ROWR', 'ROHAND',),
    ('LIWR', 'LOWR',), # across the wrist 
    ('RIWR', 'ROWR',), 
    ('LIHAND', 'LOHAND',), # across the palm 
    ('RIHAND', 'ROHAND',), 
    ('LFHD', 'LBHD',), # draw lines around circumference of the head
    ('LBHD', 'RBHD',),
    ('RBHD', 'RFHD',),
    ('RFHD', 'LFHD',),
    ('LFHD', 'ARIEL'), # connect circumference points to top of head
    ('LBHD', 'ARIEL'),
    ('RBHD', 'ARIEL'),
    ('RFHD', 'ARIEL'),
]

POINT_IDXS = {label: i for i, label in enumerate(POINT_LABELS)}
EDGES = [(POINT_IDXS[start], POINT_IDXS[end]) for start, end in skeleton_lines]
EDGES = np.array(EDGES)
EDGES.shape

(58, 2)

In [115]:
# For some reason, there are some frames that are the same as the previous frame

START_IDX = {}

for fn in ["betternot_and_retrograde", "beyond", "chunli", "honey", "knownbetter", "penelope"]:
    p = Path(f"data/mariel_{fn}.npy")
    data_i = np.load(p)
    data_i = data_i.swapaxes(0, 1)
    for i in range(1, data_i.shape[0]):
        if (data_i[i] != data_i[i-1]).all():
            print(fn, i)
            START_IDX[fn] = i
            break

betternot_and_retrograde 159
beyond 97
chunli 188
honey 124
knownbetter 284
penelope 150


First dimension is the edge group.
Second is the frames for a given clip.
Third is the XYZ coordinates in 3D space. 

We know the edge groups from [Pettee's previous project](https://github.com/mariel-pettee/choreography/blob/master/functions/functions.py), and we also know that groups 27 and 54 are bad edge groups. 

We reshape and mask those out. Additionally, we preprocess the data to be in the range -1 to 1.

In [49]:
def preprocess_data(data, normalize=True):
    bad_groups = [i for i, group in enumerate(ALL_POINT_LABELS) if group in BAD_LABELS]
    group_mask = np.ones(data.shape[0], dtype=bool)
    group_mask[bad_groups] = False
    data = data[group_mask]

    data = data.swapaxes(0, 1)

    if normalize:
        min_val = data.min()
        max_val = data.max()
        data = (data - min_val) / (max_val - min_val) * 2 - 1

    data[:, :,  2] *= -1        # invert z axis

    return data

data = preprocess_data(data_raw, normalize=True)
inspect_data(data)

shape (4866, 53, 3)
dtype float64
max 1.0
min -1.0
1q -0.2738534668068461
2q 0.3330744535997223
3q 0.5042046208517127
mean 0.15681260326528199
std 0.4136803405246239


In [68]:
%matplotlib QtAgg

from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d import Axes3D

class FigureAnimation:
    def __init__(self, data):
        self.data = data

        self.fig = plt.figure()
        self.ax = self.fig.add_subplot(projection='3d')

        self.ax.set_xlim(data[:, :, 0].min(), data[:, :, 0].max())
        self.ax.set_ylim(data[:, :, 1].min(), data[:, :, 1].max())
        self.ax.set_zlim(data[:, :, 2].min(), data[:, :, 2].max())

        self.scatter_plot = None
        self.lineplots = None

    def start(self, start_frame: int = 0, end_frame: int = -1, framerate: int = 20):
        if end_frame == -1:
            end_frame = self.data.shape[0]

        def setup():
            snapshot = self.data[start_frame]
            self.scatter_plot = self.ax.scatter(snapshot[:, 0], snapshot[:, 1], snapshot[:, 2])

            start_edges, end_edges = EDGES[:, 0], EDGES[:, 1]
            lines = np.stack([snapshot[start_edges], snapshot[end_edges]], axis=1)
            
            self.lineplots = [self.ax.plot(line[:, 0], line[:, 1], line[:, 2], color='black')[0] for line in lines]

            return self.scatter_plot, *self.lineplots

        def update(frame):
            snapshot = self.data[int(frame)]
            self.scatter_plot._offsets3d = (snapshot[:, 0], snapshot[:, 1], snapshot[:, 2])

            start_edges, end_edges = EDGES[:, 0], EDGES[:, 1]
            lines = np.stack([snapshot[start_edges], snapshot[end_edges]], axis=1)
            for lineplot, line in zip(self.lineplots, lines):
                lineplot.set_data(line[:, 0], line[:, 1])
                lineplot.set_3d_properties(line[:, 2]) 

            self.fig.canvas.draw_idle()

            if frame == end_frame - 1:
                self.ani.event_source.stop()

            return self.scatter_plot, *self.lineplots

        self.ani = FuncAnimation(self.fig, update, init_func=setup, frames=range(start_frame, end_frame), interval=1000/framerate, blit=False, cache_frame_data=False)
        self.fig.show()


#fig_anim = FigureAnimation(data)
#fig_anim.start(start_frame=150, framerate=35)

In [51]:
def split_data(data, window_size, stride):
    # data: (n_frames, n_points, 3)
    n_frames = data.shape[0]
    n_points = data.shape[1]
    n_channels = data.shape[2]

    n_windows = (n_frames - window_size) // stride + 1
    windows = np.zeros((n_windows, window_size, n_points, n_channels))
    
    for i in range(n_windows):
        start = i * stride
        end = start + window_size
        windows[i] = data[start:end]

    return windows

data.shape, split_data(data, window_size=30, stride=5).shape

((4866, 53, 3), (968, 30, 53, 3))

In [None]:
def get_data_with_pattern(data_dir: Path, filenamn_pattern: str = "mariel", window_size: int = 30, stride: int = 5, normalize: bool = True):
    all_data = []

    for data_path in data_dir.glob(filenamn_pattern):
        data_raw = np.load(data_path)
        data = preprocess_data(data_raw, normalize=normalize)
        
        name = data_path.stem[len("mariel_"):]
        start_idx = START_IDX.get(name, 0)
        data = data[start_idx:]

        windows = split_data(data, window_size, stride)

        all_data.append(windows)

    all_data = np.concatenate(all_data, axis=0)
    return all_data

data_dir = Path("data")
all_data = get_data_with_pattern(data_dir, filenamn_pattern="mariel_better*", window_size=30, stride=5, normalize=True)
all_data.shape

(2148, 30, 53, 3)

In [119]:
fig_anim = FigureAnimation(all_data[3])
fig_anim.start(framerate=60)

## Network

In [53]:
# Pose Embedding

from torch import nn

class PoseEncoder(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        
        # Imput: (N, num_groups, 3)
        # Output: (N, embedding_dim)

        self.embedding = nn.Sequential(
            nn.Linear(NUM_GROUPS * 3, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, embedding_dim)
        )
    
    def forward(self, x):
        return self.embedding(x)


class PoseDecoder(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()

        # Imput: (N, embedding_dim)
        # Output: (N, num_groups, 3)

        self.decoder = nn.Sequential(
            nn.Linear(embedding_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, NUM_GROUPS * 3),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.decoder(x).view(-1, NUM_GROUPS, 3)

class PoseAutoencoder(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.encoder = PoseEncoder(embedding_dim)
        self.decoder = PoseDecoder(embedding_dim)

    def forward(self, x):
        return self.decoder(self.encoder(x))
