LLAMASS is Loader for the AMASS dataset
I'm writing this to use in a project working with pose data to train models on the AMASS dataset. I wanted to be able to install it in colab notebooks and elsewhere easily. Hopefully it's also useful for other people but be aware this is research code so not necessarily reliable.
Before using the AMASS dataset I'm expected to sign up to the license agreeement here. This package doesn't require any other code from MPI but visualization of pose data does, see below.
Requirements are handled by pip during the install but in a new environment I would install pytorch first to configure cuda as required for the system.
pip install llamass
There are demos for plotting available in the amass repo and in the smplx repo. I wrote a library based on these to plot without having to think about the betas
, dmpls
etc. It's called gaitplotlib and it can be installed from github:
pip install git+https://github.com/gngdb/gaitplotlib.git
The AMASS website provides links to download the various parts of the AMASS dataset. Each is provided as a .tar.bz2
and I had to download them from the website by hand. Save all of these in a folder somehwere.
After installing llamass
, it provides a console script to unpack the tar.bz2
files downloaded from the AMASS website:
fast_amass_unpack -n 4 --verify <dir with .tar.bz2 files> <dir to save unpacked data>
This will unpack the data in parallel in 4 jobs and provides a progress bar. The --verify
flag will md5sum
the directory the files are unpacked to and check it against what I found when I unpacked it. It'll also avoid unpacking tar files that have already been unpacked by looking for saved .hash
files in the target directory. It's slower but more reliable and recovers from incomplete unpacking.
Alternatively, this can be access in the library using the llamass.core.unpack_body_models
function:
import llamass.core
llamass.core.unpack_body_models("sample_data/", unpacked_directory, 4)
sample_data/sample.tar.bz2 extracting to /tmp/tmp06iwsfhu
I've processed the files to find out how many frames are in each numpy archive unpacked when fast_amass_unpack
is run. By default, the first time the AMASS
Dataset object is asked for it's len
it will look for a file containing this information in the specified AMASS directory. If it doesn't find it, it will recompute it and that can take 5 minutes.
Save 5 minutes by downloading it from this repository:
wget https://github.com/gngdb/llamass/raw/master/npz_file_lens.json.gz -P <dir to save unpacked data>
details of script for splits goes here
Once the data is unpacked it can be loaded by a PyTorch DataLoader directly using the llamass.core.AMASS
Dataset class.
overlapping
: whether the clips of frames taken from each file should be allowed to overlapclip_length
: how long should clips from each file be?transform
: a transformation function apply to all fields
It is an IterableDataset so it cannot be shuffled by the DataLoader. If shuffle=True
the DataLoader will hit an error. However, the AMASS
class itself implements shuffling and it can be enabled using shuffle=True
at initialisation.
Also, in order to use more than one worker it is necessary to use the provided worker_init_fn
in the DataLoader. This can also be accessed by using llamass.core.IterableLoader
in place of DataLoader
, and this is safer because using DataLoader
without worker_init_fn
will yield duplicate data when workers load from the same files.
import torch
from torch.utils.data import DataLoader
amass = llamass.core.AMASS(
unpacked_directory,
overlapping=False,
clip_length=1,
transform=torch.tensor,
shuffle=False,
seed=0,
)
# these are equivalent
amassloader = DataLoader(amass, batch_size=4, num_workers=2, worker_init_fn=llamass.core.worker_init_fn)
amassloader = llamass.core.IterableLoader(amass, batch_size=4, num_workers=2)
for data in amassloader:
for k in data:
print(k, data[k].size())
break
poses torch.Size([4, 1, 156])
dmpls torch.Size([4, 1, 8])
trans torch.Size([4, 1, 3])
betas torch.Size([4, 1, 16])
gender torch.Size([4, 1])
poses = next(iter(llamass.core.IterableLoader(amass, batch_size=200, num_workers=2)))
poses = poses['poses'].squeeze()
# gaitplotlib
import numpy as np
import gaitplotlib.core
import matplotlib.pyplot as plt
plt.style.use('ggplot')
params = gaitplotlib.core.plottable(poses.numpy())
def plot_pose(pose_index, save_to=None):
fig, axes = plt.subplots(1, 3, figsize=(10,6))
for d, ax in enumerate(axes):
dims_to_plot = [i for i in range(3) if i != d]
joints, skeleton = params[pose_index]["joints"], params.skeleton
j = joints[:, dims_to_plot]
ax.scatter(*j.T, color="r", s=0.2)
for bone in skeleton:
a = j[bone[0]]
b = j[bone[1]]
x, y = list(zip(a, b))
ax.plot(x, y, color="r", alpha=0.5)
ax.axes.xaxis.set_ticklabels([])
ax.axes.yaxis.set_ticklabels([])
ax.set_aspect('equal', adjustable='box')
if save_to is not None:
plt.tight_layout()
plt.savefig(save_to)
plt.close()
else:
plt.show()
plot_pose(0)
# gaitplotlib
from pathlib import Path
import mediapy as media
animloc = Path(unpacked_directory)/'anim'
animloc.mkdir(exist_ok=True)
def get_frame(i, frameloc=animloc/'frame.jpeg'):
plot_pose(i, save_to=frameloc)
return media.read_image(frameloc)
img_arr = get_frame(0)
with media.VideoWriter(animloc/'anim.gif', codec='gif', shape=img_arr.shape[:2], fps=10) as w:
for i in range(0, params.vertices.shape[0], 10):
frameloc = animloc/'frame.jpeg'
plot_pose(i, save_to=frameloc)
img_arr = media.read_image(frameloc)
w.add_image(img_arr)
video = media.read_video(animloc/'anim.gif')
media.show_video(video, codec='gif')