Skip to content

Commit

Permalink
more reliable role test
Browse files Browse the repository at this point in the history
  • Loading branch information
dkaslovsky committed Mar 3, 2019
1 parent 9a7abe7 commit 1584045
Showing 1 changed file with 13 additions and 21 deletions.
34 changes: 13 additions & 21 deletions tests/test_roles/test_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,29 +58,21 @@ def test_roles(self):
role_pct = self.re.role_percentage
self.assertIsNone(roles)
self.assertIsNone(role_pct)
# extract role factors so roles should be populated
features = pd.DataFrame(np.random.rand(4, 6), index='a b c d'.split())
expected_roles = {
'a': 'role_0',
'b': 'role_2',
'c': 'role_0',
'd': 'role_1',
}
expected_role_pct = pd.DataFrame.from_dict(
{
'a': {'role_0': 1.00, 'role_1': 0.00, 'role_2': 0.00},
'b': {'role_0': 0.24, 'role_1': 0.23, 'role_2': 0.53},
'c': {'role_0': 0.63, 'role_1': 0.00, 'role_2': 0.37},
'd': {'role_0': 0.39, 'role_1': 0.61, 'role_2': 0.00},
},
orient='index',
)
self.re = RoleExtractor(n_roles=3)
self.re.extract_role_factors(features)

# extract role factors so roles and role_percentage should be populated
n_roles = 3
role_names = {f'role_{i}' for i in range(n_roles)}
self.re = RoleExtractor(n_roles=n_roles)
self.re.extract_role_factors(self.features)
# test roles
roles = self.re.roles
self.assertSetEqual(set(roles.keys()), set(self.features.index))
self.assertTrue(set(roles.values()).issubset(role_names))
# test role_percentage
role_pct = self.re.role_percentage
self.assertDictEqual(roles, expected_roles)
pd.testing.assert_frame_equal(role_pct.round(2), expected_role_pct)
self.assertSetEqual(set(role_pct.index), set(self.features.index))
self.assertSetEqual(set(role_pct.columns), role_names)
self.assertTrue(np.allclose(role_pct.sum(axis=1).values, np.ones((role_pct.shape[0], 1))))

def test_explain(self):
with self.assertRaises(NotImplementedError):
Expand Down

0 comments on commit 1584045

Please sign in to comment.