Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions pairwise/pairwise_theano.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

import theano
import theano.tensor as TT

from pairwise_python import pairwise_python_inner_broadcasting

def pairwise_theano_tensor_prepare(dtype):
X = TT.matrix(dtype=str(dtype))
dists = TT.sqrt(
TT.sum(
TT.sqr(X[:, None, :] - X),
axis=2))
rval = theano.function([X],
theano.Out(dists, borrow=True),
allow_input_downcast=True)
rval.__name__ = 'pairwise_theano_tensor_' + dtype
return rval

def pairwise_theano_blas_prepare(dtype):
X = TT.matrix(dtype=str(dtype))
X_norm_2 = (X ** 2).sum(axis=1)
dists = TT.sqrt(2 * X_norm_2 - TT.dot(X, X.T))
rval = theano.function([X],
theano.Out(dists, borrow=True),
allow_input_downcast=True)
rval.__name__ = 'pairwise_theano_blas_' + dtype
return rval


benchmarks = (
pairwise_theano_tensor_prepare('float32'),
pairwise_theano_tensor_prepare('float64'),
pairwise_theano_blas_prepare('float32'),
pairwise_theano_blas_prepare('float64'),
)


5 changes: 4 additions & 1 deletion run_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
# License: MIT
from __future__ import print_function

from collections import OrderedDict
try:
from collections import OrderedDict
except:
from ordereddict import OrderedDict
import json
import os
import traceback
Expand Down