Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making layouts/auto.py work for directed graphs #807

Merged
merged 21 commits into from
Aug 12, 2021
Merged
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 10 additions & 7 deletions graspologic/layouts/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..preprocessing import cut_edges_by_weight, histogram_edge_weight
from .classes import NodePosition
from .nooverlap import remove_overlaps
from ..utils import symmetrize, largest_connected_component

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -200,11 +201,6 @@ def layout_umap(
return lcc_graph, positions


def _largest_connected_component(graph: nx.Graph) -> nx.Graph:
daxpryce marked this conversation as resolved.
Show resolved Hide resolved
largest_component = max(nx.connected_components(graph), key=len)
return graph.subgraph(largest_component).copy()


def _approximate_prune(graph: nx.Graph, max_edges_to_keep: int = 1000000):
num_edges = len(graph.edges())
logger.info(f"num edges: {num_edges}")
Expand Down Expand Up @@ -233,7 +229,7 @@ def _node2vec_for_layout(
random_seed: Optional[int] = None,
) -> Tuple[nx.Graph, np.ndarray, np.ndarray]:
graph = _approximate_prune(graph, max_edges)
graph = _largest_connected_component(graph)
graph = largest_connected_component(graph)

start = time.time()
tensors, labels = node2vec_embed(
Expand All @@ -260,7 +256,14 @@ def _node_positions_from(
sizes = _compute_sizes(degree)
covered_area = _covered_size(sizes)
scaled_points = _scale_points(down_projection_2d, covered_area)
partitions = leiden(graph, random_seed=random_seed)
if isinstance(graph, nx.DiGraph):
diane-lee-01 marked this conversation as resolved.
Show resolved Hide resolved
temp_graph = symmetrize(graph)
bdpedigo marked this conversation as resolved.
Show resolved Hide resolved
logger.warning(
"Directed graph converted to undirected graph for community detection"
)
partitions = leiden(temp_graph, random_seed=random_seed)
else:
partitions = leiden(graph, random_seed=random_seed)
positions = [
NodePosition(
node_id=key,
Expand Down