In [None]:
# default_exp core

# core

> Unpack and load the [AMASS][] dataset for training with a PyTorch iterator.

To do:

1. Saw `AMASS` and `global_index_map` break with relative paths, add test and make fix for this case
2. Add options to pass `keep` index to `AMASS`

[amass]: https://amass.is.tue.mpg.de/

In [None]:
# hide
from nbdev.showdoc import *

# Unpack Tar Files

> Console script to unpack all tar files found in a specified directory and put them in another directory, then create a symlink to be able to find the unpacked data later

In [None]:
# export
import argparse
import os
from shutil import unpack_archive
import joblib
from tqdm.auto import tqdm


class ProgressParallel(joblib.Parallel):
    def __call__(self, *args, **kwargs):
        with tqdm(total=kwargs["total"]) as self._pbar:
            del kwargs["total"]
            return joblib.Parallel.__call__(self, *args, **kwargs)

    def print_progress(self):
        self._pbar.total = self.n_dispatched_tasks
        self._pbar.n = self.n_completed_tasks
        self._pbar.refresh()


def unpack_body_models(tardir, outdir, n_jobs=1):
    tar_root, _, tarfiles = [x for x in os.walk(tardir)][0]
    tarfiles = [x for x in tarfiles if "tar" in x.split(".")]
    tarpaths = [os.path.join(tar_root, tar) for tar in tarfiles]
    for tarpath in tarpaths:
        print(f"{tarpath} extracting to {outdir}")
    ProgressParallel(n_jobs=n_jobs)(
        (joblib.delayed(unpack_archive)(tarpath, outdir) for tarpath in tarpaths),
        total=len(tarpaths),
    )


def fast_amass_unpack():
    parser = argparse.ArgumentParser(
        description="Unpack all the body model tar files in a directory to a target directory"
    )
    parser.add_argument(
        "tardir",
        type=str,
        help="Directory containing tar.bz2 body model files",
    )
    parser.add_argument(
        "outdir",
        type=str,
        help="Output directory",
    )
    parser.add_argument(
        "-n",
        default=1,
        type=int,
        help="Number of jobs to run the tar unpacking with",
    )
    args = parser.parse_args()
    unpack_body_models(args.tardir, args.outdir, n_jobs=args.n)

Test unpacking the sample data always yields the same result:

In [None]:
import tempfile
import hashlib

# https://stackoverflow.com/a/3431838/6937913
def md5(fname):
    hash_md5 = hashlib.md5()
    with open(fname, "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            hash_md5.update(chunk)
    return hash_md5.hexdigest()


md5sums = {
    "amass_sample.npz": "d0b546b3619c8579ade39e3a8ccdc4e2",
    "dmpl_sample.npz": "576bb76b2a6328dc5c276c4150c466f0",
}

with tempfile.TemporaryDirectory() as tmpdirname:
    unpack_body_models("sample_data/", tmpdirname, 8)
    for r, d, f in os.walk(tmpdirname):
        npz_files = [x for x in f if "npz" in x.split(".")]
        npz_paths = [os.path.join(tmpdirname, r, x) for x in npz_files]
    _md5sums = {os.path.split(fpath)[-1]: md5(fpath) for fpath in npz_paths}

for k in md5sums:
    assert md5sums[k] == _md5sums[k]

sample_data/sample.tar.bz2 extracting to /tmp/tmp7wtjmhdp


  0%|          | 0/1 [00:00<?, ?it/s]

# Loading Functions

> Load the pose data directly from the `npz` files after unpacking.

Based on the [AMASS tutorial notebooks][amass], I would like to iterate over the dataset using a PyTorch Dataloader.

Steps to load:

1. Enumerate the paths to all the `npz` files
2. Inspect files for frame data
3. Map from a global dataset index to indexes for each frame clip

[amass]: https://github.com/nghorbani/amass/tree/master/notebooks

In [None]:
# export
import math
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

Looking at the sample data:

In [None]:
with tempfile.TemporaryDirectory() as tmpdirname:
    unpack_body_models("sample_data/", tmpdirname, 8)
    for r, d, f in os.walk(tmpdirname):
        npz_files = [x for x in f if "npz" in x.split(".")]
        npz_paths = [os.path.join(tmpdirname, r, x) for x in npz_files]
    for npz_path in npz_paths:
        cdata = np.load(npz_path)
        print(npz_path)
        print("  ", [k for k in cdata.keys()])
        print("  ", [(k, cdata[k].shape) for k in cdata.keys()])

sample_data/sample.tar.bz2 extracting to /tmp/tmpsd5dr57v


  0%|          | 0/1 [00:00<?, ?it/s]

/tmp/tmpsd5dr57v/sample/subdir/amass_sample.npz
   ['poses', 'gender', 'mocap_framerate', 'betas', 'marker_data', 'dmpls', 'marker_labels', 'trans']
   [('poses', (601, 156)), ('gender', ()), ('mocap_framerate', ()), ('betas', (16,)), ('marker_data', (601, 85, 3)), ('dmpls', (601, 8)), ('marker_labels', (85,)), ('trans', (601, 3))]
/tmp/tmpsd5dr57v/sample/subdir/dmpl_sample.npz
   ['poses', 'gender', 'mocap_framerate', 'betas', 'marker_data', 'dmpls', 'marker_labels', 'trans']
   [('poses', (235, 156)), ('gender', ()), ('mocap_framerate', ()), ('betas', (16,)), ('marker_data', (235, 67, 3)), ('dmpls', (235, 8)), ('marker_labels', (67,)), ('trans', (235, 3))]


## Viable Indexes

For every `npz` file I need to pull out the viable indexes:

In [None]:
# export
def viable_slice(cdata, keep):
    """
    Inspects a dictionary loaded from `.npz` numpy dumps
    and creates a slice of the viable indexes.
    args:

        - `cdata`: dictionary containing keys:
            ['poses', 'gender', 'mocap_framerate', 'betas',
             'marker_data', 'dmpls', 'marker_labels', 'trans']
        - `keep`: ratio of the file to keep, between zero and 1.,
            drops leading and trailing ends of the arrays

    returns:

        - viable: slice that can access frames in the arrays:
            cdata['poses'], cdata['marker_data'], cdata['dmpls'], cdata['trans']
    """
    assert (
        keep > 0.0 and keep <= 1.0
    ), "Proportion of array to keep must be between zero and one"
    n = cdata["poses"].shape[0]
    drop = (1.0 - keep) / 2.0
    return slice(int(n * drop), int(n * keep + n * drop))

In [None]:
with tempfile.TemporaryDirectory() as tmpdirname:
    unpack_body_models("sample_data/", tmpdirname, 8)
    for r, d, f in os.walk(tmpdirname):
        npz_files = [x for x in f if "npz" in x.split(".")]
        npz_paths = [os.path.join(tmpdirname, r, x) for x in npz_files]
    for npz_path in npz_paths:
        cdata = np.load(npz_path)
        print(npz_path)
        print("  ", viable_slice(cdata, 0.8))

sample_data/sample.tar.bz2 extracting to /tmp/tmp17pnar6l


  0%|          | 0/1 [00:00<?, ?it/s]

/tmp/tmp17pnar6l/sample/subdir/amass_sample.npz
   slice(60, 540, None)
/tmp/tmp17pnar6l/sample/subdir/dmpl_sample.npz
   slice(23, 211, None)


## Map Global Index to File Indexes

I need to be able to map from a global dataset index to contiguous sets of frames in each file. Options:

1. Whether the contiguous frames can overlap in different samples

In [None]:
# export
def global_index_map(npz_directory, overlapping, clip_length, keep=0.8):
    """
    args:
        - `npz_directory`: Directory containing `.npz` files
        - `overlapping`: Whether clips can overlap
        - `clip_length`:
    returns:
        - map from global index to corresponding file and array indexes
    """
    for r, d, f in os.walk(npz_directory):
        npz_files = [x for x in f if "npz" in x.split(".")]
        npz_paths = [os.path.join(npz_directory, r, x) for x in npz_files]
    # array slices for each file
    viable_slices = {
        npz_path: viable_slice(np.load(npz_path), keep=keep) for npz_path in npz_paths
    }
    # clip index -> array index
    def clip_to_array_index(i, array_slice):
        if not overlapping:
            i = i * clip_length
        return [i + array_slice.start + j for j in range(clip_length)]

    # length of a slice
    def lenslice(s):
        if overlapping:
            return (s.stop - s.start) - (clip_length - 1)
        else:
            return math.floor((s.stop - s.start) / clip_length)

    # global index -> file, relative index
    def find_array(i):
        global_index, j = 0, 0
        for npz_path in viable_slices:
            global_index += lenslice(viable_slices[npz_path])
            if i < global_index:
                return npz_path, i - j
            j = global_index

    # how many examples are there in this dataset, total
    n_examples = sum(lenslice(viable_slices[npz_path]) for npz_path in viable_slices)
    # create map function
    def global_to_array(i):
        npz_path, j = find_array(i)
        return npz_path, clip_to_array_index(j, viable_slices[npz_path])

    return global_to_array, n_examples

Testing for clip length 1 and non-overlapping clips:

In [None]:
def test_global_index_map(clip_length, overlapping):
    with tempfile.TemporaryDirectory() as tmpdirname:
        unpack_body_models("sample_data/", tmpdirname, 8)
        global_to_array, n_examples = global_index_map(
            tmpdirname, overlapping=overlapping, clip_length=clip_length
        )
        print(f"Number of examples in dataset: {n_examples}")
        count = 0
        prev_j = [-1] * clip_length
        for i in range(n_examples):
            npz_path, j = global_to_array(i)
            cdata = np.load(npz_path)
            assert len(j) == clip_length
            assert cdata["poses"][j] is not None
            assert all(k > 0 for k in j)
            if not overlapping:
                assert all(k not in prev_j for k in j)
                prev_j = j
            count += 1
        assert count == n_examples


test_global_index_map(1, False)

sample_data/sample.tar.bz2 extracting to /tmp/tmp1lm2xiom


  0%|          | 0/1 [00:00<?, ?it/s]

Number of examples in dataset: 668


Testing for clip length greater than 1 and non-overlapping:

In [None]:
test_global_index_map(3, False)
test_global_index_map(4, False)

sample_data/sample.tar.bz2 extracting to /tmp/tmp0tf97i2_


  0%|          | 0/1 [00:00<?, ?it/s]

Number of examples in dataset: 222
sample_data/sample.tar.bz2 extracting to /tmp/tmpbahxxekr


  0%|          | 0/1 [00:00<?, ?it/s]

Number of examples in dataset: 167


Testing for length of clip less than 1 and overlapping:

In [None]:
test_global_index_map(3, True)
test_global_index_map(4, True)

sample_data/sample.tar.bz2 extracting to /tmp/tmpdnfak_ly


  0%|          | 0/1 [00:00<?, ?it/s]

Number of examples in dataset: 664
sample_data/sample.tar.bz2 extracting to /tmp/tmpty_xfg6j


  0%|          | 0/1 [00:00<?, ?it/s]

Number of examples in dataset: 662


## Load Data from `.npz` Files

I want to load the data from the `.npz` files in a standard way, so I'm going to load each entry into its own array.

In [None]:
# export
def load_npz(npz_path, indexes):
    cdata = np.load(npz_path)

    # unpack and enforce data type
    poses = cdata["poses"][indexes].astype(np.float32)
    dmpls = cdata["dmpls"][indexes].astype(np.float32)
    trans = cdata["trans"][indexes].astype(np.float32)
    betas = np.repeat(
        cdata["betas"][np.newaxis].astype(np.float32), repeats=len(indexes), axis=0
    )

    def gender_to_int(g):
        # casting gender to integer will raise a warning in future
        g = str(g.astype(str))
        return {"male": -1, "neutral": 0, "female": 1}[g]

    gender = np.array([gender_to_int(cdata["gender"]) for _ in indexes])

    return dict(poses=poses, dmpls=dmpls, trans=trans, betas=betas, gender=gender)

Test this works with different clip lengths and overlapping clips.

In [None]:
def test_load_npz(clip_length, overlapping):
    with tempfile.TemporaryDirectory() as tmpdirname:
        unpack_body_models("sample_data/", tmpdirname, 8)
        global_to_array, n_examples = global_index_map(
            tmpdirname, overlapping=overlapping, clip_length=clip_length
        )
        npz_path, indexes = global_to_array(0)
        data = load_npz(npz_path, indexes)
        for k in data:
            assert data[k].shape[0] == clip_length
        print([data[k].shape for k in data])


test_load_npz(1, False)
test_load_npz(3, False)
test_load_npz(3, True)

sample_data/sample.tar.bz2 extracting to /tmp/tmp522nzaaa


  0%|          | 0/1 [00:00<?, ?it/s]

[(1, 156), (1, 8), (1, 3), (1, 16), (1,)]
sample_data/sample.tar.bz2 extracting to /tmp/tmpd9q5dp3o


  0%|          | 0/1 [00:00<?, ?it/s]

[(3, 156), (3, 8), (3, 3), (3, 16), (3,)]
sample_data/sample.tar.bz2 extracting to /tmp/tmpbg16wz5r


  0%|          | 0/1 [00:00<?, ?it/s]

[(3, 156), (3, 8), (3, 3), (3, 16), (3,)]


# PyTorch Dataset Class

Creating a map-style PyTorch Dataset Class that uses these functions to load the data.

In [None]:
# export
class AMASS(Dataset):
    def __init__(self, unpacked_directory, clip_length, overlapping, transform=None):
        self.global_to_array, self.n_examples = global_index_map(
            unpacked_directory, overlapping=overlapping, clip_length=clip_length
        )
        self.transform = transform

    def __len__(self):
        return self.n_examples

    def __getitem__(self, i):
        data = load_npz(*self.global_to_array(i))
        return {k: self.transform(data[k]) for k in data}

Test I can load some data with this Dataset:

In [None]:
with tempfile.TemporaryDirectory() as tmpdirname:
    unpack_body_models("sample_data/", tmpdirname, 8)
    amass = AMASS(tmpdirname, overlapping=False, clip_length=1, transform=torch.tensor)
    data = amass[0]
    for k in data:
        print(k, data[k].shape)
        assert type(data[k]) is torch.Tensor

sample_data/sample.tar.bz2 extracting to /tmp/tmp_4vxaxa5


  0%|          | 0/1 [00:00<?, ?it/s]

poses torch.Size([1, 156])
dmpls torch.Size([1, 8])
trans torch.Size([1, 3])
betas torch.Size([1, 16])
gender torch.Size([1])


Test it works in a DataLoader to make batches:

In [None]:
with tempfile.TemporaryDirectory() as tmpdirname:
    unpack_body_models("sample_data/", tmpdirname, 8)
    amass = AMASS(tmpdirname, overlapping=False, clip_length=1, transform=torch.tensor)
    amasstrain = DataLoader(amass, batch_size=4, shuffle=True)
    for i, data in enumerate(amasstrain):
        if i == 0:
            for k in data:
                print(k, data[k].shape)
        assert data["poses"].size(0) == 4

sample_data/sample.tar.bz2 extracting to /tmp/tmpq_71o724


  0%|          | 0/1 [00:00<?, ?it/s]

poses torch.Size([4, 1, 156])
dmpls torch.Size([4, 1, 8])
trans torch.Size([4, 1, 3])
betas torch.Size([4, 1, 16])
gender torch.Size([4, 1])


In [None]:
#hide
from nbdev.export import notebook2script; notebook2script()