Skip to content

Commit

Permalink
Merge pull request #112 from tomlincr/master
Browse files Browse the repository at this point in the history
Add get_params method to BaseEstimator
  • Loading branch information
benedekrozemberczki committed Aug 28, 2022
2 parents a094d0b + 4a488e0 commit 0207731
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
8 changes: 8 additions & 0 deletions karateclub/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import networkx as nx
from typing import List
import re

"""General Estimator base class."""

Expand Down Expand Up @@ -30,6 +31,13 @@ def get_memberships(self):
def get_cluster_centers(self):
"""Getting the cluster centers."""
pass

def get_params(self):
"""Get parameter dictionary for this estimator.."""
rx = re.compile(r'^\_')
params = self.__dict__
params = {key: params[key] for key in params if not rx.search(key)}
return params

def _set_seed(self):
"""Creating the initial random seed."""
Expand Down
8 changes: 8 additions & 0 deletions test/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from karateclub import DeepWalk

def test_get_params():
model = DeepWalk()
params = model.get_params()
assert len(params) != 0
assert type(params) is dict
assert '_embedding' not in params

0 comments on commit 0207731

Please sign in to comment.