In [7]:
import json

FPS = 29.97

segments = json.load(open('../../data/annotation.json'))
animations = json.load(open('../../data/steps.json'))
splits = json.load(open('../split.json'))
train_set = splits['train']
val_set = splits['val']
test_set = splits['test']

In [6]:
def construct_(segments, n_label, n_samples):
    n_samples_tmp = n_samples
    data = []

    # Group frame indices by label
    candidates = [[] for i in range(n_label)]
    for segment in segments:
        l = segment['label']
        if l > -1:
            candidates[l].extend(list(range(segment['frame_start'], segment['frame_end'])))

    # Calculate the number of samples for each label
    n_samples_ = [0] * n_label
    while n_samples > 0:
        for i in range(n_label):
            if n_samples > 0 and n_samples_[i] < len(candidates[i]):
                n_samples_[i] += 1
                n_samples -= 1

    # Uniformly draw samples from each category
    for i in range(n_label):
        N = len(candidates[i])
        M = n_samples_[i]
        for j in range(1, M + 1):
            idx = int((N - 1) * j / (M + 1))
            data.append({
                'timestamp': candidates[i][idx] / FPS,
                'frame_index': candidates[i][idx],
                'label': i
            })
    data.sort(key=lambda x: x['frame_index'])
    for i in range(len(data)):
        data[i]['id'] = i
    assert len(data) == n_samples_tmp
    return data

def construct(segments, dates, n_samples):
    '''
        Construct `n_samples` samples from each video listed in `dates`
    '''
    data = {}
    for date in dates:
        print('Processing:', date)
        data[date] = construct_(segments[date], len(animations[date]), n_samples)
    return data

In [8]:
# Construct training set
data = construct(segments, train_set, 2000)
with open('train.json', 'w') as f:
    f.write(json.dumps(data, indent=4))
    f.close()

Processing: 01152020
Processing: 01252020
Processing: 01272021
Processing: 02012021
Processing: 03232022
Processing: 06162021
Processing: 10062019
Processing: 11222019
Processing: 12022021
Processing: 12032022


In [9]:
# Construct validation set
data = construct(segments, val_set, 2000)
with open('val.json', 'w') as f:
    f.write(json.dumps(data, indent=4))
    f.close()

Processing: 03152022
Processing: 11152022


In [10]:
# Construct test set
data = construct(segments, test_set, 2000)
with open('test.json', 'w') as f:
    f.write(json.dumps(data, indent=4))
    f.close()

Processing: 02282021
Processing: 06092023
Processing: 06262020
Processing: 09122021
Processing: 11152019
Processing: 12022019
