In [1]:
'''
Load a labelezed image dataset wiht the following folder structure
./root/label1/xxx.png
./root/label1/xxy.png
./root/label2/yxx.png
./root/label2/zxx.png
...

Return 
- a numpy array with image pixels, 
- a numpy array with the label (vector of integers)
- list of labels (list of string of dirnames)

May optionally save and load using a npz file

All images must have the same shape
Support pnj jpg, jpeg files
'''

import numpy as np
from PIL import Image
from glob import glob
import os


def load_files(data, paths, data_dir):
    
    for fname in os.listdir(data_dir):
        fpath = os.path.join(data_dir, fname)
        print(fpath)
    

def get_np_img(path):
    img = Image.open(path)
    arr =  np.array(img)
    return arr

def build_img_label_dataset(data_dir, max_per_class = None,
                           save_path = None):
    
    labels = list(os.listdir(data_dir))
    labels.sort()
    
    if save_path and os.path.exists(save_path):
        data = np.load(save_path)
        return data['x'], data['y'], labels
    
    xs = []
    ys = []
    
    for yi, fname in enumerate(labels):
        fpath = os.path.join(data_dir, fname)
        img_files = []
        for ext in ['*.png', '*.jpg', '*.jpeg']:
            img_files.extend(glob(os.path.join(fpath, ext)))
            
        nvalid = 0
        for path in img_files:
            if max_per_class and nvalid >= max_per_class:
                break
                
            try:
                data = get_np_img(path)
                xs.append(np.expand_dims(data, 0))
                ys.append(yi)
                nvalid += 1
            except:
                pass
                       
    X = np.concatenate(xs, 0)
    y = np.array(ys, dtype=np.uint8)
    
    if save_path:
        np.savez(save_path, x=X, y=y)
    
    return X, y, labels