In [1]:
#!/usr/bin/env python

In [2]:
import argparse
import os
import sys

from sklearn.model_selection import train_test_split

In [3]:
parser = argparse.ArgumentParser()

parser.add_argument('--imagedir', '-i')
parser.add_argument('--keep', '-k', default=1.0) #... Used to 'down sample' images culled from video
parser.add_argument('--numimages', '-n', default=1e6) #... Mostly used for testing
parser.add_argument('--split', '-s', default=0.2)

_StoreAction(option_strings=['--split', '-s'], dest='split', nargs=None, const=None, default=0.2, type=None, choices=None, help=None, metavar=None)

In [5]:
try:
    get_ipython().__class__.__name__
    args = parser.parse_args(['-i=data/union/gunks/trapps'])
    print('In Jupyter...')
except:
    args = parser.parse_args()
    print('NOT in Jupyter...')
    
image_dir  = os.path.join(os.getcwd(), args.imagedir.strip())
num_images = args.numimages
split      = args.split
keep       = args.keep

In Jupyter...


In [6]:
#... NB: Trailing dash is ok
training_dir   = os.path.join(image_dir, 'trainval', 'training' + '-')
validation_dir = os.path.join(image_dir, 'trainval', 'validation' + '-')

In [7]:
image_list = []
class_list = []
i = 0
done = False

class_list = sorted(os.listdir(image_dir))

if 'trainval' in class_list:
    class_list.remove('trainval')
if 'unlabeled' in class_list:
    class_list.remove('unlabeled')

for subdir in class_list:
            
    for file in sorted(os.listdir(os.path.join(image_dir, subdir))):
        i += 1
        
        if i > num_images:
            done = True
            break
                
        if i % (1.0/keep) == 0:
            image_list.append(os.path.join(image_dir, subdir, file))
            
    if done: break

In [8]:
os.makedirs(training_dir, exist_ok=True)
os.makedirs(validation_dir, exist_ok=True)

In [9]:
 class_list

['apecall',
 'apoplexy',
 'baby',
 'bellyroll',
 'betty',
 'blackfly',
 'boston',
 'bunny',
 'ccbbroute',
 'citylights',
 'classic',
 'dirtychimney',
 'doublechin',
 'doubleclutch',
 'drunkardsdelight',
 'eyesore',
 'frogshead',
 'gorillamydreams',
 'handyandy',
 'herdiegerdieblock',
 'horseman',
 'jackie',
 'jane',
 'kenscrack',
 'laurel',
 'lowereaves',
 'mariadirect',
 'missbailey',
 'morningafter',
 'nosediveandretribution',
 'p38',
 'pasdedeux',
 'pauxdedeux',
 'pinklaurel',
 'raubenheimerspecial',
 'rhododendron',
 'ribs',
 'rmc',
 'rustyjam',
 'sixish',
 'sonofeasyo',
 'splashtic',
 'squiggles',
 'sundown',
 'suzieablock',
 'trappedlikearat',
 'uberfalldownclimb']

In [10]:
for subdir in class_list:
    print(subdir)
    os.makedirs(os.path.join(training_dir,   subdir), exist_ok=True)
    os.makedirs(os.path.join(validation_dir, subdir), exist_ok=True)

apecall
apoplexy
baby
bellyroll
betty
blackfly
boston
bunny
ccbbroute
citylights
classic
dirtychimney
doublechin
doubleclutch
drunkardsdelight
eyesore
frogshead
gorillamydreams
handyandy
herdiegerdieblock
horseman
jackie
jane
kenscrack
laurel
lowereaves
mariadirect
missbailey
morningafter
nosediveandretribution
p38
pasdedeux
pauxdedeux
pinklaurel
raubenheimerspecial
rhododendron
ribs
rmc
rustyjam
sixish
sonofeasyo
splashtic
squiggles
sundown
suzieablock
trappedlikearat
uberfalldownclimb


In [11]:
train_samples, validation_samples = train_test_split(image_list, test_size=0.2)
num_samples = len(train_samples) + len(validation_samples)
print(len(train_samples), len(validation_samples))

4632 1159


In [12]:
for src in validation_samples:
    
    label, file = src.split('/')[-2:]
    link = os.path.join(validation_dir, label, file)
    
    if not os.path.islink(link):
        try:
            print(src, link)
            os.symlink(src,link)
        except:
            print('ERROR: Could not create link: ' + src, link)

/home/joeantol/joeantolwork/project-x/data/union/gunks/trapps/jackie/jackie_20180310_150838_008.jpg /home/joeantol/joeantolwork/project-x/data/union/gunks/trapps/trainval/validation-/jackie/jackie_20180310_150838_008.jpg
/home/joeantol/joeantolwork/project-x/data/union/gunks/trapps/pasdedeux/pasdedeux-gopro-0000000308.jpg /home/joeantol/joeantolwork/project-x/data/union/gunks/trapps/trainval/validation-/pasdedeux/pasdedeux-gopro-0000000308.jpg
/home/joeantol/joeantolwork/project-x/data/union/gunks/trapps/splashtic/splashtic_20180505_131026_010.jpg /home/joeantol/joeantolwork/project-x/data/union/gunks/trapps/trainval/validation-/splashtic/splashtic_20180505_131026_010.jpg
/home/joeantol/joeantolwork/project-x/data/union/gunks/trapps/jackie/jackie-gopro-0000000066.jpg /home/joeantol/joeantolwork/project-x/data/union/gunks/trapps/trainval/validation-/jackie/jackie-gopro-0000000066.jpg
/home/joeantol/joeantolwork/project-x/data/union/gunks/trapps/ribs/ribs_20180505_101842_016.jpg /home/jo

In [13]:
for src in train_samples:
    label, file = src.split('/')[-2:]
    link = os.path.join(training_dir, label, file)
    
    if not os.path.islink(link):
        try:
            print(src, link)
            os.symlink(src,link)
        except:
            print('ERROR: Could not create link: ' + src, link)

/home/joeantol/joeantolwork/project-x/data/union/gunks/trapps/doublechin/doublechin_20180310_142154_006.jpg /home/joeantol/joeantolwork/project-x/data/union/gunks/trapps/trainval/training-/doublechin/doublechin_20180310_142154_006.jpg
/home/joeantol/joeantolwork/project-x/data/union/gunks/trapps/frogshead/frogshead-gopro-0000000162.jpg /home/joeantol/joeantolwork/project-x/data/union/gunks/trapps/trainval/training-/frogshead/frogshead-gopro-0000000162.jpg
/home/joeantol/joeantolwork/project-x/data/union/gunks/trapps/trappedlikearat/trappedlikearat_20180310_144444_012.jpg /home/joeantol/joeantolwork/project-x/data/union/gunks/trapps/trainval/training-/trappedlikearat/trappedlikearat_20180310_144444_012.jpg
/home/joeantol/joeantolwork/project-x/data/union/gunks/trapps/herdiegerdieblock/herdiegerdieblock-gopro-0000000365.jpg /home/joeantol/joeantolwork/project-x/data/union/gunks/trapps/trainval/training-/herdiegerdieblock/herdiegerdieblock-gopro-0000000365.jpg
/home/joeantol/joeantolwork/

In [14]:
cwd = os.getcwd()
os.chdir(os.path.join(image_dir, 'trainval'))

os.rename(training_dir, training_dir + str(num_samples))
os.rename(validation_dir, validation_dir + str(num_samples))

for link in ['training', 'validation']:
    if os.path.islink(link):
        os.unlink(link)
        
    os.symlink(link + '-' + str(num_samples), link)

os.chdir(cwd)

In [15]:
print('Successful completion of train_test_split...')

Successful completion of train_test_split...
