Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[WIP] Sparse Tensor #5800

Merged
merged 29 commits into from
Jun 26, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f98912b
squash
eric-haibin-lin Apr 29, 2017
a880bc7
draft for sgd rsp rsp (#75)
eric-haibin-lin Jun 10, 2017
6d329cd
fix lint (#78)
eric-haibin-lin Jun 10, 2017
c8d3742
fix lint (#79)
eric-haibin-lin Jun 10, 2017
16a6d7f
serial elemwise sum impl (#80)
eric-haibin-lin Jun 10, 2017
87bb1f7
bug fix for initializing module with row_sparse weight (#81)
eric-haibin-lin Jun 11, 2017
3d2d1c0
Sparse ndarray serialization and deserialization (#77)
reminisce Jun 12, 2017
1e86dbf
Fix lint (#84)
reminisce Jun 12, 2017
86f896a
Sgd with row_sparse weight, dns gradient (#83)
eric-haibin-lin Jun 13, 2017
df49954
update mshadow version (#88)
eric-haibin-lin Jun 13, 2017
3fc1703
csr slice bug fix (#90)
eric-haibin-lin Jun 13, 2017
972647d
benchmark dot code refactor (#87)
jiajiechen Jun 13, 2017
31a0233
Add unit test (#91)
jiajiechen Jun 14, 2017
0b021d5
kvstore push row sparse (#93)
reminisce Jun 14, 2017
073d8f2
Fix windows openmp build failure (#94)
reminisce Jun 14, 2017
1a129e4
update mshadow submoduel (#95)
eric-haibin-lin Jun 14, 2017
b2b3af2
Revert "update mshadow submoduel (#95)" (#96)
eric-haibin-lin Jun 14, 2017
24105cf
Refactor sparse tensor code (#99)
reminisce Jun 19, 2017
ddbe565
fix pylint (#100)
eric-haibin-lin Jun 19, 2017
905304c
Fix refactor sparse gpu test (#104)
reminisce Jun 21, 2017
a9612e3
change idx types from int32 to int64 (#101)
eric-haibin-lin Jun 21, 2017
c134687
revert LOG(DEBUG) change (#105)
eric-haibin-lin Jun 21, 2017
1914471
fix undefined zeros in optimizer.py (#106)
eric-haibin-lin Jun 21, 2017
ad8e74c
move init dns zeros to init_op.h for kvstore to use (#107)
eric-haibin-lin Jun 22, 2017
ac554ac
Refactor cast storage (#109)
reminisce Jun 23, 2017
2727528
Rowsparse kv (#111)
eric-haibin-lin Jun 25, 2017
4d4fbc7
fix lint (#113)
eric-haibin-lin Jun 26, 2017
f8556cc
rename some python funciton (#114)
eric-haibin-lin Jun 26, 2017
cb5deca
fix lint (#115)
eric-haibin-lin Jun 26, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -215,17 +215,17 @@ del /Q *.7z
// Python unittest for CPU
def python_ut(docker_type) {
timeout(time: max_time, unit: 'MINUTES') {
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/unittest"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/unittest"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/unittest"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/train"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/train"
}
}

// GPU test has two parts. 1) run unittest on GPU, 2) compare the results on
// both CPU and GPU
def python_gpu_ut(docker_type) {
timeout(time: max_time, unit: 'MINUTES') {
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/gpu"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/gpu"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/gpu"
}
}
Expand Down
228 changes: 228 additions & 0 deletions benchmark/python/sparse_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
import ctypes

from mxnet.test_utils import *
import scipy.sparse as sp
import os
import time
import argparse

from mxnet.base import check_call, _LIB
from util import get_data, estimate_density

parser = argparse.ArgumentParser(description="Benchmark sparse operators",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--num-omp-threads', type=int, default=1, help='number of omp threads to set in MXNet')
args = parser.parse_args()

# some data information
kdda = {
'data_mini': 'kdda.t.mini',
'data_name': 'kdda.t',
'data_origin_name': 'kdda.t.bz2',
'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2",
'feature_dim': 20216830,
'm': 200,
'batch_size': [64]
}

avazu = {
'data_mini': 'avazu-app.t.mini',
'data_name': 'avazu-app.t',
'data_origin_name': 'avazu-app.t.bz2',
'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/avazu-app.t.bz2",
'feature_dim': 1000000,
'm': 500,
'batch_size': [64, 128]
}


def measure_cost(repeat, f, *args, **kwargs):
# start bench
start = time.time()
results = []
for i in range(repeat):
results.append(f(*args, **kwargs))
for result in results:
result.wait_to_read()
end = time.time()
diff = end - start
return diff / repeat


def test_dot_real(data_dict):
def get_iter(path, data_shape, batch_size):
data_train = mx.io.LibSVMIter(data_libsvm=path,
data_shape=data_shape,
batch_size=batch_size)
data_iter = iter(data_train)
return data_iter

data_dir = os.path.join(os.getcwd(), 'data')

path = os.path.join(data_dir, data_dict['data_name'])
if not os.path.exists(path):
get_data(
data_dir,
data_dict['data_name'],
data_dict['url'],
data_dict['data_origin_name']
)
assert os.path.exists(path)

k = data_dict['feature_dim']
m = data_dict['m']
density = estimate_density(path, data_dict['feature_dim'])

mini_path = os.path.join(data_dir, data_dict['data_mini'])
if not os.path.exists(mini_path):
os.system("head -n 2000 %r > %r" % (path, mini_path))
assert os.path.exists(mini_path)

print "Running Benchmarking on %r data" % data_dict['data_mini']
for batch_size in data_dict['batch_size']: # iterator through different batch size of choice
print "batch_size is %d" % batch_size
# model
data_shape = (k, )
train_iter = get_iter(mini_path, data_shape, batch_size)
weight = mx.nd.random_uniform(low=0, high=1, shape=(k, m))

csr_data = []
dns_data = []
num_batch = 0
for batch in train_iter:
data = train_iter.getdata()
csr_data.append(data)
dns_data.append(data.todense())
num_batch += 1
bag_of_data = [csr_data, dns_data]
num_repeat = 5
costs = []
for d in bag_of_data:
weight.wait_to_read()
cost = 0.
count = 0
for d_batch in d:
d_batch.wait_to_read()
cost += measure_cost(num_repeat, mx.nd.dot, d_batch, weight)
count += 1
costs.append(cost/count)
t_sparse = costs[0]
t_dense = costs[1]
ratio = t_dense / t_sparse
print('density(%)\tn\tm\tk\tt_dense/t_sparse\tt_dense\tt_sparse')
fmt = "%0.4f\t\t%d\t%d\t%d\t%0.2f\t\t\t%0.4f\t%0.6f"
print(fmt % (density * 100, batch_size, m, k, ratio, t_dense, t_sparse))


def test_dot_synthetic():
"""benchmark mx.nd.dot(sparse_ndarray, dense_ndarray) with given density.
`t_sparse` is the time cost of dot(csr, dns), while `t_dense` is the time cost
of dot(dns, dns), with the same matrix except that it is in default storage type.
"""
def measure_cost_forward_baseline(repeat, dot, lhs, rhs):
start = time.time()
for i in range(repeat):
dot(lhs, rhs)
end = time.time()
diff = end - start
return diff / repeat

def measure_cost_backward_baseline(repeat, dot, transpose, lhs, rhs):
start = time.time()
for i in range(repeat):
dot(transpose(lhs), rhs)
end = time.time()
diff = end - start
return diff / repeat

def bench_dot_forward(m, k, n, density, ctx, repeat):
set_default_context(ctx)
dns = mx.nd.random_uniform(shape=(k, n)).copyto(ctx)
data_shape = (m, k)
csr_data = rand_ndarray(data_shape, 'csr', density)
dns_data = csr_data.todense()
rhs_dns_np = dns.asnumpy()
lhs_csr_sp = sp.csr_matrix(dns_data.asnumpy()) # csr in scipy
lhs_dns_np = lhs_csr_sp.todense()

data = [dns_data, csr_data]
costs = []
for d in data:
dns.wait_to_read()
d.wait_to_read()
cost = measure_cost(repeat, mx.nd.dot, d, dns)
costs.append(cost)
ratio = costs[0] / costs[1]

costs_baseline = []
cost = measure_cost_forward_baseline(repeat, np.dot, lhs_dns_np, rhs_dns_np)
costs_baseline.append(cost)
cost = measure_cost_forward_baseline(repeat, sp.spmatrix.dot, lhs_csr_sp, rhs_dns_np)
costs_baseline.append(cost)
ratio_baseline = costs_baseline[0] / costs_baseline[1]
fmt = "%0.1f\t\t%s\t%d\t%d\t%d\t%0.2f\t\t\t%0.2f\t%0.5f\t\t%0.2f\t\t\t\t%0.6f\t%0.5f"
print(fmt % (density * 100, str(ctx), n, m, k, ratio, costs[0], costs[1],
ratio_baseline, costs_baseline[0], costs_baseline[1]))

def bench_dot_backward(m, k, n, density, ctx, repeat):
set_default_context(ctx)
dns = mx.nd.random_uniform(shape=(m, n)).copyto(ctx)
data_shape = (m, k)
csr_data = rand_ndarray(data_shape, 'csr', density)
dns_data = csr_data.todense()
rhs_dns_np = dns.asnumpy()
lhs_csr_sp = sp.csr_matrix(dns_data.asnumpy())
lhs_dns_np = lhs_csr_sp.todense()

data = [dns_data, csr_data]
costs = []
for d in data:
dns.wait_to_read()
d.wait_to_read()
cost = measure_cost(repeat, mx.nd.dot, d, dns, transpose_a=True)
costs.append(cost)
ratio = costs[0] / costs[1]

costs_baseline = []
cost = measure_cost_backward_baseline(repeat, np.dot, np.transpose, lhs_dns_np, rhs_dns_np)
costs_baseline.append(cost)
cost = measure_cost_backward_baseline(repeat, sp.spmatrix.dot, sp.spmatrix.transpose, lhs_csr_sp, rhs_dns_np)
costs_baseline.append(cost)
ratio_baseline = costs_baseline[0] / costs_baseline[1]
fmt = "%0.1f\t\t%s\t%d\t%d\t%d\t%0.2f\t\t\t%0.2f\t%0.5f\t\t%0.2f\t\t\t\t%0.6f\t%0.5f"
print(fmt % (density * 100, str(ctx), n, m, k, ratio, costs[0], costs[1],
ratio_baseline, costs_baseline[0], costs_baseline[1]))

print("A = sparse NDArray of shape(m, k)")
print("B = dense NDArray of shape(k, n)")
print("dot_forward\tdot(csr, dns)")
print('density(%)\tcontext\tn\tm\tk\tt_dense/t_sparse\tt_dense\tt_sparse'
'\tt_scipy_dense/t_scipy_sparse\tt_scipy_dense\tt_scipy_sparse')

check_call(_LIB.MXSetNumOMPThreads(ctypes.c_int(args.num_omp_threads)))
# TODO(haibin) make these runtime options
m = 512
k = [50000, 100000]
n = [64, 128]
density = [1.00, 0.90, 0.70, 0.50, 0.30, 0.20, 0.10, 0.07, 0.05, 0.02, 0.01, 0.005, 0.001]
num_repeat = 10
# contexts = [mx.cpu(), mx.gpu(0)]
contexts = [mx.cpu()]
for i in range(2):
for ctx in contexts:
for den in density:
bench_dot_forward(m, k[i], n[i], den, ctx, num_repeat)

print("dot_backward\tdot(csr.T, dns)")
print('density(%)\tcontext\tn\tm\tk\tt_dense/t_sparse\tt_dense\tt_sparse'
'\tt_scipy_dense/t_scipy_sparse\tt_scipy_dense\tt_scipy_sparse')
for i in range(2):
for ctx in contexts:
for den in density:
bench_dot_backward(m, k[i], n[i], den, ctx, num_repeat)


if __name__ == "__main__":
test_dot_real(avazu)
test_dot_real(kdda)
test_dot_synthetic()
33 changes: 33 additions & 0 deletions benchmark/python/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
import random


def get_data(data_dir, data_name, url, data_origin_name):
if not os.path.isdir(data_dir):
os.system("mkdir " + data_dir)
os.chdir(data_dir)
if (not os.path.exists(data_name)):
import urllib
zippath = os.path.join(data_dir, data_origin_name)
urllib.urlretrieve(url, zippath)
os.system("bzip2 -d %r" % data_origin_name)
os.chdir("..")


def estimate_density(DATA_PATH, feature_size):
"""sample 10 times of a size of 1000 for estimating the density of the sparse dataset"""
if not os.path.exists(DATA_PATH):
raise Exception("Data is not there!")
density = []
P = 0.01
for _ in xrange(10):
num_non_zero = 0
num_sample = 0
with open(DATA_PATH) as f:
for line in f:
if (random.random() < P):
num_non_zero += len(line.split(" ")) - 1
num_sample += 1
density.append(num_non_zero * 1.0 / (feature_size * num_sample))
return sum(density) / len(density)

Loading