In [None]:
# default_exp core

# core

> Unpack and load the [AMASS][] dataset for training with a PyTorch iterator.

To do:

1. Saw `AMASS` and `global_index_map` break with relative paths, add test and make fix for this case
2. Add options to pass `keep` index to `AMASS`

[amass]: https://amass.is.tue.mpg.de/

In [None]:
# hide
from nbdev.showdoc import *

# Unpack Tar Files

> Console script to unpack all tar files found in a specified directory and put them in another directory, then create a symlink to be able to find the unpacked data later

## Checksum Directories

> Checksum directories to only unpack tar files when target directory either doesn't exist or has been incorrectly unpacked.

It would probably be sufficient to check if the target directory exists, but this is more thorough.

In [None]:
# export
# https://stackoverflow.com/a/54477583/6937913
import hashlib
from _hashlib import HASH as Hash
from pathlib import Path
from typing import Union


def md5_update_from_file(filename: Union[str, Path], hash: Hash) -> Hash:
    assert Path(filename).is_file()
    with open(str(filename), "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            hash.update(chunk)
    return hash


def md5_file(filename: Union[str, Path]) -> str:
    return str(md5_update_from_file(filename, hashlib.md5()).hexdigest())


def md5_update_from_dir(directory: Union[str, Path], hash: Hash) -> Hash:
    assert Path(directory).is_dir()
    for path in sorted(Path(directory).iterdir(), key=lambda p: str(p).lower()):
        hash.update(path.name.encode())
        if path.is_file():
            hash = md5_update_from_file(path, hash)
        elif path.is_dir():
            hash = md5_update_from_dir(path, hash)
    return hash


def md5_dir(directory: Union[str, Path]) -> str:
    return str(md5_update_from_dir(directory, hashlib.md5()).hexdigest())

In [None]:
#export
hashes = \
{'ACCAD.tar.bz2': {'unpacks_to': 'ACCAD',
  'hash': '193442a2ab66cb116932b8bce08ecb89'},
 'BMLhandball.tar.bz2': {'unpacks_to': 'BMLhandball',
  'hash': '8947df17dd59d052ae618daf24ccace3'},
 'BMLmovi.tar.bz2': {'unpacks_to': 'BMLmovi',
  'hash': '6dfb134273f284152aa2d0838d7529d5'},
 'CMU.tar.bz2': {'unpacks_to': 'CMU',
  'hash': 'f04bc3f37f3eafebfb12ba0cf706ca72'},
 'DFaust67.tar.bz2': {'unpacks_to': 'DFaust_67',
  'hash': '7e5f11ed897da72c5159ef3c747383b8'},
 'EKUT.tar.bz2': {'unpacks_to': 'EKUT',
  'hash': '221ee4a27a03afd1808cbb11af067879'},
 'HumanEva.tar.bz2': {'unpacks_to': 'HumanEva',
  'hash': 'ca781438b08caafd8a42b91cce905a03'},
 'KIT.tar.bz2': {'unpacks_to': 'KIT',
  'hash': '3813500a3909f6ded1a1fffbd27ff35a'},
 'MPIHDM05.tar.bz2': {'unpacks_to': 'MPI_HDM05',
  'hash': 'f76da8deb9e583c65c618d57fbad1be4'},
 'MPILimits.tar.bz2': {'unpacks_to': 'MPI_Limits',
  'hash': '72398ec89ff8ac8550813686cdb07b00'},
 'MPImosh.tar.bz2': {'unpacks_to': 'MPI_mosh',
  'hash': 'a00019cac611816b7ac5b7e2035f3a8a'},
 'SFU.tar.bz2': {'unpacks_to': 'SFU',
  'hash': 'cb10b931509566c0a49d72456e0909e2'},
 'SSMsynced.tar.bz2': {'unpacks_to': 'SSM_synced',
  'hash': '7cc15af6bf95c34e481d58ed04587b58'},
 'TCDhandMocap.tar.bz2': {'unpacks_to': 'TCD_handMocap',
  'hash': 'c500aa07973bf33ac1587a521b7d66d3'},
 'TotalCapture.tar.bz2': {'unpacks_to': 'TotalCapture',
  'hash': 'b2c6833d3341816f4550799b460a1b27'},
 'Transitionsmocap.tar.bz2': {'unpacks_to': 'Transitions_mocap',
  'hash': '705e8020405357d9d65d17580a6e9b39'},
 'EyesJapanDataset.tar.bz2': {'unpacks_to': 'Eyes_Japan_Dataset',
  'hash': 'd19fc19771cfdbe8efe2422719e5f3f1'},
 'BMLrub.tar.bz2': {'unpacks_to': 'BioMotionLab_NTroje',
  'hash': '8b82ffa6c79d42a920f5dde1dcd087c3'},
 'DanceDB.tar.bz2': {'unpacks_to': 'DanceDB',
  'hash': '9ce35953c4234489036ecb1c26ae38bc'}}

## Parallel Unpacking with Joblib

> Unpacks tar files in multiple jobs to speed up unpacking the dataset.



In [None]:
# export
import json
import argparse
import functools
import os
from shutil import unpack_archive
import joblib
from tqdm.auto import tqdm


class ProgressParallel(joblib.Parallel):
    def __call__(self, *args, **kwargs):
        with tqdm(total=kwargs["total"]) as self._pbar:
            del kwargs["total"]
            return joblib.Parallel.__call__(self, *args, **kwargs)

    def print_progress(self):
        self._pbar.n = self.n_completed_tasks
        self._pbar.refresh()

def lazy_unpack(tarpath, outdir):
    # check if this has already been unpacked by looking for hash file
    tarpath, outdir = Path(tarpath), Path(outdir)
    unpacks_to = hashes[tarpath.name]['unpacks_to']
    hashpath = outdir / Path(unpacks_to+'.hash')
    # if the hash exists and it's correct then assume the directory is correctly unpacked
    if hashpath.exists():
        with open(hashpath) as f:
            h = f.read() # read hash
        if h == hashes[tarpath.name]['hash']:
            return None
    else:
        # if there's no stored hash or it doesn't match, unpack the tar file
        unpack_archive(tarpath, outdir)
        # calculate the hash of the unpacked directory and check it's the same
        h = md5_dir(outdir/unpacks_to)
        _h = hashes[tarpath.name]['hash']
        assert h == _h,\
            f'Directory {outdir/unpacks_to} hash {h} != {_h}'
        # save the calculated hash
        with open(hashpath, 'w') as f:
            f.write(h)

def unpack_body_models(tardir, outdir, n_jobs=1, verify=False):
    tar_root, _, tarfiles = [x for x in os.walk(tardir)][0]
    tarfiles = [x for x in tarfiles if "tar" in x.split(".")]
    tarpaths = [os.path.join(tar_root, tar) for tar in tarfiles]
    for tarpath in tarpaths:
        print(f"{tarpath} extracting to {outdir}")
    unpack = lazy_unpack if verify else unpack_archive
    ProgressParallel(n_jobs=n_jobs)(
        (joblib.delayed(unpack)(tarpath, outdir) for tarpath in tarpaths),
        total=len(tarpaths),
    )


def fast_amass_unpack():
    parser = argparse.ArgumentParser(
        description="Unpack all the body model tar files in a directory to a target directory"
    )
    parser.add_argument(
        "tardir",
        type=str,
        help="Directory containing tar.bz2 body model files",
    )
    parser.add_argument(
        "outdir",
        type=str,
        help="Output directory",
    )
    parser.add_argument(
        "--verify",
        type="store_true",
        help="Verify the output by calculating a checksum, "
        "ensures that each tar file will only be unpacked once."
    )
    parser.add_argument(
        "-n",
        default=1,
        type=int,
        help="Number of jobs to run the tar unpacking with",
    )
    args = parser.parse_args()
    unpack_body_models(args.tardir, args.outdir, n_jobs=args.n, verify=args.verify)

Test unpacking the sample data always yields the same result:

In [None]:
import tempfile
import hashlib

# https://stackoverflow.com/a/3431838/6937913
def md5(fname):
    hash_md5 = hashlib.md5()
    with open(fname, "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            hash_md5.update(chunk)
    return hash_md5.hexdigest()


md5sums = {
    "amass_sample.npz": "d0b546b3619c8579ade39e3a8ccdc4e2",
    "dmpl_sample.npz": "576bb76b2a6328dc5c276c4150c466f0",
}

with tempfile.TemporaryDirectory() as tmpdirname:
    unpack_body_models("sample_data/", tmpdirname, 8)
    for r, d, f in os.walk(tmpdirname):
        npz_files = [x for x in f if "npz" in x.split(".")]
        npz_paths = [os.path.join(tmpdirname, r, x) for x in npz_files]
    _md5sums = {os.path.split(fpath)[-1]: md5(fpath) for fpath in npz_paths}

for k in md5sums:
    assert md5sums[k] == _md5sums[k]

sample_data/sample.tar.bz2 extracting to /tmp/tmps93l76ly


  0%|          | 0/1 [00:00<?, ?it/s]

Testing that `verify=True` works as expected. Can redefine `hashes` here for testing without breaking the exported library because this cell doesn't get exported by `nbdev`.

In [None]:
import time
hashes = {'sample.tar.bz2': {'unpacks_to': 'sample', 'hash': 'b5a86fe22ed2799d79101a532eb0ff27'}}

with tempfile.TemporaryDirectory() as tmpdirname:
    start = time.time()
    unpack_body_models("sample_data/", tmpdirname, 8, verify=True)
    unpacking_time = time.time() - start
    start = time.time()
    unpack_body_models("sample_data/", tmpdirname, 8, verify=True)
    skip_time = time.time() - start
    assert unpacking_time > skip_time

sample_data/sample.tar.bz2 extracting to /tmp/tmp5unr3c6d


  0%|          | 0/1 [00:00<?, ?it/s]

sample_data/sample.tar.bz2 extracting to /tmp/tmp5unr3c6d


  0%|          | 0/1 [00:00<?, ?it/s]

# Loading Functions

> Load the pose data directly from the `npz` files after unpacking.

Based on the [AMASS tutorial notebooks][amass], I would like to iterate over the dataset using a PyTorch Dataloader.

Steps to load:

1. Enumerate the paths to all the `npz` files
2. Inspect files for frame data
3. Map from a global dataset index to indexes for each frame clip

[amass]: https://github.com/nghorbani/amass/tree/master/notebooks

In [None]:
# export
import math
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

Looking at the sample data:

In [None]:
with tempfile.TemporaryDirectory() as tmpdirname:
    unpack_body_models("sample_data/", tmpdirname, 8)
    for r, d, f in os.walk(tmpdirname):
        npz_files = [x for x in f if "npz" in x.split(".")]
        npz_paths = [os.path.join(tmpdirname, r, x) for x in npz_files]
    for npz_path in npz_paths:
        cdata = np.load(npz_path)
        print(npz_path)
        print("  ", [k for k in cdata.keys()])
        print("  ", [(k, cdata[k].shape) for k in cdata.keys()])

sample_data/sample.tar.bz2 extracting to /tmp/tmpephft3on


  0%|          | 0/1 [00:00<?, ?it/s]

/tmp/tmpephft3on/sample/subdir/amass_sample.npz
   ['poses', 'gender', 'mocap_framerate', 'betas', 'marker_data', 'dmpls', 'marker_labels', 'trans']
   [('poses', (601, 156)), ('gender', ()), ('mocap_framerate', ()), ('betas', (16,)), ('marker_data', (601, 85, 3)), ('dmpls', (601, 8)), ('marker_labels', (85,)), ('trans', (601, 3))]
/tmp/tmpephft3on/sample/subdir/dmpl_sample.npz
   ['poses', 'gender', 'mocap_framerate', 'betas', 'marker_data', 'dmpls', 'marker_labels', 'trans']
   [('poses', (235, 156)), ('gender', ()), ('mocap_framerate', ()), ('betas', (16,)), ('marker_data', (235, 67, 3)), ('dmpls', (235, 8)), ('marker_labels', (67,)), ('trans', (235, 3))]


## Viable Indexes

For every `npz` file I need to pull out the viable indexes:

In [None]:
# export
def viable_slice(cdata, keep):
    """
    Inspects a dictionary loaded from `.npz` numpy dumps
    and creates a slice of the viable indexes.
    args:

        - `cdata`: dictionary containing keys:
            ['poses', 'gender', 'mocap_framerate', 'betas',
             'marker_data', 'dmpls', 'marker_labels', 'trans']
        - `keep`: ratio of the file to keep, between zero and 1.,
            drops leading and trailing ends of the arrays

    returns:

        - viable: slice that can access frames in the arrays:
            cdata['poses'], cdata['marker_data'], cdata['dmpls'], cdata['trans']
    """
    assert (
        keep > 0.0 and keep <= 1.0
    ), "Proportion of array to keep must be between zero and one"
    n = cdata["poses"].shape[0]
    drop = (1.0 - keep) / 2.0
    return slice(int(n * drop), int(n * keep + n * drop))

In [None]:
with tempfile.TemporaryDirectory() as tmpdirname:
    unpack_body_models("sample_data/", tmpdirname, 8)
    for r, d, f in os.walk(tmpdirname):
        npz_files = [x for x in f if "npz" in x.split(".")]
        npz_paths = [os.path.join(tmpdirname, r, x) for x in npz_files]
    for npz_path in npz_paths:
        cdata = np.load(npz_path)
        print(npz_path)
        print("  ", viable_slice(cdata, 0.8))

sample_data/sample.tar.bz2 extracting to /tmp/tmp1dmwr08g


  0%|          | 0/1 [00:00<?, ?it/s]

/tmp/tmp1dmwr08g/sample/subdir/amass_sample.npz
   slice(60, 540, None)
/tmp/tmp1dmwr08g/sample/subdir/dmpl_sample.npz
   slice(23, 211, None)


## Map Global Index to File Indexes

I need to be able to map from a global dataset index to contiguous sets of frames in each file. Options:

1. Whether the contiguous frames can overlap in different samples

In [None]:
# export
import warnings


@functools.lru_cache(maxsize=2)
def walk_npz_paths(npz_directory):
    npz_paths = []
    for r, d, f in os.walk(npz_directory):
        npz_files = [x for x in f if "npz" in x.split(".")]
        npz_paths += [os.path.join(npz_directory, r, x) for x in npz_files]
    return tuple(npz_paths)

def read_viable_slices(npz_paths, keep_percent):
    keep = keep_percent/100.
    viable = {}
    for npz_path in npz_paths:
        try:
            # filter out npz files that don't contain pose data
            if Path(npz_path).name not in ['shape.npz']:
                viable[npz_path] = viable_slice(np.load(npz_path), keep=keep)
        except KeyError as err:
            warnings.warn(f'Archive {npz_path} does not contain correctly formatted data')
            # raise Exception(f'Error in archive {npz_path}') from err
    return viable

def global_index_map(npz_directory, overlapping, clip_length, keep=0.8, cache_map=True):
    """
    args:
        - `npz_directory`: Directory containing `.npz` files
        - `overlapping`: Whether clips can overlap
        - `clip_length`:
    returns:
        - map from global index to corresponding file and array indexes
    """
    npz_paths = walk_npz_paths(npz_directory)
    # array slices for each file
    if cache_map:
        cache_map_dir = Path(npz_directory)/Path('viable_slices_memory')
        memory = joblib.Memory(cache_map_dir, verbose=0)
        viable_slices = memory.cache(read_viable_slices)(npz_paths, int(100*keep))
    else:
        viable_slices = read_viable_slices(npz_paths, int(100*keep))
    # clip index -> array index
    def clip_to_array_index(i, array_slice):
        if not overlapping:
            i = i * clip_length
        return [i + array_slice.start + j for j in range(clip_length)]

    # length of a slice
    def lenslice(s):
        if overlapping:
            return (s.stop - s.start) - (clip_length - 1)
        else:
            return math.floor((s.stop - s.start) / clip_length)

    # global index -> file, relative index
    def find_array(i):
        global_index, j = 0, 0
        for npz_path in viable_slices:
            global_index += lenslice(viable_slices[npz_path])
            if i < global_index:
                return npz_path, i - j
            j = global_index

    # how many examples are there in this dataset, total
    n_examples = sum(lenslice(viable_slices[npz_path]) for npz_path in viable_slices)
    # create map function
    def global_to_array(i):
        npz_path, j = find_array(i)
        return npz_path, clip_to_array_index(j, viable_slices[npz_path])

    return global_to_array, n_examples

Testing for clip length 1 and non-overlapping clips:

In [None]:
def test_global_index_map(clip_length, overlapping):
    with tempfile.TemporaryDirectory() as tmpdirname:
        unpack_body_models("sample_data/", tmpdirname, 8)
        global_to_array, n_examples = global_index_map(
            tmpdirname, overlapping=overlapping, clip_length=clip_length
        )
        print(f"Number of examples in dataset: {n_examples}")
        count = 0
        prev_j = [-1] * clip_length
        for i in range(n_examples):
            npz_path, j = global_to_array(i)
            cdata = np.load(npz_path)
            assert len(j) == clip_length
            assert cdata["poses"][j] is not None
            assert all(k > 0 for k in j)
            if not overlapping:
                assert all(k not in prev_j for k in j)
                prev_j = j
            count += 1
        assert count == n_examples


test_global_index_map(1, False)

sample_data/sample.tar.bz2 extracting to /tmp/tmps9k3hq75


  0%|          | 0/1 [00:00<?, ?it/s]

Number of examples in dataset: 668


Testing for clip length greater than 1 and non-overlapping:

In [None]:
test_global_index_map(3, False)
test_global_index_map(4, False)

sample_data/sample.tar.bz2 extracting to /tmp/tmp8hvalka_


  0%|          | 0/1 [00:00<?, ?it/s]

Number of examples in dataset: 222
sample_data/sample.tar.bz2 extracting to /tmp/tmp3jih9nxz


  0%|          | 0/1 [00:00<?, ?it/s]

Number of examples in dataset: 167


Testing for length of clip less than 1 and overlapping:

In [None]:
test_global_index_map(3, True)
test_global_index_map(4, True)

sample_data/sample.tar.bz2 extracting to /tmp/tmpta70cr8w


  0%|          | 0/1 [00:00<?, ?it/s]

Number of examples in dataset: 664
sample_data/sample.tar.bz2 extracting to /tmp/tmp7rgo6dtq


  0%|          | 0/1 [00:00<?, ?it/s]

Number of examples in dataset: 662


## Load Data from `.npz` Files

I want to load the data from the `.npz` files in a standard way, so I'm going to load each entry into its own array.

In [None]:
# export
def load_npz(npz_path, indexes, load_poses=True, load_dmpls=True,
             load_trans=True, load_betas=True, load_gender=True):
    # cache this because we will often be accessing the same file multiple times
    cdata = functools.lru_cache(maxsize=128)(np.load)(npz_path)

    data = {}
    # unpack and enforce data type
    if load_poses:
        data['poses'] = cdata["poses"][indexes].astype(np.float32)
    if load_dmpls:
        data['dmpls'] = cdata["dmpls"][indexes].astype(np.float32)
    if load_trans:
        data['trans'] = cdata["trans"][indexes].astype(np.float32)
    if load_betas:
        data['betas'] = np.repeat(
            cdata["betas"][np.newaxis].astype(np.float32), repeats=len(indexes), axis=0
        )
    if load_gender:
        def gender_to_int(g):
            # casting gender to integer will raise a warning in future
            g = str(g.astype(str))
            return {"male": -1, "neutral": 0, "female": 1}[g]
        data['gender'] = np.array([gender_to_int(cdata["gender"]) for _ in indexes])

    return data

Test this works with different clip lengths and overlapping clips.

In [None]:
def test_load_npz(clip_length, overlapping):
    with tempfile.TemporaryDirectory() as tmpdirname:
        unpack_body_models("sample_data/", tmpdirname, 8)
        global_to_array, n_examples = global_index_map(
            tmpdirname, overlapping=overlapping, clip_length=clip_length
        )
        npz_path, indexes = global_to_array(0)
        data = load_npz(npz_path, indexes)
        for k in data:
            assert data[k].shape[0] == clip_length
        print([data[k].shape for k in data])


test_load_npz(1, False)
test_load_npz(3, False)
test_load_npz(3, True)

sample_data/sample.tar.bz2 extracting to /tmp/tmp2bevwk7s


  0%|          | 0/1 [00:00<?, ?it/s]

[(1, 156), (1, 8), (1, 3), (1, 16), (1,)]
sample_data/sample.tar.bz2 extracting to /tmp/tmpm2j2ugk9


  0%|          | 0/1 [00:00<?, ?it/s]

[(3, 156), (3, 8), (3, 3), (3, 16), (3,)]
sample_data/sample.tar.bz2 extracting to /tmp/tmp4lloez87


  0%|          | 0/1 [00:00<?, ?it/s]

[(3, 156), (3, 8), (3, 3), (3, 16), (3,)]


How long does it take to iterate over all the sample data?

In [None]:
import time
from pathlib import Path

In [None]:
clip_length, overlapping = 1, False
with tempfile.TemporaryDirectory() as tmpdirname:
    unpack_body_models("sample_data/", tmpdirname, 8)
    global_to_array, n_examples = global_index_map(
        tmpdirname, overlapping=overlapping, clip_length=clip_length
    )
    start = time.time()
    for i in range(n_examples):
        npz_path, indexes = global_to_array(i)
        _ = load_npz(npz_path, indexes)
    elapsed = time.time() - start
print(f'time to iterate: {elapsed}')

sample_data/sample.tar.bz2 extracting to /tmp/tmpj5hliecp


  0%|          | 0/1 [00:00<?, ?it/s]

time to iterate: 1.8302652835845947


In [None]:
clip_length, overlapping = 1, False
with tempfile.TemporaryDirectory() as tmpdirname:
    dir_size = lambda d: sum(file.stat().st_size for file in Path(d).rglob('*'))/1e6
    unpack_body_models("sample_data/", tmpdirname, 8)
    data_size = dir_size(tmpdirname)
    global_to_array, n_examples = global_index_map(
        tmpdirname, overlapping=overlapping, clip_length=clip_length
    )
    cache_dir = os.path.join(tmpdirname, 'cache')
    os.mkdir(cache_dir)
    memory = joblib.Memory(cache_dir, verbose=0)
    cached_load_npz = memory.cache(load_npz)
    print('caching...')
    start = time.time()
    for i in range(n_examples):
        npz_path, indexes = global_to_array(i)
        _ = cached_load_npz(npz_path, indexes)
    print(f'time to cache: {time.time() - start}')
    print('loading from cache...')
    start = time.time()
    for i in range(n_examples):
        npz_path, indexes = global_to_array(i)
        _ = cached_load_npz(npz_path, indexes)
    elapsed = time.time() - start
    cache_size = dir_size(cache_dir)
print(f'time to iterate: {elapsed}')
print(f'stored size (data+cache)/total: ({data_size:.2f}MB + {cache_size:.2f}MB)/{cache_size+data_size:.2f}MB')

sample_data/sample.tar.bz2 extracting to /tmp/tmp18asdg23


  0%|          | 0/1 [00:00<?, ?it/s]

caching...
time to cache: 3.4107885360717773
loading from cache...
time to iterate: 0.48926854133605957
stored size (data+cache)/total: (2.73MB + 3.75MB)/6.49MB


Loading from npz files and caching the result to speed everything up.

# PyTorch Dataset Class

Creating a map-style PyTorch Dataset Class that uses these functions to load the data.

In [None]:
# export
class AMASS(Dataset):
    def __init__(self, unpacked_directory, clip_length, overlapping,
                 transform=None, memory=False, memory_bytes_limit=None,
                 to_load=('poses', 'dmpls', 'trans', 'betas', 'gender')):
        self.global_to_array, self.n_examples = global_index_map(
            unpacked_directory, overlapping=overlapping, clip_length=clip_length
        )
        self.transform = transform
        self.to_load = {}
        for k in ('poses', 'dmpls', 'trans', 'betas', 'gender'):
            l = f'load_{k}'
            self.to_load[l] = True if k in to_load else False 
        caching_directory = Path(unpacked_directory) / Path('memory')
        if memory_bytes_limit is not None:
            warnings.warn(f'AMASS.memory.reduce_size() must be called reduce cache size to be less than {memory_bytes_limit}')
        self.memory = joblib.Memory(caching_directory, verbose=0, bytes_limit=memory_bytes_limit)
        self.load_npz = self.memory.cache(load_npz) if memory else load_npz

    def __len__(self):
        return self.n_examples

    def __getitem__(self, i):
        npz_path, array_index = self.global_to_array(i)
        data = self.load_npz(npz_path, array_index, **self.to_load)
        return {k: self.transform(data[k]) for k in data}

Test I can load some data with this Dataset:

In [None]:
with tempfile.TemporaryDirectory() as tmpdirname:
    unpack_body_models("sample_data/", tmpdirname, 8)
    amass = AMASS(tmpdirname, overlapping=False, clip_length=1, transform=torch.tensor)
    data = amass[0]
    for k in data:
        print(k, data[k].shape)
        assert type(data[k]) is torch.Tensor

sample_data/sample.tar.bz2 extracting to /tmp/tmpw_r2sl6k


  0%|          | 0/1 [00:00<?, ?it/s]

poses torch.Size([1, 156])
dmpls torch.Size([1, 8])
trans torch.Size([1, 3])
betas torch.Size([1, 16])
gender torch.Size([1])


Test it works in a DataLoader to make batches:

In [None]:
with tempfile.TemporaryDirectory() as tmpdirname:
    unpack_body_models("sample_data/", tmpdirname, 8)
    amass = AMASS(tmpdirname, overlapping=False, clip_length=1, transform=torch.tensor)
    amasstrain = DataLoader(amass, batch_size=4, shuffle=True)
    for i, data in enumerate(amasstrain):
        if i == 0:
            for k in data:
                print(k, data[k].shape)
        assert data["poses"].size(0) == 4

sample_data/sample.tar.bz2 extracting to /tmp/tmptugy2bd9


  0%|          | 0/1 [00:00<?, ?it/s]

poses torch.Size([4, 1, 156])
dmpls torch.Size([4, 1, 8])
trans torch.Size([4, 1, 3])
betas torch.Size([4, 1, 16])
gender torch.Size([4, 1])


# Caching

Initially when I tried to load from AMASS using just a single worker (more than one worker accessing the `npz` files directly would lock up), the estimate of runtime just to iterate over the dataset was going to 160 hours. That's not practical, so I decided to use `joblib.Memory` to cache loading of the `npz` files.

This also involves caching the dictionary of `viable_slices` above.

In [None]:
def cache_amass():
    parser = argparse.ArgumentParser(
        description="Cache the data using joblib.Memory"
    )
    parser.add_argument(
        "amassdir",
        type=str,
        help="Directory where AMASS has been unpacked",
    )
    parser.add_argument(
        "--verify",
        type="store_true",
        help="Verify the output by calculating a checksum, "
        "ensures that each tar file will only be unpacked once."
    )
    parser.add_argument(
        "--num-workers",
        default=1,
        type=int,
        help="Number of Dataloader workers to run the caching with",
    )
    parser.add_argument("--bytes-limit",
        default=None,
        type=int,
        help="Limit of bytes to store on disk (this means some data will not be cached)"
    )
    args = parser.parse_args()
    unpack_body_models(args.tardir, args.outdir, n_jobs=args.n, verify=args.verify)

In [None]:
#hide
from nbdev.export import notebook2script; notebook2script()

Converted 00_core.ipynb.
Converted index.ipynb.
