Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelize Sampling of LDT #744

Merged
merged 31 commits into from
May 20, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
ea8f9a7
LPT tutorial render fix attempt
kareef928 Nov 13, 2020
a2c77c3
merge
kareef928 Nov 26, 2020
895d93e
unwanted changes
kareef928 Dec 10, 2020
e05d226
lpt tutorial changes
kareef928 Dec 10, 2020
aa9059d
Merge branch 'dev' of https://github.com/microsoft/graspologic into dev
kareef928 Mar 7, 2021
948de44
parallelize ldt
kareef928 Mar 8, 2021
e62cc99
parallelize ldt
kareef928 Mar 8, 2021
507a98b
functioning parallelization
kareef928 Mar 8, 2021
fc64a64
remove print
kareef928 Mar 15, 2021
e8bd71a
remove comments
kareef928 Apr 1, 2021
0edd979
format fix
kareef928 Apr 4, 2021
eda6730
add parallel test
kareef928 Apr 4, 2021
3f7c310
Merge branch 'dev' into ldt-parallel
bdpedigo Apr 6, 2021
2071206
change seeds array
kareef928 Apr 18, 2021
8d037ca
pull
kareef928 Apr 18, 2021
742d373
pull
kareef928 Apr 27, 2021
ea6f8b5
add random_state param
kareef928 May 5, 2021
96e0e91
change workers to n_jobs
kareef928 May 6, 2021
5814892
Merge branch 'dev' into ldt-parallel
bdpedigo May 6, 2021
8a7d80e
change back to workers
kareef928 May 6, 2021
6ed928c
Merge branch 'ldt-parallel' of https://github.com/kareef928/graspolog…
kareef928 May 6, 2021
c9cbbc7
fix conflict
kareef928 May 10, 2021
c816c42
add missing imports
kareef928 May 10, 2021
709942d
remove unused warnings import
bdpedigo May 17, 2021
df7e801
allow for none in type checking
bdpedigo May 17, 2021
db44da7
update description for workers
bdpedigo May 17, 2021
817d08b
black
bdpedigo May 17, 2021
18aaea2
Merge branch 'dev' into ldt-parallel
bdpedigo May 19, 2021
907a07e
Merge branch 'dev' into ldt-parallel
bdpedigo May 19, 2021
ad04089
try only generating N seeds
bdpedigo May 19, 2021
bcc4466
Merge branch 'dev' into ldt-parallel
bdpedigo May 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
37 changes: 30 additions & 7 deletions graspologic/inference/latent_distribution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
from ..utils import import_graph, fit_plug_in_variance_estimator
from ..align import SignFlips
from ..align import SeedlessProcrustes
from sklearn.utils import check_array
from sklearn.utils import check_array, check_random_state
from sklearn.metrics import pairwise_distances
from sklearn.metrics.pairwise import pairwise_kernels
from sklearn.metrics.pairwise import PAIRED_DISTANCES
from sklearn.metrics.pairwise import PAIRWISE_KERNEL_FUNCTIONS
from hyppo.ksample import KSample
from hyppo._utils import gaussian
from collections import namedtuple
from joblib import Parallel, delayed

_VALID_DISTANCES = list(PAIRED_DISTANCES.keys())
_VALID_KERNELS = list(PAIRWISE_KERNEL_FUNCTIONS.keys())
Expand All @@ -36,6 +37,7 @@ def latent_distribution_test(
metric="euclidean",
n_components=None,
n_bootstraps=500,
random_state=None,
workers=1,
size_correction=True,
pooled=False,
Expand Down Expand Up @@ -86,6 +88,17 @@ def latent_distribution_test(
n_bootstraps : int (default=200)
Number of bootstrap iterations for the backend hypothesis test.
See :class:`hyppo.ksample.KSample` for more information.

random_state : {None, int, `~np.random.RandomState`, `~np.random.Generator`}
This parameter defines the object to use for drawing random
variates.
If `random_state` is ``None`` the `~np.random.RandomState` singleton is
used.
If `random_state` is an int, a new ``RandomState`` instance is used,
seeded with `random_state`.
If `random_state` is already a ``RandomState`` or ``Generator``
instance, then that object is used.
Default is None.

workers : int (default=1)
bdpedigo marked this conversation as resolved.
Show resolved Hide resolved
Number of workers to use. If more than 1, parallelizes the code.
Expand Down Expand Up @@ -312,7 +325,9 @@ def latent_distribution_test(
Q = np.identity(X1_hat.shape[0])

if size_correction:
X1_hat, X2_hat = _sample_modified_ase(X1_hat, X2_hat, pooled=pooled)
X1_hat, X2_hat = _sample_modified_ase(
X1_hat, X2_hat, workers=workers, random_state=random_state, pooled=pooled
)

metric_func_ = _instantiate_metric_func(metric=metric, test=test)
test_obj = KSample(test, compute_distance=metric_func_)
Expand Down Expand Up @@ -407,7 +422,7 @@ def _embed(A1, A2, n_components):
return X1_hat, X2_hat


def _sample_modified_ase(X, Y, pooled=False):
def _sample_modified_ase(X, Y, workers, random_state, pooled=False):
N, M = len(X), len(Y)

# return if graphs are same order, else ensure X the larger graph.
Expand All @@ -427,12 +442,20 @@ def _sample_modified_ase(X, Y, pooled=False):
else:
get_sigma = fit_plug_in_variance_estimator(X)
X_sigmas = get_sigma(X) * (N - M) / (N * M)

# increase the variance of X by sampling from the asy dist
X_sampled = np.zeros(X.shape)
# TODO may be parallelized, but requires keeping track of random state
for i in range(N):
X_sampled[i, :] = X[i, :] + stats.multivariate_normal.rvs(cov=X_sigmas[i])
rng = check_random_state(random_state)
X_sampled = np.asarray(
Parallel(n_jobs=workers)(
delayed(add_variance)(X[i, :], X_sigmas[i], r)
for i, r in zip(range(N), rng.randint(np.iinfo(np.int32).max, size=X.shape))
bdpedigo marked this conversation as resolved.
Show resolved Hide resolved
)
)

# return the embeddings in the appropriate order
return (Y, X_sampled) if reverse_order else (X_sampled, Y)


def add_variance(X_orig, X_sigma, seed):
np.random.seed(seed)
return X_orig + stats.multivariate_normal.rvs(cov=X_sigma)
9 changes: 9 additions & 0 deletions tests/test_latentdistributiontest.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def test_bad_kwargs(self):
# check workers argument
with pytest.raises(TypeError):
latent_distribution_test(A1, A2, workers=0.5)
latent_distribution_test(A1, A2, workers="oops")
# check size_correction argument
with pytest.raises(TypeError):
latent_distribution_test(A1, A2, size_correction=0)
Expand Down Expand Up @@ -207,11 +208,19 @@ def test_SBM_dcorr(self):
A1 = sbm(2 * [b_size], B1)
A2 = sbm(2 * [b_size], B1)
A3 = sbm(2 * [b_size], B2)

# non-parallel test
ldt_null = latent_distribution_test(A1, A2)
ldt_alt = latent_distribution_test(A1, A3)
self.assertTrue(ldt_null[0] > 0.05)
self.assertTrue(ldt_alt[0] <= 0.05)

# parallel test
ldt_null = latent_distribution_test(A1, A2, workers=-1)
ldt_alt = latent_distribution_test(A1, A3, workers=-1)
self.assertTrue(ldt_null[0] > 0.05)
self.assertTrue(ldt_alt[0] <= 0.05)

def test_different_sizes_null(self):
np.random.seed(314)

Expand Down