Skip to content

Commit

Permalink
PROD-86 Extend header clustering
Browse files Browse the repository at this point in the history
* WIP

* adressing import ordering error

* style linter

* Added unit tests

* addressed comments

* module level function modifications

GitOrigin-RevId: 5829c449bd0e9550d2770a0a94261049119456d3
  • Loading branch information
Jeesh96 committed May 18, 2022
1 parent e5a36a9 commit 6b6e48b
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 0 deletions.
64 changes: 64 additions & 0 deletions src/gretel_synthetics/utils/header_clusters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import re

from functools import reduce
from typing import List
Expand All @@ -12,6 +13,53 @@

LEFT = 0
RIGHT = 1
COMPLEX_ID_PERC_UNIQ = 0.85
COMPLEX_ID_LEN = 16
TEXT_COL_LIMIT = 1500


def _is_field_complex(field: pd.Series) -> bool:
"""
Function to determine if field is a complex ID requiring special handling.
Args:
field: column values that are being evaluated to determine if field is complex
Returns:
A boolean value that signifies whether the field is complex or not.
"""

# Return False if field has no valid values

field = field.dropna()
if len(field) == 0:
return False

# Return False if field is less than 85% unique

perc_unique = field.nunique() / len(field)
if perc_unique < COMPLEX_ID_PERC_UNIQ:
return False

# Return False if field has avg len less than 16 characters

textcol = field.to_csv(header=False, index=False)
avg_len = (len(textcol) - 2 * len(field)) / len(field)

if avg_len < COMPLEX_ID_LEN:
return False

# Return False if values do not contain numbers

contains_digit = any(map(str.isdigit, textcol[0:TEXT_COL_LIMIT]))
if not contains_digit:
return False

# Return True if field contains only numbers, letters, underscore or hyphen, else return False

return bool(
re.match("^[a-zA-Z0-9\-\_]+$", textcol[0:TEXT_COL_LIMIT].replace("\n", ""))
)


def _get_correlation_matrix(df, numeric_cat: List[str] = None):
Expand Down Expand Up @@ -187,6 +235,7 @@ def cluster(
method: str = "single",
numeric_cat: List[str] = None,
plot: bool = False,
isolate_complex_field: bool = True,
) -> List[List[str]]:
"""
Given an input dataframe, extract clusters of similar headers
Expand All @@ -209,6 +258,7 @@ def cluster(
may be used to define additional categorical fields that may
not automatically get identified as such.
plot: Plot header list as a dendogram.
isolate_complex_field: Enables isolation of complex fields when clustering
Returns:
A list of lists of column names, each column name list being an identified cluster
Expand All @@ -235,6 +285,16 @@ def prepare_response(
if df.shape[1] == 1:
return prepare_response([list(df.columns)], header_prefix)

# Check for complex fields which will require their own batch
single_batch_columns = []
if isolate_complex_field:
cluster_columns = list(df.columns)
for col in cluster_columns:
if _is_field_complex(df[col]):
single_batch_columns.append(col)
cluster_columns.remove(col)
df = df.filter(cluster_columns)

# Start by getting the correlation matrix
corr_matrix = _get_correlation_matrix(df, numeric_cat)

Expand Down Expand Up @@ -276,4 +336,8 @@ def prepare_response(
plot,
)

# Re add columns that were isolated, as individual batches
for col in single_batch_columns:
col_list.append([col])

return prepare_response(col_list, header_prefix)
29 changes: 29 additions & 0 deletions tests/utils/test_header_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ def sample_df():
)


@pytest.fixture()
def sample_df_2():
dataset_path = "https://gretel-public-website.s3.amazonaws.com/datasets/experiments/complex_id_dataset.csv"
ROUND_DECIMALS = 4
tmp = pd.read_csv(dataset_path, low_memory=False)
tmp = tmp.round(ROUND_DECIMALS)
return tmp


def test_backward_compat(sample_df):
old_clusters = hc.cluster(sample_df)
new_clusters = hc.cluster(sample_df, average_record_length_threshold=250.0)
Expand All @@ -25,3 +34,23 @@ def test_backward_compat(sample_df):
def test_no_empty_clusters(sample_df):
clusters = hc.cluster(sample_df, average_record_length_threshold=250.0)
assert [] not in clusters


# sample_df doesn't have any fields that should be single batched, so independent of isolation flag, clusters should be the same
def test_no_isolation(sample_df):
old_clusters = hc.cluster(sample_df, maxsize=20, isolate_complex_field=False)
new_clusters = hc.cluster(sample_df, maxsize=20)

assert len(old_clusters) == len(new_clusters)
assert old_clusters == new_clusters


# sample_df_2 has fields that should be single batched, so shouldn't be the same depending on state of isolation flag
# 'Prospect ID' is the one field that should be single batched
def test_isolation(sample_df_2):
old_clusters = hc.cluster(sample_df_2, maxsize=20, isolate_complex_field=False)
new_clusters = hc.cluster(sample_df_2, maxsize=20)

assert old_clusters != new_clusters
assert ["Prospect ID"] in new_clusters
assert ["Prospect ID"] not in old_clusters

0 comments on commit 6b6e48b

Please sign in to comment.