In [31]:
import pandas as pd
import numpy as np
import json

input_path = "../usable_micro_labeled_video_list.json"
output_path = "/home/manoj/train_test_val_splits/{}"
output_files = ["train.json", "validation.json", "test.json"]
TRAIN_PERCENT = 0.75
VALID_PERCENT = 0.85


def train_val_test_splits(df, train_percent, valid_percent):
	'''
	Splits the dataframe into train , val and test dataframes.
	'''
	train, val, test = np.split(df.sample(frac=1), \
					[int(train_percent*len(df)), int(valid_percent*len(df))])
	print("Train dataset size: ", train.shape)
	print("Val dataset size: ", val.shape)
	print("Test dataset size: ", test.shape)
	return [train, val, test]

def write_splits(df, train_percent, valid_percent):
	'''
	Calls train_val_test_splits() to split the dataframe and writes the splits 
	to the location as specified by the output path.
	'''
	print("Creating train, val, test splits")
	splits = train_val_test_splits(df, train_percent, valid_percent)
	for i in range(0, 3):
		splits[i].to_json(output_path.format(output_files[i]), orient = 'records')


if __name__ == '__main__':
	df = pd.read_json(input_path, orient = 'records')
	write_splits(df, TRAIN_PERCENT, VALID_PERCENT)
	print("Completed writing files.")


Creating train, val, test splits
Train dataset size:  (8590, 3)
Val dataset size:  (1145, 3)
Test dataset size:  (1719, 3)
Completed writing files.


In [32]:
train = pd.read_json('/home/manoj/train_test_val_splits/train.json')
test = pd.read_json('/home/manoj/train_test_val_splits/test.json')
val = pd.read_json('/home/manoj/train_test_val_splits/validation.json')

In [33]:
train.count()

composite_json    8590
macro             8590
micro             8590
dtype: int64

In [34]:
test.count()

composite_json    1719
macro             1719
micro             1719
dtype: int64

In [35]:
val.count()

composite_json    1145
macro             1145
micro             1145
dtype: int64

In [36]:
train.groupby('macro').count()

Unnamed: 0_level_0,composite_json,micro
macro,Unnamed: 1_level_1,Unnamed: 2_level_1
benchpress,2510,2510
deadlift,2561,2561
squat,3519,3519


In [37]:
test.groupby('macro').count()

Unnamed: 0_level_0,composite_json,micro
macro,Unnamed: 1_level_1,Unnamed: 2_level_1
benchpress,496,496
deadlift,498,498
squat,725,725


In [38]:
val.groupby('macro').count()

Unnamed: 0_level_0,composite_json,micro
macro,Unnamed: 1_level_1,Unnamed: 2_level_1
benchpress,322,322
deadlift,331,331
squat,492,492
