Skip to content

Commit

Permalink
make cluster thresholds compatible with link and dedupe
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Nov 7, 2021
1 parent 77e22fc commit 1d34219
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions splink/cluster.py
@@ -1,9 +1,13 @@
from pyspark.sql.dataframe import DataFrame
from typing import Union
from typing import Union, List
from .model import Model

from pyspark.sql.functions import expr

from splink.vertically_concat import (
vertically_concatenate_datasets,
)

graphframes_installed = True
try:
from graphframes import GraphFrame
Expand Down Expand Up @@ -126,7 +130,7 @@ def _threshold_values_to_dict(threshold_values):


def clusters_at_thresholds(
df_nodes: DataFrame,
df_of_dfs_nodes: Union[DataFrame, List[DataFrame]],
df_edges: DataFrame,
threshold_values: Union[float, list, dict],
model: Model,
Expand All @@ -138,7 +142,10 @@ def clusters_at_thresholds(
from a table of scored edges (scored pairwise comparisons)
Args:
df_nodes (DataFrame): Dataframe of nodes (original records from which pairwise comparisons are derived)
df_of_dfs_nodes (Union[DataFrame, List[DataFrame]]): Dataframe or Dataframes of nodes (original records
from which pairwise comparisons are derived). If the link_type is `dedupe_only`, this will be a
single dataframe. If the link_type is `link_and_dedupe` or `link_only`, this will be a list of dataframes.
The provided dataframes should be the same as provided to Splink().
df_edges (DataFrame): Dataframe of edges (pairwise record comparisons with scores)
threshold_values (Union[float, list, dict]): Threshold values of the match probability (or score_colname)
above which pairwise comparisons are considered to be a match. There are three options:
Expand All @@ -158,7 +165,15 @@ def clusters_at_thresholds(
DataFrame: clustered DataFrame
"""

spark = df_nodes.sql_ctx.sparkSession
# dfs is a list of dfs irrespective of whether input was a df or list of dfs
if type(df_of_dfs_nodes) == DataFrame:
dfs = [df_of_dfs_nodes]
else:
dfs = df_of_dfs_nodes

spark = dfs[0].sql_ctx.sparkSession
df_nodes = vertically_concatenate_datasets(dfs)

if check_graphframes_installation:
_check_graphframes_installation(spark)

Expand Down

0 comments on commit 1d34219

Please sign in to comment.