In [1]:
import json

segments = json.load(open('../../data/annotation.json'))
animations = json.load(open('../../data/steps.json'))
splits = json.load(open('../split.json'))
train_set = splits['train']
val_set = splits['val']
test_set = splits['test']

In [2]:
import random

random.seed(43)
fps = 29.97
min_d = int(30 * fps) # minimal distance from the candidates to the query
d_b = int(15 * fps) # minimal distance from the candidates and query to the step boundary

def valid(x, L1, L2):
    # if there is valid distance d when the segment lengths are L1 and L2, and the mid point is x
    if x + d_b >= L1 or x < 0:
        return False
    mn = max(min_d, L1 + d_b - x) # minimal possible d
    mx = min(x - d_b, L1 + L2 - 1 - x - d_b) # maximal possible d
    return mn <= mx

def get_samples(L1, L2, M):
    samples = []
    valid_x = []
    for x in range(0, L1):
        if valid(x, L1, L2):
            valid_x.append(x)
    for i in range(M):
        m = len(valid_x) * i // M
        m = valid_x[m]
        try:
            d = random.randint(max(min_d, L1 + d_b - m), min(m - d_b, L1 + L2 - 1 - m - d_b))
        except:
            print(L1, L2, m)
        l = m - d
        r = m + d
        assert m >= d_b and m + d_b < L1
        assert d >= min_d
        assert l >= d_b and l + d_b < L1
        assert r - d_b >= L1 and r + d_b < L1 + L2
        samples.append((m, l, r))
    return samples

def construct_(segments, n_samples):
    n_samples_tmp = n_samples
    data = []
    candidates = []
    caps = []
    for i in range(len(segments) - 1):
        if segments[i]['label'] > -1 and segments[i+1]['label'] > -1:
            X1, X2 = [], []
            L1 = segments[i]['frame_end'] - segments[i]['frame_start']
            L2 = segments[i+1]['frame_end'] - segments[i+1]['frame_start']
            for x in range(0, L1):
                if valid(x, L1, L2):
                    X1.append(x)
            for x in range(0, L2):
                if valid(x, L2, L1):
                    X2.append(x)
            if len(X1) > 0 and len(X2) > 0:
                candidates.append(i)
                caps.append(min(len(X1), len(X2)) * 2)

    # Evenly allocate the samples to each adjacent segment pair
    n_samples_ = [0] * len(candidates)
    while n_samples > 0:
        for i in range(len(candidates)):
            if n_samples > 0 and caps[i] > 0:
                if segments[candidates[i]]['label'] == 0 or segments[candidates[i]+1]['label'] == 0:
                    # There are too many irrelevant segments, so we need to draw samples from them less
                    if random.randint(1, 100) == 1:
                        n_samples_[i] += 1
                        n_samples -= 1
                        caps[i] -= 1
                else:
                    n_samples_[i] += 1
                    n_samples -= 1
                    caps[i] -= 1
    for i, M  in zip(candidates, n_samples_):
        hf = M // 2

        # Half samples with the query timestamp in the former segment
        samples = get_samples(segments[i]['frame_end'] - segments[i]['frame_start'], segments[i+1]['frame_end'] - segments[i+1]['frame_start'], hf)
        for m, l, r in samples:
            data.append({
                'query_frame_index': segments[i]['frame_start'] + m,
                'candidates': [
                    segments[i]['frame_start'] + l,
                    segments[i]['frame_start'] + r
                ] # the first one is positive and the second one is negative
            })
        
        # Half samples with the query timestamp in the latter segment
        samples = get_samples(segments[i+1]['frame_end'] - segments[i+1]['frame_start'], segments[i]['frame_end'] - segments[i]['frame_start'], M - hf)
        L = segments[i+1]['frame_end'] - segments[i]['frame_start'] - 1
        for m, l, r in samples:
            data.append({
                'query_frame_index': segments[i]['frame_start'] + L - m,
                'candidates': [
                    segments[i]['frame_start'] + L - l,
                    segments[i]['frame_start'] + L - r
                ] # the first one is positive and the second one is negative
            })
    for i in range(len(data)):
        # randomly shuffle the positive and the negative
        label = random.randint(0, 1)
        if label == 1:
            cans = [data[i]['candidates'][1], data[i]['candidates'][0]]
        else:
            cans = data[i]['candidates']
        data[i]['candidate_frame_indices'] = cans
        data[i].pop('candidates')
        data[i]['label'] = label
        data[i]['id'] = i
    assert len(data) == n_samples_tmp
    return data

def construct(segments, dates, n_samples):
    '''
        Construct `n_samples` samples from each video listed in `dates`
    '''
    data = {}
    for date in dates:
        print('Processing:', date)
        data[date] = construct_(segments[date], n_samples)
    return data

In [3]:
# #####################################
#             Warning                #
# As randomness is involved, please  #
# modify the output file name before #
# running this code. Otherwise, it   #
# will overwrite the training set.   # 
# #####################################

# Construct training set
data = construct(segments, train_set, 2000)
with open('train.json', 'w') as f:
    f.write(json.dumps(data, indent=4))
    f.close()

Processing: 01152020


Processing: 01252020
Processing: 01272021
Processing: 02012021
Processing: 03232022
Processing: 06162021
Processing: 10062019
Processing: 11222019
Processing: 12022021
Processing: 12032022


In [4]:
#####################################
#             Warning               #
# As randomness is involved, please #
# don't run this to overwrite the   #
# original validation set.          #
#####################################

# Construct validation set
data = construct(segments, val_set, 2000)
with open('val.json', 'w') as f:
    f.write(json.dumps(data, indent=4))
    f.close()

Processing: 03152022


Processing: 11152022


In [5]:
#####################################
#             Warning               #
# As randomness is involved, please #
# don't run this to overwrite the   #
# original test set.                #
#####################################

# Construct test set
data = construct(segments, test_set, 2000)
with open('test.json', 'w') as f:
    f.write(json.dumps(data, indent=4))
    f.close()

Processing: 02282021


Processing: 06092023
Processing: 06262020
Processing: 09122021
Processing: 11152019
Processing: 12022019
