In [61]:
import pickle
import os
import random
from PIL import Image
import numpy as np

In [62]:
def shuffle(*arrs):
    arrs = list(arrs)
    for i, arr in enumerate(arrs):
        assert len(arrs[0]) == len(arrs[i])
        arrs[i] = np.array(arr)
    p = np.random.permutation(len(arrs[0]))
    return tuple(arr[p] for arr in arrs)

def to_categorical(y, nb_classes):
    y = np.asarray(y, dtype='int32')
    # high dimensional array warning
    if len(y.shape) > 2:
        warnings.warn('{}-dimensional array is used as input array.'.format(len(y.shape)), stacklevel=2)
    # flatten high dimensional array
    if len(y.shape) > 1:
        y = y.reshape(-1)
    if not nb_classes:
        nb_classes = np.max(y)+1
    Y = np.zeros((len(y), nb_classes))
    Y[np.arange(len(y)),y] = 1.
    return Y

def load_from_dir(directory, resize=None):
    train_samples = []
    train_labels = []
    test_samples = []
    test_labels = []
    label = 0
    dirs = sorted(os.walk(directory).__next__()[1])
    for d in dirs:
        files = sorted(os.walk(directory + d).__next__()[2])
        test_img_file = random.choice(files)
        for file in files:
            if test_img_file == file:
                test_img = Image.open(os.path.join(directory + d, file))
                if resize:
                    test_img = test_img.resize(resize, Image.ANTIALIAS)
                test_samples.append(np.asarray(test_img, dtype="float32")/255.)
                test_labels.append(label)
            img = Image.open(os.path.join(directory + d, file))
            if resize:
                img = img.resize(resize, Image.ANTIALIAS)
            train_samples.append(np.asarray(img, dtype="float32")/255.)
            train_labels.append(label)
        label += 1
    return train_samples, train_labels, test_samples, test_labels

def load_dataset(directory, dataset_file, resize=None, shuffle_data=False, one_hot=False):
    try:
        X_train, X_label, Y_test, Y_label = pickle.load(open(dataset_file, 'rb'))
    except Exception:
        X_train, X_label, Y_test, Y_label = load_from_dir(directory, resize)
        pickle.dump((X_train, X_label, Y_test, Y_label), open(dataset_file, 'wb'))
    if one_hot:
        X_label = to_categorical(X_label, np.max(X_label) + 1)
        Y_label = to_categorical(Y_label, np.max(Y_label) + 1)
    if shuffle_data:
        X_train, X_label = shuffle(X_train, X_label)
        Y_test, Y_label = shuffle(Y_test, Y_label)
    return X_train, X_label, Y_test, Y_label

In [63]:
img_dir = '/home/chrisjan/project/training/koi_train5/images/train/'
dataset_file = '/home/chrisjan/project/training/koi_train5/data/koi_dataset.pkl'

X_train, X_label, Y_test, Y_label = load_from_dir(img_dir, (150,200))

In [102]:
X_train[431]

array([[[ 1.        ,  1.        ,  1.        ,  1.        ],
        [ 1.        ,  1.        ,  1.        ,  1.        ],
        [ 1.        ,  1.        ,  1.        ,  1.        ],
        ..., 
        [ 1.        ,  1.        ,  1.        ,  1.        ],
        [ 1.        ,  1.        ,  1.        ,  1.        ],
        [ 1.        ,  1.        ,  1.        ,  1.        ]],

       [[ 1.        ,  1.        ,  1.        ,  1.        ],
        [ 1.        ,  1.        ,  1.        ,  1.        ],
        [ 0.97254902,  0.98823529,  0.99607843,  1.        ],
        ..., 
        [ 0.96862745,  0.98823529,  0.99215686,  1.        ],
        [ 0.96862745,  0.98823529,  0.99215686,  1.        ],
        [ 0.96862745,  0.98823529,  0.99215686,  1.        ]],

       [[ 1.        ,  1.        ,  1.        ,  1.        ],
        [ 0.99607843,  1.        ,  0.99607843,  1.        ],
        [ 0.38039216,  0.70588237,  0.87450981,  1.        ],
        ..., 
        [ 0.24313726,  0

In [41]:
X, Y = shuffle(X_train, X_label)

ValueError: could not broadcast input array from shape (400,300,3) into shape (400,300)

In [33]:
Y

array([12,  5, 15, 22, 12, 26, 21, 22, 16,  7, 14,  5, 24,  7,  6, 24,  5,
       21,  5, 17,  4, 19, 17,  9, 27, 18, 26, 19,  8, 24,  7, 26, 17,  5,
       11,  5, 23, 20, 21, 10,  3,  7, 17, 21, 18, 18, 11, 11,  5, 10,  8,
       14, 25, 28, 25, 17,  1, 12, 28,  2,  5, 17,  4, 11, 24,  0, 10, 20,
       18, 20, 23,  6, 24, 16, 25, 23,  0, 18,  6,  8, 26, 14, 17, 23, 15,
       16,  5, 12,  0, 10,  9, 27,  2,  8,  0, 25, 21,  2, 18, 16, 10,  1,
        3, 19, 27, 23,  8, 23, 21,  6, 27, 13,  7, 14,  8, 24, 28,  4, 17,
       25,  3, 13, 25, 23,  0, 18, 16,  1, 14, 12,  8, 23,  5, 12, 25,  7,
        9, 11, 19, 10, 14, 10, 14,  8, 18, 16, 29, 26, 19, 28, 20, 16,  3,
       22, 18, 17,  9, 14, 10, 14, 26, 25,  4, 20, 29,  6, 29,  4, 28, 16,
        4,  1, 10,  7, 20, 13, 20, 20, 10,  7,  0, 12,  9, 12, 25, 29, 12,
        1, 24, 15, 27, 21,  3,  9, 11,  3, 16,  6, 23, 15, 29, 17,  4, 12,
        4, 16, 23,  3, 19,  2, 16,  8, 25,  0, 29, 26,  3, 13, 19,  4, 28,
        1,  6, 20, 19, 24