Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Extend UST class to handle weighted graphs
- Loading branch information
1 parent
3df7987
commit b7ccd08
Showing
2 changed files
with
30 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,14 +36,22 @@ | |
from dppy.utils import check_random_state | ||
|
||
|
||
def ust_sampler_wilson(list_of_neighbors, root=None, | ||
def ust_sampler_wilson(W, root=None, | ||
This comment has been minimized.
Sorry, something went wrong. |
||
random_state=None): | ||
|
||
""" | ||
Compute a random spanning tree of a graph G given by his adjacency matrix | ||
param W: | ||
Adjacency matrix of the Graph | ||
param W: | ||
This comment has been minimized.
Sorry, something went wrong. |
||
scipy.sparse.csr_matrix | ||
""" | ||
rng = check_random_state(random_state) | ||
|
||
# Initialize the tree | ||
wilson_tree_graph = nx.Graph() | ||
nb_nodes = len(list_of_neighbors) | ||
#nb_nodes = len(list_of_neighbors) | ||
nb_nodes = W.shape[0] | ||
|
||
# Initialize the root, if root not specified start from any node | ||
n0 = root if root else rng.choice(nb_nodes) # size=1)[0] | ||
|
@@ -57,7 +65,10 @@ def ust_sampler_wilson(list_of_neighbors, root=None, | |
while nb_nodes_in_tree < nb_nodes: # |Tree| = |V| - 1 | ||
|
||
# visit a neighbor of n0 uniformly at random | ||
n1 = rng.choice(list_of_neighbors[n0]) # size=1)[0] | ||
#n1 = rng.choice(list_of_neighbors[n0]) # size=1)[0] | ||
weights = (W.getrow(n0).toarray())[0].astype('float') | ||
weights /= np.sum(weights) | ||
This comment has been minimized.
Sorry, something went wrong.
guilgautier
Owner
|
||
n1 = rng.choice(np.arange(nb_nodes), p=weights) | ||
This comment has been minimized.
Sorry, something went wrong. |
||
|
||
if state[n1] == -1: # not visited => continue the walk | ||
|
||
|
@@ -105,7 +116,7 @@ def ust_sampler_wilson_nodes(W, absorbing_weight=0, random_state=None): | |
:param W: | ||
Adjacency matrix of the graph | ||
:type W: | ||
array_like | ||
scipy.sparse.csr_matrix | ||
:param absorbing_weight: | ||
Weight of the node Delta added to the graph | ||
|
@@ -138,7 +149,7 @@ def ust_sampler_wilson_nodes(W, absorbing_weight=0, random_state=None): | |
all_path = [] | ||
|
||
# Compute the probabilities of transition | ||
transition_probabilities = np.pad(W, [(0, 0), (0, 1)], mode='constant', constant_values=absorbing_weight).astype('float') | ||
transition_probabilities = np.pad(W.toarray(), [(0, 0), (0, 1)], mode='constant', constant_values=absorbing_weight).astype('float') | ||
norm = np.sum(transition_probabilities, axis=1) | ||
transition_probabilities[np.nonzero(norm), :] /= norm[np.nonzero(norm), None] | ||
|
||
|
@@ -211,14 +222,15 @@ def ust_sampler_wilson_nodes(W, absorbing_weight=0, random_state=None): | |
#print("Nu=", Nu) | ||
return Y, all_path, wilson_tree_from_path | ||
|
||
def ust_sampler_aldous_broder(list_of_neighbors, root=None, | ||
def ust_sampler_aldous_broder(W, root=None, | ||
random_state=None): | ||
|
||
rng = check_random_state(random_state) | ||
|
||
# Initialize the tree | ||
aldous_tree_graph = nx.Graph() | ||
nb_nodes = len(list_of_neighbors) | ||
#nb_nodes = len(list_of_neighbors) | ||
nb_nodes = W.shape[0] | ||
|
||
# Initialize the root, if root not specified start from any node | ||
n0 = root if root else rng.choice(nb_nodes) # size=1)[0] | ||
|
@@ -231,7 +243,10 @@ def ust_sampler_aldous_broder(list_of_neighbors, root=None, | |
while nb_nodes_in_tree < nb_nodes: | ||
|
||
# visit a neighbor of n0 uniformly at random | ||
n1 = rng.choice(list_of_neighbors[n0]) # size=1)[0] | ||
#n1 = rng.choice(list_of_neighbors[n0]) # size=1)[0] | ||
weights = (W.getrow(n0).toarray())[0].astype('float') | ||
weights /= np.sum(weights) | ||
n1 = rng.choice(np.arange(nb_nodes), p=weights) | ||
|
||
if visited[n1]: | ||
pass # continue the walk | ||
|
More explicit name of variable are expected