In [10]:
import csv
import pathlib
import os
from collections import defaultdict
import zipfile

from tqdm import tqdm
import numpy as np

In [3]:
root_d = pathlib.Path('emage_amass30_beat2')
amass_d = root_d / 'amass_smplx'
beat2_d = root_d / 'beat_v2.0.0'
data_d = root_d / 'emage_amass30_beat2'

data_d.mkdir(parents=True,exist_ok=True)

Subset of AMASS that has fps multiple of 30.

In [4]:
data_fs = [pathlib.Path(f'{rt}/{f}')
           for rt,ds,fs in os.walk(amass_d)
           for f in fs
           if f.endswith('.npz')]
data_fs.sort()

In [5]:
amass_data_fs = list()
files_not_readable = list()
files_model_type_not_usable = list()
files_mocap_frame_rate_key_missing = list()
files_frame_rate_not_multiple_of_30 = list()
all_model_types = defaultdict(int)
all_fps = defaultdict(int)
for f in tqdm(data_fs,desc='filtering 30fps files',ncols=150):
    try:
        data = np.load(f,allow_pickle=True)
    except zipfile.BadZipFile:
        files_not_readable.append(f)
        continue
    all_model_types[data['surface_model_type'].item()] += 1
    if data['surface_model_type'].item() not in {'smplx','smplx_locked_head'}:
        files_model_type_not_usable.append(f)
        continue
    if 'mocap_frame_rate' not in data:
        files_mocap_frame_rate_key_missing.append(f)
        continue
    fps = int(data['mocap_frame_rate'].item())
    all_fps[fps] += 1
    if fps % 30 != 0:
        files_frame_rate_not_multiple_of_30.append(f)
        continue
    amass_data_fs.append(f)
print('total files:',len(data_fs))
print('usable_files:',len(amass_data_fs))
print('files not readable:',len(files_not_readable))
print('files model type not usable:',len(files_model_type_not_usable))
print('files missing mocap_frame_rate:',len(files_mocap_frame_rate_key_missing))
print('files with frame rante not multiple of 30fps:',len(files_frame_rate_not_multiple_of_30))
print('model types:',*sorted(all_model_types.items()))
print('mocap frame rate:',*sorted(all_fps.items()))
assert len(data_fs) == (len(amass_data_fs) + 
                        len(files_not_readable) + 
                        len(files_model_type_not_usable) + 
                        len(files_mocap_frame_rate_key_missing) + 
                        len(files_frame_rate_not_multiple_of_30))

filtering 30fps files: 100%|███████████████████████████████████████████████████████████████████████████████████| 17430/17430 [00:30<00:00, 575.00it/s]

total files: 17430
usable_files: 9977
files not readable: 1
files model type not usable: 0
files missing mocap_frame_rate: 462
files with frame rante not multiple of 30fps: 6990
model types: ('smplx', 17429)
mocap frame rate: (59, 13) (60, 532) (100, 6920) (120, 9436) (150, 9) (250, 57)





Number of usable files is 9977. Create symlinks to these files in the combined data dir. EMAGE code was written for BEAT2 which has all the npz files at the top level. Create the same structure here.

In [7]:
for src_f in tqdm(amass_data_fs,desc='creating symlinks',ncols=150):
    dst_f = data_d / str(src_f.relative_to(amass_d)).replace('/','__')
    dst_f.symlink_to(src_f)

creating symlinks: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 9977/9977 [00:15<00:00, 645.99it/s]


Create symlinks to BEAT2 files in the combined data dir. There is no filtering needed for BEAT2 as all of them are 30fps already.

In [8]:
beat2_data_fs = [pathlib.Path(f'{rt}/{f}')
                 for rt,ds,fs in os.walk(beat2_d)
                 for f in fs
                 if f.endswith('.npz')]
beat2_data_fs.sort()

for src_f in tqdm(beat2_data_fs,desc='creating symlinks',ncols=150):
    dst_f = data_d / str(src_f.relative_to(beat2_d)).replace('/','__')
    dst_f.symlink_to(src_f)

creating symlinks: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 2048/2048 [00:03<00:00, 669.18it/s]


Make sure there are correct number of files in the data dir. Total should be `9977+2048=12025`.

In [9]:
data_fs = list(data_d.iterdir())
data_fs.sort()

print('num files:',len(data_fs))

num files: 12025


Create the train/val/test split file similar to BEAT2 because EMAGE code expects it. Here we will use the following:
* train = AMASS30 + BEAT2-train + BEAT2-val + BEAT2-additional
* val = BEAT2-test
* test = BEAT2-test

In [28]:
train_val_test = defaultdict(list)
for train_val_test_f in beat2_d.glob('*/train_test_split.csv'):
    for row in csv.DictReader(open(train_val_test_f)):
        train_val_test[row['type']].append((train_val_test_f.parent.name,
                                            row['id']))

for k,v in train_val_test.items():
    print(f'num examples BEAT2 {k}:',len(v))

num examples BEAT2 test: 355
num examples BEAT2 train: 1383
num examples BEAT2 additional: 198
num examples BEAT2 val: 118


In [33]:
rows = list()

missing_files = list()
added_files = set()

n_test = 0
n_val = 0
for d,i in train_val_test['test']:
    id = f'{d}__smplxflame_30__{i}'
    f = data_d / f'{id}.npz'
    if not f.is_file():
        missing_files.append(f)
        continue
    rows.append({'id':id,
                 'type':'test'})
    rows.append({'id':id,
                 'type':'val'})
    n_test += 1
    n_val += 1
    added_files.add(f)
    
n_train = 0
for d,i in (train_val_test['train'] + 
            train_val_test['val'] + 
            train_val_test['additional']):
    id = f'{d}__smplxflame_30__{i}'
    f = data_d / f'{id}.npz'
    if not f.is_file():
        missing_files.append(f)
        continue
    rows.append({'id':id,
                 'type':'train'})
    n_train += 1
    added_files.add(f)
    
for f in data_fs:
    if f in added_files:
        continue
    rows.append({'id':f.with_suffix('').name,
                 'type':'train'})
    n_train += 1
    
print('num examples:',len(rows))
print('num train:',n_train)
print('num train:',n_val)
print('num test:',n_test)
print('missing files:',len(missing_files))

num examples: 12380
num train: 11670
num train: 355
num test: 355
missing files: 7


Missing files do not affect us. Moving on to write the csv.

In [34]:
csv_f = data_d / 'train_test_split.csv'
with open(csv_f,'w') as csv_fh:
    writer = csv.DictWriter(csv_fh,
                            fieldnames=['id','type'])
    writer.writeheader()
    writer.writerows(rows)

Move all the npz files to a sub-dir named `smplxflame_30`. EMAGE expects this structure.

In [25]:
!mkdir emage_amass30_beat2/smplxflame_30
!mv emage_amass30_beat2/*.npz emage_amass30_beat2/smplxflame_30
!mkdir emage_amass30_beat2/weights

mv: cannot stat 'emage_amass30_beat2/*.txt': No such file or directory
