#### Definition of Clustering Function:

In [41]:
### Clustering:

from sklearn.cluster import KMeans

# Function for data clustering, and computation of cluster center vectors and inv. sq. deviation vectors
def form_clusters(data, n_kmeans):
    
    kmeans = KMeans(n_clusters=n_kmeans, random_state=0, init='k-means++', algorithm='elkan')
    lbls_kmeans = kmeans.fit_predict(data)
    
    lbls_set = [lbls_kmeans] # just as an option to include more clustering methods
    
    centers   = []
    inv_sq_devs  = []
    
    # Find cluster centers and covar matrix:
    for lbls in lbls_set:
        for k in range(max(lbls)+1):
            # Ignore single element clusters (we cannot determine the spread)
            if(sum(lbls==k)>=2):
                # print(k, sum(lbls==k),'elems')
                cluster = data[lbls==k,:]
                centers.append(cluster.mean(axis=0))
                dev = np.std(cluster,axis=0)
                dev[dev==0]+=1e-6 # to avoid Infs and NaNs
                inv_sq_devs.append(np.reciprocal(np.square(dev)))
        del cluster, dev
    return centers, inv_sq_devs;


#### Definition of RBF (Fi) Function:

In [42]:
# Function to calculate one Fi column for given dataset (for one RB center)
def calc_fi_column(data, center, inv_sq_devs):
    # data        - given dataset matrix for which to compute fi values
    # center      - given center vector on which to compute fi vals
    # inv_sq_devs - given reciprocal sq. deviation vector on which to compute fi vals 
    
    fi_col = np.empty((data.shape[0],1))
    # tmp  = data - np.dot(np.ones((data.shape[0],1)),center.reshape((1,center.shape[0])))
    # tmp2 = np.square(tmp)
    # tmp3 = np.dot(tmp2,inv_sq_devs)
    # fi_col = np.exp(-np.sqrt(tmp3))
    fi_col = np.exp(-np.sqrt(np.dot(np.square(data - np.dot(np.ones((data.shape[0],1)),center.reshape((1,center.shape[0])))),inv_sq_devs)))

    return fi_col;

# Function to compute whole Fi output for given dataset (for all RB centers)
def fi_transform(data, all_centers, all_inv_sq_devs):
    # data        - given dataset matrix for which to compute fi values
    # center      - list of all center vectors on which to compute fi vals
    # inv_sq_devs - list of all reciprocal sq. deviation vectors on which to compute fi vals     
    new_data = np.empty((data.shape[0],len(all_centers)))
    for k in range(len(all_centers)):
        new_data[:,k] = calc_fi_column(data, all_centers[k], all_inv_sq_devs[k])
    
    return new_data

#### Definition of RBF Transformer (sklearn compatible object):

In [46]:
# http://scikit-learn.org/stable/modules/classes.html#module-sklearn.base

from sklearn.base import TransformerMixin

# SKLEARN Compatible Transformer - supports fit method (custering data) and transform method (calculating fi values)
class myRBFtransformer(TransformerMixin):
    
    # Transformer initialization (default to 50 kMeans clusters)
    def __init__(self, n_kmeans=50):
        self.centers     = []
        self.inv_sq_devs = []
        self.n_kmeans    = n_kmeans
        # print(self.n_kmeans)
    
    # Clusters each digit and finds cluster centers and deviations:
    def fit(self, X, y):
        self.centers  = []
        self.inv_sq_devs = []
        for dig in range(10):
            data = X[y==dig,:]
            centers, inv_sq_devs = form_clusters(data, self.n_kmeans)
            self.centers.extend(centers)
            self.inv_sq_devs.extend(inv_sq_devs)
        return self
    
    # Computes fi values (with Gaussian function) based on obtained cluster centers and deviations:
    def transform(self, X, y=None):
        result = fi_transform(X, self.centers, self.inv_sq_devs)
        return result

In [47]:
import numpy as np
DATA_folder  = '../../Data/'
CLUSTER_folder = DATA_folder+'Clusters/'### Data Transformation:
data = np.load(CLUSTER_folder+'train_imgs.npy')
lbls = np.load(CLUSTER_folder+'train_lbls.npy')
data[lbls==0,:].shape

(5923, 784)

In [49]:
rbf = myRBFtransformer(n_kmeans = 10)

In [50]:
a = rbf.fit_transform(X=data[:5000,:],y=lbls[:5000])

In [51]:
a.max(axis=1)[:20]

array([  5.96230507e-11,   7.29714668e-06,   7.02719853e-16,
         9.43751200e-05,   5.80700201e-07,   9.33582647e-09,
         1.34884251e-06,   5.30407799e-07,   1.38960260e-04,
         4.51619094e-07,   1.47673402e-06,   1.18912113e-06,
         1.63669014e-09,   5.99345714e-06,   3.24922518e-04,
         1.03333804e-05,   2.51904559e-09,   2.30081078e-06,
         3.35790053e-07,   7.83227159e-07])

In [53]:
lbls[0]

5