# Data Split
Split data into train/valid/test set, and export the file list.

In [None]:
import os
import random

In [None]:
train_ratio = 0.65
valid_ratio = 0.15
test_ratio = 1 - train_ratio - valid_ratio

In [None]:
print(train_ratio, valid_ratio, test_ratio)

## Load File Names from Data Files (Choice 1)

In [None]:
data_dir = 'data_dir'  # the directory where the data files are stored
valid_suffices = ('.json',)  # or None

In [None]:
file_name_list = []
for root_dir, dirs, files in os.walk(data_dir):
    for file_name in files:
        if valid_suffices is not None:
            skip = True
            for suffix in valid_suffices:
                if file_name.endswith(suffix):
                    skip = False
                    break
            if skip:
                continue
        file_path = os.path.join(root_dir, file_name)
        file_path = os.path.relpath(file_path, data_dir)
        file_name_list.append(file_path)

## Load File Names from a List (Choice 2)

In [None]:
file_list_file = 'file_list.txt'  # use a text file to specify the data files

In [None]:
file_name_list = []
with open(file_list_file, 'r') as f:
    for line in f:
        line = line.strip()
        if line == '':
            continue
        file_name_list.append(line)

## Calculating and Save

In [None]:
num_samples = len(file_name_list)
print(num_samples)

In [None]:
train_num = round(num_samples * train_ratio)
valid_num = round(num_samples * valid_ratio)
test_num = num_samples - train_num - valid_num

In [None]:
print(train_num)
print(valid_num)
print(test_num)

In [None]:
random.shuffle(file_name_list)

In [None]:
output_dir = 'split'  # a directory to save the file lists

In [None]:
os.makedirs(output_dir, exist_ok=True)

In [None]:
with open(output_dir + '/train.txt', 'w', encoding='utf-8') as f:
    for item in file_name_list[:train_num]:
        f.write(item + '\n')

In [None]:
with open(output_dir + '/valid.txt', 'w', encoding='utf-8') as f:
    for item in file_name_list[train_num:train_num+valid_num]:
        f.write(item + '\n')

In [None]:
with open(output_dir + '/test.txt', 'w', encoding='utf-8') as f:
    for item in file_name_list[train_num+valid_num:]:
        f.write(item + '\n')