You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
MooooCat
changed the title
API normalized_mutual_info_score may violate Symmetry.
[Bug] API normalized_mutual_info_score may violate Symmetry.
Jan 18, 2024
Description
The output of api normalized_mutual_info_score may violate Symmetry, although the official documentation claims that
"This metric is furthermore symmetric: switching label_true with label_pred will return the same score value. "https://scikit-learn.org/stable/modules/generated/sklearn.metrics.normalized_mutual_info_score.html
Reproduce
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics.cluster import normalized_mutual_info_score
a = "guest"
b = "user"
c = "admin"
src = [a,c,b,b,c,a,c,a,c,c]
tar = [a,b,a,a,b,c,b,b,c,a]
test_nums = 100
for i in range(test_nums):
le = LabelEncoder()
src_list = list(set(src))
tar_list = list(set(tar))
fit_list = tar_list + src_list
le.fit(fit_list)
src_col = le.transform(src)
tar_col = le.transform(tar)
test1 = normalized_mutual_info_score(src_col, tar_col,average_method='geometric')
test2 = normalized_mutual_info_score(tar_col, src_col,average_method='geometric')
print(f"iter:{i}: test1:{test1} test2:{test2}")
print(src_col,tar_col)
print(tar_col,src_col)
assert test2==test1
average_method
, but the error was still there.Expected behavior
Keep the Symmetry property of normalized mutual information. We may need to rewrite code by ourselves.
Context
The text was updated successfully, but these errors were encountered: