Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt committed Apr 30, 2019
1 parent c3733ad commit d20dfd9
Show file tree
Hide file tree
Showing 10 changed files with 78 additions and 73 deletions.
18 changes: 9 additions & 9 deletions src/nrl/model/deepwalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,25 @@

from .util import WalkerModel
from .word2vec import Word2VecParameters
from ..walker import RandomWalkParameters, StandardRandomWalker
from ..walker import StandardRandomWalker, WalkerParameters

__all__ = [
'run_deepwalk',
'DeepWalkModel',
]


def run_deepwalk(graph: Graph,
random_walk_parameters: Optional[RandomWalkParameters] = None,
word2vec_parameters: Optional[Word2VecParameters] = None
) -> Word2Vec:
def run_deepwalk(
graph: Graph,
walker_parameters: Optional[WalkerParameters] = None,
word2vec_parameters: Optional[Word2VecParameters] = None
) -> Word2Vec:
"""Run the DeepWalk algorithm to generate a Word2Vec model."""
model = DeepWalkModel(
graph=graph,
random_walk_parameters=random_walk_parameters,
walker_parameters=walker_parameters,
word2vec_parameters=word2vec_parameters,
)
return model.fit()
return model.fit(graph)


class DeepWalkModel(WalkerModel):
Expand All @@ -44,4 +44,4 @@ class DeepWalkModel(WalkerModel):
- https://github.com/jwplayer/jwalk
"""

random_walker_cls = StandardRandomWalker
walker_cls = StandardRandomWalker
33 changes: 16 additions & 17 deletions src/nrl/model/gat2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,27 @@
from .util import WalkerModel
from .word2vec import Word2VecParameters
from ..typing import Walk
from ..walker import RandomWalkParameters, StandardRandomWalker
from ..walker import WalkerParameters, StandardRandomWalker

__all__ = [
'run_gat2vec_unsupervised',
'Gat2VecUnsupervisedModel',
]


def run_gat2vec_unsupervised(graph: Graph,
structural_vertices: VertexSeq,
random_walk_parameters: Optional[RandomWalkParameters] = None,
word2vec_parameters: Optional[Word2VecParameters] = None
) -> Word2Vec:
def run_gat2vec_unsupervised(
graph: Graph,
structural_vertices: VertexSeq,
random_walk_parameters: Optional[WalkerParameters] = None,
word2vec_parameters: Optional[Word2VecParameters] = None
) -> Word2Vec:
"""Run the unsupervised GAT2VEC algorithm to generate a Word2Vec model."""
model = Gat2VecUnsupervisedModel(
graph=graph,
structural_vertices=structural_vertices,
random_walk_parameters=random_walk_parameters,
word2vec_parameters=word2vec_parameters,
)
return model.fit()
return model.fit(graph)


class Gat2VecUnsupervisedModel(WalkerModel):
Expand All @@ -44,18 +44,17 @@ class Gat2VecUnsupervisedModel(WalkerModel):
- https://github.com/snash4/GAT2VEC (reference implementation)
"""

random_walker_cls = StandardRandomWalker
walker_cls = StandardRandomWalker

def __init__(self,
graph: Graph,
structural_vertices: VertexSeq,
random_walk_parameters: Optional[RandomWalkParameters] = None,
word2vec_parameters: Optional[Word2VecParameters] = None
) -> None:
def __init__(
self,
structural_vertices: VertexSeq,
random_walk_parameters: Optional[WalkerParameters] = None,
word2vec_parameters: Optional[Word2VecParameters] = None
) -> None:
"""Initialize the GAT2VEC unsupervised model."""
super().__init__(
graph=graph,
random_walk_parameters=random_walk_parameters,
walker_parameters=random_walk_parameters,
word2vec_parameters=word2vec_parameters,
)

Expand Down
4 changes: 2 additions & 2 deletions src/nrl/model/node2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Node2VecModel(WalkerModel):
- https://github.com/apple2373/node2vec
"""

random_walker_cls = BiasedRandomWalker
walker_cls = BiasedRandomWalker

NUM_WALKS_KEY = 'num_walks'
WALK_LENGTH_KEY = 'walk_length'
Expand All @@ -41,7 +41,7 @@ class Node2VecModel(WalkerModel):
P_KEY = 'p'
Q_KEY = 'q'

def initialize(self):
def initialize(self, graph):
"""Pre-process the model by computing transition probabilities for each node in the graph."""
if not self.random_walk_parameters.is_weighted:
for edge in self.graph.es:
Expand Down
36 changes: 21 additions & 15 deletions src/nrl/model/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .word2vec import Word2VecParameters, get_word2vec_from_walks
from ..typing import Walk
from ..walker import AbstractRandomWalker, RandomWalkParameters
from ..walker import Walker, WalkerParameters

__all__ = [
'BaseModel',
Expand All @@ -21,20 +21,22 @@
class BaseModel(ABC):
"""A base model for running Word2Vec-based algorithms."""

def __init__(self,
graph: Graph,
random_walk_parameters: Optional[RandomWalkParameters] = None,
word2vec_parameters: Optional[Word2VecParameters] = None
) -> None:
def __init__(
self,
walker_parameters: Optional[WalkerParameters] = None,
word2vec_parameters: Optional[Word2VecParameters] = None
) -> None:
"""Store the graph, parameters, then initialize the model."""
self.graph = graph
self.random_walk_parameters = random_walk_parameters or RandomWalkParameters()
self.random_walk_parameters = walker_parameters or WalkerParameters()
self.word2vec_parameters = word2vec_parameters or Word2VecParameters()

# Model is saved after being fit
self.model: Optional[Word2Vec] = None

self.initialize()

@abstractmethod
def fit(self) -> Word2Vec:
def fit(self, graph: Graph) -> Word2Vec:
"""Fit the model to the graph and parameters."""

def initialize(self) -> None:
Expand All @@ -44,23 +46,27 @@ def initialize(self) -> None:
class WalkerModel(BaseModel):
"""A base model that uses a random walker to generate walks."""

random_walker_cls: Type[AbstractRandomWalker]
walker_cls: Type[Walker]

def fit(self) -> Word2Vec:
def fit(self, graph: Graph) -> Word2Vec:
"""Fit the DeepWalk model to the graph and parameters."""
walker = self.random_walker_cls(self.random_walk_parameters)
walks = walker.get_walks(self.graph)
walker = self.walker_cls(self.random_walk_parameters)
walks = walker.get_walks(graph)

# stringify output from igraph for Word2Vec
walks = (
map(str, walk)
[
vertex['label']
for vertex in walk
]
for walk in self.transform_walks(walks)
)

return get_word2vec_from_walks(
self.model = get_word2vec_from_walks(
walks=walks,
word2vec_parameters=self.word2vec_parameters,
)
return self.model

def transform_walks(self, walks: Iterable[Walk]) -> Iterable[Walk]:
"""Transform walks (by default, simply returns the walks)."""
Expand Down
22 changes: 11 additions & 11 deletions src/nrl/model/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Iterable, Optional

import numpy as np
import pandas as pd
from gensim.models import Word2Vec
from sklearn.metrics.pairwise import cosine_similarity

Expand Down Expand Up @@ -45,30 +46,29 @@ def save_word2vec(word2vec: Word2Vec, name: str):
word2vec.wv.save_word2vec_format(fname=name)


# def get_cosine_similarity_df(word2vec: Word2Vec) -> pd.DataFrame:
# """Get the cosine similarity matrix from the embedding as a Pandas DataFrame."""
# node_labels = ...
# labels = [node_labels[n] for n in word2vec.wv.index2word]
# sim = get_cosine_similarity(word2vec)
# return pd.DataFrame(sim, index=labels, columns=labels)
def get_cosine_similarity_df(word2vec: Word2Vec) -> pd.DataFrame:
"""Get the cosine similarity matrix from the embedding as a Pandas DataFrame."""
sim = get_cosine_similarity(word2vec)
return pd.DataFrame(sim, index=word2vec.wv.index2word, columns=word2vec.wv.index2word)


def get_cosine_similarity(word2vec: Word2Vec) -> np.ndarray:
"""Get the cosine similarity matrix from the embedding.
Warning; might be very big!
"""
return 1 - cosine_similarity(word2vec.wv.vectors)
return cosine_similarity(word2vec.wv.vectors)


def get_word2vec_from_walks(walks: Iterable[Iterable[str]],
word2vec_parameters: Optional[Word2VecParameters] = None
) -> Word2Vec:
def get_word2vec_from_walks(
walks: Iterable[Iterable[str]],
word2vec_parameters: Optional[Word2VecParameters] = None
) -> Word2Vec:
"""Train Word2Vec with the given walks."""
if word2vec_parameters is None:
word2vec_parameters = Word2VecParameters()

# TODO hack this up to be an itertor so Word2Vec doesn't complain
# TODO hack this up to be an iterator so Word2Vec doesn't complain
walks = [list(x) for x in walks]

return Word2Vec(
Expand Down
6 changes: 3 additions & 3 deletions src/nrl/walker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

"""Algorithms for generating random walks."""

from .utils import AbstractRandomWalker, RandomWalkParameters
from .utils import Walker, WalkerParameters
from .walkers import BiasedRandomWalker, RestartingRandomWalker, StandardRandomWalker

__all__ = [
'AbstractRandomWalker',
'RandomWalkParameters',
'Walker',
'WalkerParameters',
'BiasedRandomWalker',
'RestartingRandomWalker',
'StandardRandomWalker',
Expand Down
10 changes: 5 additions & 5 deletions src/nrl/walker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
from ..typing import Walk

__all__ = [
'RandomWalkParameters',
'AbstractRandomWalker',
'WalkerParameters',
'Walker',
]


@dataclass
class RandomWalkParameters:
class WalkerParameters:
"""Parameters for random walks."""

#: The number of paths to get
Expand Down Expand Up @@ -51,10 +51,10 @@ class RandomWalkParameters:
is_weighted: bool = True


class AbstractRandomWalker(ABC):
class Walker(ABC):
"""An abstract class for random walkers."""

def __init__(self, parameters: RandomWalkParameters):
def __init__(self, parameters: WalkerParameters):
"""Initialize the walker with the given random walk parameters dataclass."""
self.parameters = parameters

Expand Down
8 changes: 4 additions & 4 deletions src/nrl/walker/walkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from igraph import Graph, Vertex

from .utils import AbstractRandomWalker
from .utils import Walker
from ..typing import Walk

__all__ = [
Expand All @@ -17,7 +17,7 @@
]


class StandardRandomWalker(AbstractRandomWalker):
class StandardRandomWalker(Walker):
"""Make standard random walks, choosing the neighbors at a given position uniformly."""

def get_walk(self, graph: Graph, vertex: Vertex) -> Walk:
Expand All @@ -32,7 +32,7 @@ def get_walk(self, graph: Graph, vertex: Vertex) -> Walk:
path_length += 1


class RestartingRandomWalker(AbstractRandomWalker):
class RestartingRandomWalker(Walker):
"""A random walker that restarts from the original vertex with a given probability."""

@property
Expand All @@ -56,7 +56,7 @@ def get_walk(self, graph: Graph, vertex: Vertex) -> Walk:
path_length += 1


class BiasedRandomWalker(AbstractRandomWalker):
class BiasedRandomWalker(Walker):
"""A random walker that generates second-order random walks biased by edge weights."""

NUM_WALKS_KEY = 'num_walks'
Expand Down
6 changes: 3 additions & 3 deletions tests/test_algorithm/test_deepwalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from nrl.model.deepwalk import run_deepwalk
from nrl.model.word2vec import Word2VecParameters
from nrl.walker import RandomWalkParameters
from nrl.walker import WalkerParameters
from tests.constants import get_test_network


Expand All @@ -18,14 +18,14 @@ class TestDeepWalk(unittest.TestCase):
def test_deepwalk(self):
"""Test DeepWalk."""
graph = get_test_network()
random_walk_parameters = RandomWalkParameters(
random_walk_parameters = WalkerParameters(
number_paths=5,
max_path_length=10,
)
word2vec_parameters = Word2VecParameters()
word2vec = run_deepwalk(
graph=graph,
random_walk_parameters=random_walk_parameters,
walker_parameters=random_walk_parameters,
word2vec_parameters=word2vec_parameters,
)
self.assertIsInstance(word2vec, Word2Vec)
8 changes: 4 additions & 4 deletions tests/test_algorithm/test_node2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from nrl.model import Node2VecModel
from nrl.model.word2vec import Word2VecParameters
from nrl.walker import RandomWalkParameters
from nrl.walker import WalkerParameters
from tests.constants import WEIGHTED_NETWORK_PATH, get_test_network


Expand All @@ -22,7 +22,7 @@ class TestNode2Vec(unittest.TestCase):
def test_node2vec_unweighted(self):
"""Test Node2Vec."""
graph = get_test_network()
random_walk_parameters = RandomWalkParameters(
random_walk_parameters = WalkerParameters(
number_paths=5,
max_path_length=10,
is_weighted=False
Expand All @@ -37,7 +37,7 @@ def test_node2vec_unweighted(self):
def test_node2vec_weighted(self):
"""Test Node2Vec."""
graph = get_test_network(path=WEIGHTED_NETWORK_PATH)
random_walk_parameters = RandomWalkParameters(
random_walk_parameters = WalkerParameters(
number_paths=5,
max_path_length=10,
)
Expand All @@ -56,7 +56,7 @@ def test_precompute_probs(self):
d1 = n1._precompute_probabilities()

g2 = get_test_network(WEIGHTED_NETWORK_PATH)
random_walk_parameters = RandomWalkParameters(
random_walk_parameters = WalkerParameters(
number_paths=5,
max_path_length=10,
)
Expand Down

0 comments on commit d20dfd9

Please sign in to comment.