# Transform the image data

In [None]:
import _pickle as pickle
import glob
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt

In [None]:
class data_manager(object):
    
    def __init__(self, data_dir, class_labels, image_size):
        
        self.data_dir = data_dir
        
        self.class_labels = class_labels
        
        self.num_class = len(self.class_labels)
        
        self.image_size = image_size
        
        self.load_train_set()
        
        self.load_validation_set()
        
        self.pickle_datasets()
        
    
    def compute_label(self, label):
        '''
        Compute one-hot labels given the class size
        '''    
        one_hot = np.zeros(self.num_class)

        idx = self.class_labels.index(label)

        one_hot[idx] = 1.0

        return one_hot


    def compute_feature(self, image):
        '''
        No rescaling of the input image. 
        Standardizing pixel value from [0, 255] to [-1, 1].
        ''' 
        image = cv2.resize(image, (self.image_size, self.image_size))
        
        image = (image / 255.0) * 2.0 - 1.0

        return image      
    
            
    def load_set(self,set_name):
        
        data = []
        
        data_paths = glob.glob(os.path.join(self.data_dir, set_name, '*.png'))

        for data_path in data_paths:

            fname = os.path.basename(data_path)

            label = fname.split("_")[0]

            if label in self.class_labels:

                img = cv2.imread(data_path)

                label_vec = self.compute_label(label)

                features = self.compute_feature(img)

                data.append({'c_img': img, 'label': label_vec, 'features': features})
        
        np.random.shuffle(data)
        
        return data           
      
        
    def load_train_set(self):
        '''
        Loads the train set
        '''
        self.train_data = self.load_set('train')
        

    def load_validation_set(self):
        '''
        Loads the validation set
        '''
        self.val_data = self.load_set('val')
        
        
    def pickle_data(self, data, fname):
        
        with open(os.path.join(self.data_dir, fname), 'wb') as f:

            pickle.dump(data, f) 
            
        
    def pickle_datasets(self):
               
        # Pickle data dictionary to data directory       
        self.pickle_data(self.train_data, "training_data.pickle")
                
        self.pickle_data(self.val_data, "test_data.pickle")
        

In [None]:
data_dir = ".\\data"
CLASS_LABELS = ['apple','banana','nectarine','plum','peach','watermelon','pear','mango','grape',
                'orange','strawberry','pineapple','radish','carrot','potato','tomato','bellpepper',
                'broccoli','cabbage','cauliflower','celery','eggplant','garlic','spinach','ginger']
image_size = 90
dm = data_manager(data_dir, CLASS_LABELS, image_size)

In [None]:
print ("Total number of training samples is: %i" % len(dm.train_data))
print ("Total number of test samples is: %i" % len(dm.val_data))

In [None]:
# Calculate number of training sample per class
def summarize_data(data, class_labels):
    
    num_class = len(class_labels)
    
    counts = np.zeros(num_class)

    for d in data:
        counts += d['label']

    class_counts = dict(zip(class_labels, counts))
    
    return class_counts

def print_summary(data, class_labels):
    
    counts_dict = summarize_data(data, class_labels)
    
    for name, count in counts_dict.items():
        print ("Class: {:<15} sample counts: {:<15}".format(name, count))

print ("Training sample counts")
print_summary(dm.train_data, CLASS_LABELS)
   
print ("Test sample counts")
print_summary(dm.val_data, CLASS_LABELS)

In [None]:
# Display some sample image from datasets
def display_data(data, class_labels, data_id):
    fig = plt.figure()
    title = class_labels[np.where(data[data_id]['label']==1)[0][0]]
    plt.title(title, fontsize = 20)
    im = plt.imshow(cv2.cvtColor(data[data_id]['c_img'], cv2.COLOR_BGR2RGB))
    plt.show()

# display 200th sample in train_data
display_data(dm.train_data, CLASS_LABELS, 200)

In [None]:
# display 200th sample in val_data
display_data(dm.val_data, CLASS_LABELS, 200)