# Week 2 Tuesday
# June 11, 2019

# train_val_test_split development notebook
## Code to perform stratified split on class examples into train and test sets for use with generator class

Input format:

data ====== class1
        |
        |== class2
        |
        |== class3
        
Output format:

data ====== train ===== class1
        |            |
        |            |= class2
        |            |
        |            |= class3
        |
        |== val ====== class1
                     |
                     |= class2
                     |
                     |= class3
        
        |== test ====== class1
                     |
                     |= class2
                     |
                     |= class3

In [10]:
!ls data

goals  nongoals


In [11]:
import random
import shutil
import os

In [12]:
def train_test_split(root='./data', classes=['goals', 'nongoals'], split_ratio=[0.8, 0.1, 0.1]):
    
    # make train and test directories
    current_dir = os.listdir(root)
    train_dir = os.path.join(root, 'train')
    val_dir = os.path.join(root, 'val')
    test_dir = os.path.join(root, 'test')    
    
    if 'train' not in current_dir:
        os.mkdir(train_dir)

    if 'val' not in current_dir:
        os.mkdir(val_dir)        
        
    if 'test' not in current_dir:
        os.mkdir(test_dir)
    
    train_dir_content = os.listdir(train_dir)
    val_dir_content = os.listdir(val_dir)
    test_dir_content = os.listdir(test_dir)
    
    for cls in classes:
        cls_dir = os.path.join(root, cls)
        print('Accessing files in ' + cls_dir)
        cls_list = os.listdir(cls_dir)
        assert len(cls_list) > 0
        
        random.shuffle(cls_list)
        split_train = round(len(cls_list) * split_ratio[0])
        split_train = max(0, split_train)
        
        split_test = len(cls_list) - round(len(cls_list) * split_ratio[2])
        split_test = min(split_test, len(cls_list) - 1)
        
        cls_train_dir = os.path.join(train_dir, cls)
        cls_val_dir = os.path.join(val_dir, cls)
        cls_test_dir = os.path.join(test_dir, cls)
        
        if cls not in train_dir_content:
            os.mkdir(cls_train_dir)

        if cls not in val_dir_content:
            os.mkdir(cls_val_dir)            
            
        if cls not in test_dir_content:
            os.mkdir(cls_test_dir)
        
        cls_train_set = cls_list[:split_train]
        cls_val_set = cls_list[split_train:split_test]
        cls_test_set = cls_list[split_test:]
        
        for dir in cls_train_set:
            shutil.copy2(os.path.join(cls_dir, dir), cls_train_dir)

        for dir in cls_val_set:
            shutil.copy2(os.path.join(cls_dir, dir), cls_val_dir)            
            
        for dir in cls_test_set:
            shutil.copy2(os.path.join(cls_dir, dir), cls_test_dir)

In [13]:
train_test_split(root='./data',
                 classes=['goals', 'nongoals'],
                 split_ratio = [0.8, 0.1, 0.1]
                )

Accessing files in ./data/goals
Accessing files in ./data/nongoals


In [14]:
!ls ./data

goals  nongoals  test  train  val


In [18]:
!ls ./data/val/goals

0.mkv	  1144.mkv  1404.mkv  1585.mkv	394.mkv  589.mkv  71.mkv   892.mkv
1002.mkv  115.mkv   1425.mkv  15.mkv	396.mkv  58.mkv   727.mkv  8.mkv
1014.mkv  1162.mkv  1434.mkv  166.mkv	397.mkv  600.mkv  74.mkv   907.mkv
1015.mkv  1171.mkv  1437.mkv  191.mkv	417.mkv  61.mkv   752.mkv  90.mkv
1031.mkv  1196.mkv  1438.mkv  198.mkv	420.mkv  620.mkv  756.mkv  912.mkv
1037.mkv  1202.mkv  1441.mkv  209.mkv	425.mkv  627.mkv  758.mkv  914.mkv
103.mkv   1230.mkv  1467.mkv  218.mkv	431.mkv  632.mkv  763.mkv  916.mkv
104.mkv   1242.mkv  1469.mkv  233.mkv	441.mkv  63.mkv   785.mkv  926.mkv
1050.mkv  1257.mkv  146.mkv   242.mkv	456.mkv  640.mkv  789.mkv  933.mkv
1055.mkv  1277.mkv  1489.mkv  247.mkv	470.mkv  641.mkv  792.mkv  972.mkv
1061.mkv  1289.mkv  1498.mkv  252.mkv	486.mkv  644.mkv  7.mkv    983.mkv
1074.mkv  1294.mkv  1506.mkv  278.mkv	488.mkv  651.mkv  805.mkv  987.mkv
1078.mkv  129.mkv   1517.mkv  286.mkv	506.mkv  659.mkv  812.mkv  993.mkv
1080.mkv  1308.mkv  1527.mkv  313.mkv	512.m

In [22]:
goals_train = len(os.listdir('./data/train/goals'))
nongoals_train = len(os.listdir('./data/train/nongoals'))

goals_val = len(os.listdir('./data/val/goals'))
nongoals_val = len(os.listdir('./data/val/nongoals'))

goals_test = len(os.listdir('./data/test/goals'))
nongoals_test = len(os.listdir('./data/test/nongoals'))

print("Training set: {} goals, {} non-goals".format(goals_train, nongoals_train))
print("Validation set: {} goals, {} non-goals".format(goals_val, nongoals_val))
print("Testing set: {} goals, {} non-goals".format(goals_test, nongoals_test))

Training set: 1282 goals, 1277 non-goals
Validation set: 161 goals, 159 non-goals
Testing set: 160 goals, 160 non-goals
