In [1]:
# imports and functions
import csv
import os
from PIL import Image

IMAGE_SIZE = (32, 32)
labels = ['ocean', 'ship', 'sky']

def normalize_rgb(r, g, b):
    """takes an input between 1 and 255 and returns a 
    value between 0 and 1"""
    return (r/255.0, g/255.0, b/255.0)

def extract_features_labels(file_list):
    """trains the classifier given a set of files"""
    # X contains the features, Y contains the classes
    X = []
    Y = []
    
    # train on the file_list
    for file in file_list:
        # metadata
        if file[0] != ".":
            label, dataset_name, dataset_index = file.split("_") # label_dataset_index.extension
            dataset_index, extension = dataset_index.split(".") # index.extension
            #print("%s %s %s" % (classification, dataset_name, dataset_index))
            # set the label
            Y.append(label)

            path = image_dir + "/" + file
            im = Image.open(path)
            im = im.resize(IMAGE_SIZE, resample=Image.LANCZOS)

            # get the rgb color components of all pixels
            image_features = []
            for x in range(0,IMAGE_SIZE[0]):
                for y in range(0,IMAGE_SIZE[1]):
                    r, g, b = im.getpixel((x, y))
                    r, g, b = normalize_rgb(r, g, b)
                    image_features.extend([r, g, b])
            X.append(image_features)

            #print("%s %s %s" % (r, g, b))
            new_name = ("%s_%s_%s.%s" % (label, dataset_name, dataset_index, "png"))
            im.save(tmp_dir + "/" + new_name, "PNG")
    
    return X, Y

In [2]:
# directory structure
cwd = os.getcwd()
image_dir = "../images/combined"
tmp_dir = "../images/tmp"
print("cwd = " + cwd)
print("image_dir = " + image_dir)
print("tmp_dir = " + tmp_dir)

if not os.path.exists(tmp_dir):
    os.mkdir(tmp_dir)

cwd = C:\Users\jchadwick\Documents\ml-project\code
image_dir = ../images/combined
tmp_dir = ../images/tmp


In [3]:
# process images, create labels and features
files = os.listdir(image_dir)

X, y = extract_features_labels(files)

# some stats about the dataset
print("%s observations" % (len(y)))
assert len(X) == len(y)

print("Label counts:")
for label in labels:
    print("%s - %s" % (label, y.count(label)))

3347 observations
Label counts:
ocean - 504
ship - 2347
sky - 496


In [4]:
# export to data.csv
with open('data.csv', 'w') as csvfile:
    data_writer = csv.writer(csvfile, dialect='excel')
    for row_num, label in enumerate(y):
        row = [label]
        row.extend(X[row_num])
        data_writer.writerow(row)