# Shapesdata

In [1]:
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

In [55]:
def load_images():
    """
    Arguments:
    None
    
    Returns:
    images -- three numpy arrays of shape (100, 28,28,3), one for each shape (circle, square, triangle)
    """

    folder = "shapes/"
    filename = "drawing"
    extension = ".png"

    circles = np.empty((0, 28, 28, 3))
    squares = np.empty((0, 28, 28, 3))
    triangles = np.empty((0, 28, 28, 3))

    for i in range(1,101):
        file_path = folder + "circles/" + filename + "(" + str(i) + ")" + extension

        img = Image.open(file_path)
        raw = np.asarray(img)
        raw = raw[np.newaxis, :]
        circles = np.append(circles, raw, axis = 0)
        
        file_path = folder + "squares/" + filename + "(" + str(i) + ")" + extension

        img = Image.open(file_path)
        raw = np.asarray(img)
        raw = raw[np.newaxis, :]
        squares = np.append(squares, raw, axis = 0)
        
        file_path = folder + "triangles/" + filename + "(" + str(i) + ")" + extension

        img = Image.open(file_path)
        raw = np.asarray(img)
        raw = raw[np.newaxis, :]
        triangles = np.append(triangles, raw, axis = 0)

    return circles, squares, triangles

In [56]:
circles, squares, triangles = load_images()
circles.shape

(100, 28, 28, 3)

In [61]:
def split_set(data, splits = [.7, .85]):
    """
    Arguments:
        data -- list to be splitted
        splits -- array of splitting points. Ex.: [.7, .85] will split in two points producing three sets.
    
    Returns:
        list of sets: (train_set, dev_set, test_set)
    """
    
    np.random.seed(42)
    np.random.shuffle(data)
    
    train_set = data[:70]
    dev_set = data[70:]
    test_set = []
    
#     np.random.seed(42)
#     rand = np.random.rand((len(data)))
    
#     bool_selector_train = rand < splits[0]
#     train_set = data[bool_selector_train]
    
#     bool_selector_dev = (splits[0] < rand) & (rand < splits[1])
#     dev_set = data[bool_selector_dev]
    
#     bool_selector_test = splits[1] < rand
#     test_set = data[bool_selector_test]
    
    return (train_set, dev_set, test_set)

In [161]:
def load_data():
    
    circles, squares, triangles = load_images()  
    circles_label = np.full((100,), "circle")
    squares_label = np.full((100,), "square")
    triangles_label = np.full((100,), "triangle")
    
    circles_train_x, circles_dev_x, circles_test_x = split_set(circles)
    circles_train_y, circles_dev_y, circles_test_y = split_set(circles_label)
    
    squares_train_x, squares_dev_x, squares_test_x = split_set(squares)
    squares_train_y, squares_dev_y, squares_test_y = split_set(squares_label)
    
    triangles_train_x, triangles_dev_x, triangles_test_x = split_set(triangles)
    triangles_train_y, triangles_dev_y, triangles_test_y = split_set(triangles_label)
    
    train_x = np.concatenate([circles_train_x, squares_train_x, triangles_train_x], axis=0)
    train_y = np.concatenate([circles_train_y, squares_train_y, triangles_train_y], axis=0)
    
    dev_x = np.concatenate([circles_dev_x, squares_dev_x, triangles_dev_x], axis=0)
    dev_y = np.concatenate([circles_dev_y, squares_dev_y, triangles_dev_y], axis=0)
    
    test_x = np.concatenate([circles_test_x, squares_test_x, triangles_test_x], axis=0)
    test_y = np.concatenate([circles_test_y, squares_test_y, triangles_test_y], axis=0)
    
    np.random.seed(42)
    
    p = np.random.permutation(len(train_x))
    train_x = train_x[p]
    train_y = train_y[p]
    train_y = train_y[:, np.newaxis]
    
    p = np.random.permutation(len(dev_x))
    dev_x = dev_x[p]
    dev_y = dev_y[p]
    dev_y = dev_y[:, np.newaxis]
    
    p = np.random.permutation(len(test_x))
    test_x = test_x[p]
    test_y = test_y[p]
    test_y = test_y[:, np.newaxis]


    
    return (train_x, train_y, dev_x, dev_y, test_x, test_y)

# train_x, train_y, dev_x, dev_y, test_x, test_y = load_data()

# print(test_x.shape)

In [159]:
def print_img(array, array_label, index):
    plt.imshow((array[index]).astype(np.uint8))
    print("This is a {}".format(array_label[index]))
    
# print_img(dev_x, dev_y, 65)