Skip to content

Commit

Permalink
Update dec example (apache#12950)
Browse files Browse the repository at this point in the history
* update dec example

* trigger CI

* update to remove dependency on sklearn data
  • Loading branch information
ThomasDelteil authored and Jose Luis Contreras committed Nov 13, 2018
1 parent 383c516 commit ea8983b
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 21 deletions.
11 changes: 10 additions & 1 deletion example/deep-embedded-clustering/README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
# DEC Implementation
This is based on the paper `Unsupervised deep embedding for clustering analysis` by Junyuan Xie, Ross Girshick, and Ali Farhadi

Abstract:

Clustering is central to many data-driven application domains and has been studied extensively in terms of distance functions and grouping algorithms. Relatively little work has focused on learning representations for clustering. In this paper, we propose Deep Embedded Clustering (DEC), a method that simultaneously learns feature representations and cluster assignments using deep neural networks. DEC learns a mapping from the data space to a lower-dimensional feature space in which it iteratively optimizes a clustering objective. Our experimental evaluations on image and text corpora show significant improvement over state-of-the-art methods.


## Prerequisite
- Install Scikit-learn: `python -m pip install --user sklearn`
- Install SciPy: `python -m pip install --user scipy`

## Data

The script is using MNIST dataset.

## Usage
run `python dec.py`
run `python dec.py`
15 changes: 8 additions & 7 deletions example/deep-embedded-clustering/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,21 @@
from __future__ import print_function

import os

import mxnet as mx
import numpy as np
from sklearn.datasets import fetch_mldata


def get_mnist():
""" Gets MNIST dataset """

np.random.seed(1234) # set seed for deterministic ordering
data_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
data_path = os.path.join(data_path, '../../data')
mnist = fetch_mldata('MNIST original', data_home=data_path)
p = np.random.permutation(mnist.data.shape[0])
X = mnist.data[p].astype(np.float32)*0.02
Y = mnist.target[p]
mnist_data = mx.test_utils.get_mnist()
X = np.concatenate([mnist_data['train_data'], mnist_data['test_data']])
Y = np.concatenate([mnist_data['train_label'], mnist_data['test_label']])
p = np.random.permutation(X.shape[0])
X = X[p].reshape((X.shape[0], -1)).astype(np.float32)*5
Y = Y[p]
return X, Y


Expand Down
25 changes: 12 additions & 13 deletions example/deep-embedded-clustering/dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
from __future__ import print_function
import sys
import os
# code to automatically download dataset
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path = [os.path.join(curr_path, "../autoencoder")] + sys.path
import mxnet as mx
import numpy as np
import data
Expand All @@ -33,14 +30,14 @@
import logging

def cluster_acc(Y_pred, Y):
from sklearn.utils.linear_assignment_ import linear_assignment
assert Y_pred.size == Y.size
D = max(Y_pred.max(), Y.max())+1
w = np.zeros((D,D), dtype=np.int64)
for i in range(Y_pred.size):
w[Y_pred[i], int(Y[i])] += 1
ind = linear_assignment(w.max() - w)
return sum([w[i,j] for i,j in ind])*1.0/Y_pred.size, w
from sklearn.utils.linear_assignment_ import linear_assignment
assert Y_pred.size == Y.size
D = max(Y_pred.max(), Y.max())+1
w = np.zeros((D,D), dtype=np.int64)
for i in range(Y_pred.size):
w[Y_pred[i], int(Y[i])] += 1
ind = linear_assignment(w.max() - w)
return sum([w[i,j] for i,j in ind])*1.0/Y_pred.size, w

class DECModel(model.MXModel):
class DECLoss(mx.operator.NumpyOp):
Expand Down Expand Up @@ -87,9 +84,9 @@ def setup(self, X, num_centers, alpha, save_to='dec_model'):
ae_model = AutoEncoderModel(self.xpu, [X.shape[1],500,500,2000,10], pt_dropout=0.2)
if not os.path.exists(save_to+'_pt.arg'):
ae_model.layerwise_pretrain(X_train, 256, 50000, 'sgd', l_rate=0.1, decay=0.0,
lr_scheduler=mx.misc.FactorScheduler(20000,0.1))
lr_scheduler=mx.lr_scheduler.FactorScheduler(20000,0.1))
ae_model.finetune(X_train, 256, 100000, 'sgd', l_rate=0.1, decay=0.0,
lr_scheduler=mx.misc.FactorScheduler(20000,0.1))
lr_scheduler=mx.lr_scheduler.FactorScheduler(20000,0.1))
ae_model.save(save_to+'_pt.arg')
logging.log(logging.INFO, "Autoencoder Training error: %f"%ae_model.eval(X_train))
logging.log(logging.INFO, "Autoencoder Validation error: %f"%ae_model.eval(X_val))
Expand Down Expand Up @@ -160,6 +157,8 @@ def refresh(i):

def mnist_exp(xpu):
X, Y = data.get_mnist()
if not os.path.isdir('data'):
os.makedirs('data')
dec_model = DECModel(xpu, X, 10, 1.0, 'data/mnist')
acc = []
for i in [10*(2**j) for j in range(9)]:
Expand Down

0 comments on commit ea8983b

Please sign in to comment.