Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix metric mutual information #117

Closed
wants to merge 65 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
3dcd842
test
Z712023 Jan 4, 2024
34af206
test_v2
Z712023 Jan 4, 2024
9bf9108
no-test
Z712023 Jan 5, 2024
7fd6c75
pair_v1
Z712023 Jan 9, 2024
a6ad779
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2024
05223ea
remove_old_mi_sim
Z712023 Jan 10, 2024
a692a54
remove mi_sim in columns
Z712023 Jan 10, 2024
730bd9b
modify single&multi_table MISim
Z712023 Jan 10, 2024
b100dd9
modify single_mi_sim by using pair_sim instance
Z712023 Jan 10, 2024
40a19c5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2024
0031062
modify multi_mi_sim by using pair_sim instance
Z712023 Jan 10, 2024
88eaa2a
modify multi_mi_sim by using pair_sim instance
Z712023 Jan 10, 2024
1c4026b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2024
8c333dd
change_class_name_err
Z712023 Jan 10, 2024
032df09
Merge branch 'feature-metric-mutual_information' of github.com:hitsz-…
Z712023 Jan 10, 2024
0ebf11d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2024
844c13d
modify_paircolumn
Z712023 Jan 10, 2024
ca53f1a
mi only needs dataframe
Z712023 Jan 10, 2024
e4efe41
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2024
93018e0
Merge branch 'main' into feature-metric-mutual_information
MooooCat Jan 16, 2024
f3ffab7
modify based on review
Z712023 Jan 16, 2024
8583aae
test
Z712023 Jan 16, 2024
a1fb0ec
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
b0c4282
complete test_mi_sim
Z712023 Jan 16, 2024
f9024e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
fe7b080
modify test file
Z712023 Jan 16, 2024
df1e572
change_var_name
Z712023 Jan 16, 2024
dd08734
Update sdgx/metrics/multi_table/multitable_mi_sim.py
Z712023 Jan 16, 2024
3264704
add MULTI_TABLE_DEMO_DATA
Z712023 Jan 16, 2024
63854ac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
6e2ee7b
modify comments
Z712023 Jan 16, 2024
e85571e
JSD->MISIM
Z712023 Jan 16, 2024
a0ac893
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
aafa3cc
modify base of pair_column
Z712023 Jan 16, 2024
84efc35
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
c55b98d
add cls
Z712023 Jan 16, 2024
d69af4e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
0ddde82
change self into cls instance
Z712023 Jan 16, 2024
89d2bca
Merge branch 'feature-metric-mutual_information' of github.com:hitsz-…
Z712023 Jan 16, 2024
52720f0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
94fb07a
change cls
Z712023 Jan 16, 2024
0190305
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
8384b06
series2array
Z712023 Jan 16, 2024
ac4ad49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
c1503cd
test
Z712023 Jan 16, 2024
bb83518
Merge branch 'feature-metric-mutual_information' of github.com:hitsz-…
Z712023 Jan 16, 2024
5481df2
test
Z712023 Jan 16, 2024
c649868
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
110b63a
add label_encoder for category in mi_sim
Z712023 Jan 16, 2024
107587a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
7ffa05b
use series.array
Z712023 Jan 16, 2024
5847e35
change le_fit
Z712023 Jan 16, 2024
cee6dac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
96e2c0d
change transform type to np.array instead of list
Z712023 Jan 16, 2024
3df03a8
add astype
Z712023 Jan 16, 2024
7a8f766
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
03c9fd2
series2array
Z712023 Jan 16, 2024
d4c949a
foo
Z712023 Jan 16, 2024
fe956d5
change test_suit
Z712023 Jan 16, 2024
15cdebc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
e2a0db5
all right?
Z712023 Jan 16, 2024
0e21332
all right
Z712023 Jan 16, 2024
9244401
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
8020a9a
Merge branch 'main' into feature-metric-mutual_information
MooooCat Jan 16, 2024
0aa2575
init_bug_fix
Z712023 Jan 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions sdgx/metrics/column/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,8 @@ def calculate(
cls, real_data: pd.Series | pd.DataFrame, synthetic_data: pd.Series | pd.DataFrame
):
"""Calculate the metric value between columns between real table and synthetic table.

Args:
real_data(pd.DataFrame or pd.Series): the real (original) data table / column.

synthetic_data(pd.DataFrame or pd.Series): the synthetic (generated) data table / column.
"""
# This method should first check the input
Expand Down
2 changes: 1 addition & 1 deletion sdgx/metrics/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def check_output(raw_metric_value: float):
"""Check the output value.
Args:

raw_metric_value (float): the calculated raw value of the JSD metric.
raw_metric_value (float): the calculated raw value of the Mutual Information Similarity.
"""
raise NotImplementedError()

Expand Down
71 changes: 71 additions & 0 deletions sdgx/metrics/multi_table/multitable_mi_sim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
import pandas as pd
from scipy.stats import entropy
from sklearn.metrics.cluster import normalized_mutual_info_score

from sdgx.metrics.multi_table.base import MultiTableMetric
from sdgx.metrics.pair_column.mi_sim import MISim


class MISim(MultiTableMetric):
"""MISim : Mutual Information Similarity

This class is used to calculate the Mutual Information Similarity between the target columns of real data and synthetic data.

Currently, we support discrete and continuous(need to be discretized) columns as inputs.
"""

def __init__(self) -> None:
super().__init__()
self.lower_bound = 0
self.upper_bound = 1
self.metric_name = "mutual_information_similarity"
self.numerical_bins = 50

@classmethod
def calculate(
real_data: pd.DataFrame, synthetic_data: pd.DataFrame, metadata: dict
) -> pd.DataFrame:
"""
Calculate the Mutual Information Similarity between a real column and a synthetic column.
Args:
real_data (pd.DataFrame): The real data.
synthetic_data (pd.DataFrame): The synthetic data.
metadata(dict): The metadata that describes the data type of each column

Returns:
MI_similarity (float): The metric value.
"""

# 传入概率分布数组

columns = synthetic_data.columns
n = len(columns)
mi_sim_instance = MISim()
nMI_sim = np.zeros((n, n))

for i in range(len(columns)):
for j in range(len(columns)):
syn_data = pd.concat(
[synthetic_data[columns[i]], synthetic_data[columns[j]]], axis=1
)
real_data = pd.concat([real_data[columns[i]], real_data[columns[j]]], axis=1)

nMI_sim[i][j] = mi_sim_instance.calculate(real_data, syn_data, metadata)

MI_sim = np.sum(nMI_sim) / n / n
# test
MISim.check_output(MI_sim)

return MI_sim

@classmethod
def check_output(cls, raw_metric_value: float):
"""Check the output value.

Args:
raw_metric_value (float): the calculated raw value of the Mutual Information Similarity.
"""
instance = cls()
if raw_metric_value < instance.lower_bound or raw_metric_value > instance.upper_bound:
raise ValueError
75 changes: 75 additions & 0 deletions sdgx/metrics/pair_column/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pandas as pd

from sdgx.log import logger


class PairMetric(object):
"""PairMetric
Metrics used to evaluate the quality of synthetic data columns.
"""

upper_bound = None
lower_bound = None
metric_name = "Correlation"

def __init__(self) -> None:
pass

@classmethod
def check_input(cls, src_col: pd.Series, tar_col: pd.Series, metadata: dict):
"""Input check for table input.
Args:
src_data(pd.Series ): the source data column.
tar_data(pd.Series): the target data column .
metadata(dict): The metadata that describes the data type of each column
"""
# Input parameter must not contain None value
if real_data is None or synthetic_data is None:
raise TypeError("Input contains None.")
# check column_names
tar_name = tar_col.name
src_name = src_col.name

# check column_types
if metadata[tar_name] != metadata[src_name]:
raise TypeError("Type of Pair is Conflicting.")

# if type is pd.Series, return directly
if isinstance(real_data, pd.Series):
return src_col, tar_col

# if type is not pd.Series or pd.DataFrame tranfer it to Series
try:
src_col = pd.Series(src_col)
tar_col = pd.Series(tar_col)
return src_col, tar_col
except Exception as e:
logger.error(f"An error occurred while converting to pd.Series: {e}")

return None, None

@classmethod
def calculate(cls, src_col: pd.Series, tar_col: pd.Series, metadata):
"""Calculate the metric value between pair-columns between real table and synthetic table.

Args:
src_data(pd.Series ): the source data column.
tar_data(pd.Series): the target data column .
metadata(dict): The metadata that describes the data type of each column
"""
# This method should first check the input
# such as:
real_data, synthetic_data = PairMetric.check_input(src_col, tar_col)

raise NotImplementedError()

@classmethod
def check_output(cls, raw_metric_value: float):
"""Check the output value.

Args:
raw_metric_value (float): the calculated raw value of the Mutual Information.
"""
raise NotImplementedError()

pass
98 changes: 98 additions & 0 deletions sdgx/metrics/pair_column/mi_sim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import numpy as np
import pandas as pd
from scipy.stats import entropy
from sklearn.metrics.cluster import normalized_mutual_info_score
from sklearn.preprocessing import LabelEncoder

from sdgx.metrics.pair_column.base import PairMetric
from sdgx.utils import time2int


class MISim(PairMetric):
"""MISim : Mutual Information Similarity

This class is used to calculate the Mutual Information Similarity between the target columns of real data and synthetic data.

Currently, we support discrete and continuous(need to be discretized) columns as inputs.
"""

def __init__(instance) -> None:
super().__init__()
instance.lower_bound = 0
instance.upper_bound = 1
instance.metric_name = "mutual_information_similarity"
instance.numerical_bins = 50

@classmethod
def calculate(
cls,
src_col: pd.Series,
tar_col: pd.Series,
metadata: dict,
) -> float:
"""
Calculate the MI similarity for the source data colum and the target data column.
Args:
src_data(pd.Series ): the source data column.
tar_data(pd.Series): the target data column .
metadata(dict): The metadata that describes the data type of each columns
Returns:
MI_similarity (float): The metric value.
"""

# 传入概率分布数组
instance = cls()

col_name = src_col.name
data_type = metadata[col_name]

if data_type == "numerical":
x = np.array(src_col.array)
src_col = pd.cut(
x,
instance.numerical_bins,
labels=range(instance.numerical_bins),
)
x = np.array(tar_col.array)
tar_col = pd.cut(
x,
instance.numerical_bins,
labels=range(instance.numerical_bins),
)
src_col = src_col.to_numpy()
tar_col = tar_col.to_numpy()

elif data_type == "category":

le = LabelEncoder()
src_list = list(set(src_col.array))
tar_list = list(set(tar_col.array))
fit_list = tar_list + src_list
le.fit(fit_list)

src_col = le.transform(np.array(src_col.array))
tar_col = le.transform(np.array(tar_col.array))

elif data_type == "datetime":
src_col = src_col.apply(time2int)
tar_col = tar_col.apply(time2int)
src_col = pd.cut(
src_col, bins=instance.numerical_bins, labels=range(instance.numerical_bins)
)
tar_col = pd.cut(
tar_col, bins=instance.numerical_bins, labels=range(instance.numerical_bins)
)
src_col = src_col.to_numpy()
tar_col = tar_col.to_numpy()

MI_sim = normalized_mutual_info_score(src_col, tar_col)
return MI_sim

@classmethod
def check_output(cls, raw_metric_value: float):
"""Check the output value.

Args:
raw_metric_value (float): the calculated raw value of the MI similarity.
"""
pass
4 changes: 2 additions & 2 deletions sdgx/metrics/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def check_input(cls, real_data: pd.DataFrame, synthetic_data: pd.DataFrame):

return None, None

def calculate(self, real_data: pd.DataFrame, synthetic_data: pd.DataFrame):
def calculate(cls, real_data: pd.DataFrame, synthetic_data: pd.DataFrame):
"""Calculate the metric value between a real table and a synthetic table.

Args:
Expand All @@ -71,7 +71,7 @@ def check_output(raw_metric_value: float):
"""Check the output value.

Args:
raw_metric_value (float): the calculated raw value of the JSD metric.
raw_metric_value (float): the calculated raw value of the Mutual Information Similarity.
"""
raise NotImplementedError()

Expand Down
67 changes: 67 additions & 0 deletions sdgx/metrics/single_table/single_mi_sim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import numpy as np
import pandas as pd
from scipy.stats import entropy
from sklearn.metrics.cluster import normalized_mutual_info_score

from sdgx.metrics.pair_column.mi_sim import MISim
from sdgx.metrics.single_table.base import SingleTableMetric


class SinTabMISim(SingleTableMetric):
"""MISim : Mutual Information Similarity

This class is used to calculate the Mutual Information Similarity between the target columns of real data and synthetic data.

Currently, we support discrete and continuous(need to be discretized) columns as inputs.
"""

def __init__(self) -> None:
super().__init__()
self.lower_bound = 0
self.upper_bound = 1
self.metric_name = "mutual_information_similarity"
self.numerical_bins = 50

@classmethod
def calculate(real_data: pd.DataFrame, synthetic_data: pd.DataFrame, metadata) -> pd.DataFrame:
"""
Calculate the Mutual Information Similarity between a real column and a synthetic column.
Args:
real_data (pd.DataFrame): The real data.
synthetic_data (pd.DataFrame): The synthetic data.
metadata(dict): The metadata that describes the data type of each column
Returns:
MI_similarity (float): The metric value.
"""

# 传入概率分布数组

columns = synthetic_data.columns
n = len(columns)
mi_sim_instance = MISim()
nMI_sim = np.zeros((n, n))

for i in range(len(columns)):
for j in range(len(columns)):
syn_data = pd.concat(
[synthetic_data[columns[i]], synthetic_data[columns[j]]], axis=1
)
real_data = pd.concat([real_data[columns[i]], real_data[columns[j]]], axis=1)

nMI_sim[i][j] = mi_sim_instance.calculate(real_data, syn_data, metadata)

MI_sim = np.sum(nMI_sim) / n / n
MISim.check_output(MI_sim)

return MI_sim

@classmethod
def check_output(cls, raw_metric_value: float):
"""Check the output value.

Args:
raw_metric_value (float): the calculated raw value of the Mutual Information Similarity.
"""
instance = cls()
if raw_metric_value < instance.lower_bound or raw_metric_value > instance.upper_bound:
raise ValueError
9 changes: 8 additions & 1 deletion sdgx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools
import socket
import threading
import time
import urllib.request
import warnings
from contextlib import closing
Expand All @@ -26,8 +27,8 @@
"find_free_port",
"download_multi_table_demo_data",
"get_demo_single_table",
"time2int",
]

MULTI_TABLE_DEMO_DATA = {
"rossman": {
"parent_table": "store",
Expand Down Expand Up @@ -99,6 +100,12 @@ def get_demo_single_table(data_dir: str | Path = "./dataset"):
return pd_obj, discrete_cols


def time2int(datetime, form):
time_array = time.strptime(datetime, form)
time_stamp = int(time.mktime(time_array))
return time_stamp


class Singleton(type):
"""
metaclass for singleton, thread-safe.
Expand Down
Loading