In [None]:
from keras_dec import DeepEmbeddingClustering ## keras_dec.py
from keras.datasets import mnist
import numpy as np
from numpy import array
from PIL import Image
import PIL.ImageOps
from os import listdir
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.applications.resnet50 import preprocess_input


def load_images(directory):
    images = []
    for name in listdir(directory):
        filename = directory + '/' + name
        image = load_img(filename, grayscale=True, target_size=(250,250))
        image = PIL.ImageOps.invert(image)
        image = img_to_array(image)
        images.append(image)
    return images

def get_mnist():
    np.random.seed(1234) # set seed for deterministic ordering
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    ## if local
    if data_type == 'digits':
        fpath = 'mnist.npz'
        with np.load(fpath) as data:
            x_train, y_train = data['x_train'], data['y_train']
            x_test, y_test = data['x_test'], data['y_test']
    
    x_all = np.concatenate((x_train, x_test), axis = 0)
    Y = np.concatenate((y_train, y_test), axis = 0)
    X = x_all.reshape(-1,x_all.shape[1]*x_all.shape[2])
    
    p = np.random.permutation(X.shape[0])
    X = X[p].astype(np.float32)*0.02
    Y = Y[p]
    return X, Y

data_type = '' ## change to 'mnist' for mnist online, 'digits' for mnist offline, 'images' for local image dir 

if data_type == 'mnist' or data_type == 'digits':
    
    X, Y  = get_mnist()
    c = DeepEmbeddingClustering(n_clusters=10, input_dim=784)  # 28 x 28 grayscale
    c.initialize(X, finetune_iters=100000, layerwise_pretrain_iters=50000)
    c.cluster(X, y=Y)

elif data_type == 'images':
    np.random.seed(1234)
    directory = 'xx'
    x_all = load_images(directory)
    x_all = array(x_all)
    x_all = x_all[:,:,:,0]
    X = x_all.reshape(-1,x_all.shape[1]*x_all.shape[2])
    p = np.random.permutation(X.shape[0])
    X = X[p].astype(np.float32)*0.02
    Y = Y[p]

    c = DeepEmbeddingClustering(n_clusters=10, input_dim=62500)    # 250 x 250 grayscale
    c.initialize(X, finetune_iters=100000, layerwise_pretrain_iters=50000)
    c.cluster(X, iter_max=3000)
    
