In [4]:
import pandas as pd
import numpy as np
import unicodedata
import string
import sqlalchemy as _sql
import sqlalchemy.ext.declarative as _declarative
import sqlalchemy.orm as _orm
from soupsieve.util import lower
from sqlalchemy import create_engine
import re
from rapidfuzz.fuzz import ratio, partial_ratio
from rapidfuzz.distance import JaroWinkler, Levenshtein
from rapidfuzz.process import extractOne
from fuzzywuzzy import fuzz
from jellyfish import soundex
import itertools
import json

In [5]:
# TODO: change to real database
# DATABASE_URL = "postgresql+psycopg2://db_user:password@db:5432/inzynierka_db"
# engine = _sql.create_engine(DATABASE_URL)
# SessionLocal = _orm.sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Base = _declarative.declarative_base()

# Data reading

In [6]:
class ReadData:
    def __init__(self, path):
        """
        Initialize with the path to the CSV file.
        """
        self.path = path  # Can be replaced by database connection later
        self.data = None

    def read_data(self):
        """
        Read data from the CSV file using pandas.
        Returns a pandas DataFrame.
        """
        self.data = pd.read_csv(self.path)
        return self.data

In [7]:
class WorkflowLogger:
    def __init__(self, workflow_id, db_connection_string):
        """
        Initialize the logger for tracking workflow steps.
        :param workflow_id: ID of the workflow this logger is associated with.
        :param db_connection_string: Connection string for the database.
        """
        self.workflow_id = workflow_id
        self.db_engine = create_engine(db_connection_string)

    def fetch_raw_data(self, table_name):
        """
        Fetch raw data from a given table.
        :param table_name: Name of the raw data table to fetch data from.
        :return: pandas DataFrame containing the raw data.
        """
        query = f"SELECT file_content FROM {table_name};"
        with self.db_engine.connect() as conn:
            return pd.read_sql(query, conn)

    def log_step(self, step_name, step_type, parameters, data, previous_id=None):
        """
        Log the details of a step in the workflow.
        :param step_name: Name of the step (e.g., 'Preprocessing', 'Blocking').
        :param step_type: Type of the step (e.g., 'preprocessing', 'blocking').
        :param parameters: Dictionary of parameters used in the step.
        :param data: The pandas DataFrame processed in this step.
        :param previous_id: ID of the previous step in this workflow (if any).
        :return: ID of the logged step.
        """
        data_json = data.to_json(orient="records")

        # Deactivate the current step for the same step_type
        deactivate_query = """
        UPDATE workflow_steps
        SET currently_in_use = FALSE, valid_to = NOW()
        WHERE workflow_id = %s AND step_type = %s AND currently_in_use = TRUE;
        """
        with self.db_engine.connect() as conn:
            conn.execute(deactivate_query, (self.workflow_id, step_type))

        # Insert the new step
        insert_query = """
        INSERT INTO workflow_steps (
            name, step_type, parameters, data, workflow_id, previous_id, currently_in_use
        ) VALUES (%s, %s, %s, %s, %s, %s, TRUE)
        RETURNING id;
        """
        with self.db_engine.connect() as conn:
            result = conn.execute(
                insert_query,
                (step_name, step_type, json.dumps(parameters), data_json, self.workflow_id, previous_id)
            )
            return result.fetchone()[0]

    def fetch_latest_step(self, step_type):
        """
        Fetch the latest data and parameters for a given step type.
        :param step_type: The type of step to fetch (e.g., 'preprocessing', 'blocking').
        :return: Tuple (step_id, parameters, data DataFrame) of the latest step.
        """
        query = """
        SELECT id, parameters, data
        FROM workflow_steps
        WHERE workflow_id = %s AND step_type = %s AND currently_in_use = TRUE;
        """
        with self.db_engine.connect() as conn:
            result = conn.execute(query, (self.workflow_id, step_type)).fetchone()
            if result:
                step_id, parameters, data_json = result
                parameters = json.loads(parameters)
                data = pd.read_json(data_json)
                return step_id, parameters, data
            return None, None, None

    def fetch_step_by_id(self, step_id):
        """
        Fetch data and parameters for a specific step by ID.
        :param step_id: ID of the step to fetch.
        :return: Tuple (parameters, data DataFrame) of the step.
        """
        query = """
        SELECT parameters, data
        FROM workflow_steps
        WHERE id = %s;
        """
        with self.db_engine.connect() as conn:
            result = conn.execute(query, (step_id,)).fetchone()
            if result:
                parameters, data_json = result
                parameters = json.loads(parameters)
                data = pd.read_json(data_json)
                return parameters, data
            return None, None


In [8]:
# Initialize the logger
logger = WorkflowLogger(workflow_id=1, db_connection_string="postgresql+psycopg2://db_user:password@db:5432/inzynierka_db")

# Fetch raw data from the database
source_data = logger.fetch_raw_data("workflows")
source_data.head()

Unnamed: 0,file_content


In [5]:
path = "data/restaurant-nophone.csv"
rd = ReadData(path)
source_data = rd.read_data()
source_data

Unnamed: 0,name,address,city,cuisine
0,arnie morton's of chicago,435 s. la cienega blv.,los angeles,american
1,arnie morton's of chicago,435 s. la cienega blvd.,los angeles,steakhouses
2,art's delicatessen,12224 ventura blvd.,studio city,american
3,art's deli,12224 ventura blvd.,studio city,delis
4,hotel bel-air,701 stone canyon rd.,bel air,californian
...,...,...,...,...
859,ti couz,3108 16th st.,san francisco,french
860,trio cafe,1870 fillmore st.,san francisco,american
861,tu lan,8 sixth st.,san francisco,vietnamese
862,vicolo pizzeria,201 ivy st.,san francisco,pizza


# Data Pre-processing

- Normalization (lowercasing, removing diacritics, punctuations)
- Tokenization (TBD?)
- Drop duplicates

In [6]:
class DataPreprocessing:
    def __init__(self, data):
        """
        Initialize the DataPreprocessor with the data.
        :param data: pandas DataFrame containing the data to be processed.
        """
        self.data = data
        self.columns = None
        self.processed_data = data.copy()  # A copy of the data to avoid modifying the original

    def select_columns(self, columns):
        """
        Select the columns to apply preprocessing on.
        If 'all' is passed, all columns will be selected.
        :param columns: List of columns to be normalized, or 'all' to select all columns.
        """
        if columns[0] == 'all':
            # Select all columns in the DataFrame
            self.columns = self.data.columns
        else:
            # Otherwise, use the provided list of columns
            self.columns = columns

    def _ensure_non_numeric_string_columns(self):
        """
        Internal method to ensure that only non-numeric string columns are selected for string operations.
        """
        # Filter out non-string columns (Int64, Float64, etc.)
        self.columns = [
            col for col in self.columns if self.processed_data[col].dtype == 'object'
        ]

    def lowercase(self):
        """
        Convert text to lowercase in the selected columns.
        """
        if self.columns is None:
            raise ValueError("No columns selected for preprocessing. Use select_columns method first.")

        try:
            for col in self.columns:
                self.processed_data[col] = self.processed_data[col].str.lower()
        except Exception as e:
            print(f"Error applying lowercase operation: {e}")

    def remove_diacritics(self):
        """
        Remove diacritics from text in the selected columns.
        """
        if self.columns is None:
            raise ValueError("No columns selected for preprocessing. Use select_columns method first.")

        def _remove_diacritics(text):
            if isinstance(text, str):
                return ''.join(
                    c for c in unicodedata.normalize('NFKD', text)
                    if unicodedata.category(c) != 'Mn'
                )
            return text

        try:
            for col in self.columns:
                self.processed_data[col] = self.processed_data[col].apply(_remove_diacritics)
        except Exception as e:
            print(f"Error removing diacritics: {e}")

    def remove_punctuation(self):
        """
        Remove punctuation from text in the selected columns.
        """
        if self.columns is None:
            raise ValueError("No columns selected for preprocessing. Use select_columns method first.")

        # Ensure only string columns are processed
        self._ensure_non_numeric_string_columns()

        punctuation_pattern = f"[{re.escape(string.punctuation)}]"

        try:
            for col in self.columns:
                self.processed_data[col] = self.processed_data[col].str.replace(
                    punctuation_pattern, '', regex=True
                )
        except Exception as e:
            print(f"Error removing punctuation: {e}")

    def drop_duplicates(self):
        """
        Drop exact duplicates across all columns in the DataFrame. This is a mandatory step.
        """
        try:
            self.processed_data = self.processed_data.drop_duplicates()
        except Exception as e:
            print(f"Error dropping duplicates: {e}")

    def apply_preprocessing(self, lowercase=False, diacritics_removal=False, punctuation_removal=False):
        """
        Apply preprocessing steps based on user selection.
        The order is: lowercase -> diacritics removal -> punctuation removal -> drop exact duplicates.
        :param lowercase: If True, apply lowercasing to the selected columns.
        :param diacritics_removal: If True, remove diacritics from the selected columns.
        :param punctuation_removal: If True, remove punctuation from the selected columns.
        :return: Preprocessed pandas DataFrame.
        """
        try:
            if punctuation_removal:
                self.remove_punctuation()

            # Ensure non-numeric string columns are processed
            self._ensure_non_numeric_string_columns()

            if lowercase:
                self.lowercase()

            if diacritics_removal:
                self.remove_diacritics()

            # Drop exact duplicates as the mandatory last step
            self.drop_duplicates()

        except Exception as e:
            print(f"Error during preprocessing: {e}")

        return self.processed_data

    def get_processed_data(self):
        """
        Return the preprocessed data.
        :return: Preprocessed pandas DataFrame.
        """
        return self.processed_data


In [7]:
preprocessor = DataPreprocessing(source_data)
columns=['all']
lowercase=True 
diacritics_removal=True
punctuation_removal=True
preprocessor.select_columns(columns=columns)
preprocessor.apply_preprocessing(lowercase=lowercase, diacritics_removal=diacritics_removal, punctuation_removal=punctuation_removal)
preprocessed_data = preprocessor.get_processed_data()
# Log preprocessing step
preprocessing_id = logger.log_step(
    step_name="Preprocessing",
    step_type="preprocessing",
    parameters={"columns": columns, "lowercase": lowercase, "diacritics_removal": diacritics_removal, "punctuation_removal": punctuation_removal},
    data=preprocessed_data
)
preprocessed_data

Unnamed: 0,name,address,city,cuisine
0,arnie mortons of chicago,435 s la cienega blv,los angeles,american
1,arnie mortons of chicago,435 s la cienega blvd,los angeles,steakhouses
2,arts delicatessen,12224 ventura blvd,studio city,american
3,arts deli,12224 ventura blvd,studio city,delis
4,hotel belair,701 stone canyon rd,bel air,californian
...,...,...,...,...
859,ti couz,3108 16th st,san francisco,french
860,trio cafe,1870 fillmore st,san francisco,american
861,tu lan,8 sixth st,san francisco,vietnamese
862,vicolo pizzeria,201 ivy st,san francisco,pizza


# Block building

- Sorted Neighborhood Method (SNM)
- Standard Blocking Method
- SBM with dynamic sliding window


In [8]:
class BlockBuilding:
    def __init__(self, data, method):
        """
        Initialize the BlockBuilder with data and method.
        :param data: pandas DataFrame containing the entity data.
        :param method: The blocking method to use ('sorted_neighborhood', 'dynamic_sorted_neighborhood', or 'standard_blocking').
        """
        self.data = data
        self.method = method
        self.blocks = None
        self.num_blocks = 0

    def build_blocks(self, columns=None, window_size=None, max_window_size=None, match_threshold=None, n_letters=3, block_index=1):
        """
        Main function to build blocks using the selected method.
        :param columns: List of columns to generate BKVs or SKVs.
        :param window_size: Window size for sorted neighborhood method.
        :param max_window_size: Maximum window size for dynamic sorted neighborhood.
        :param match_threshold: Match threshold for dynamic sorted neighborhood.
        :param n_letters: Number of letters to concatenate for SKVs.
        :param block_index: Index of the block to display (optional).
        :return: A specific block based on block_index.
        """
        if columns is None:
            raise ValueError("You must specify the columns for generating keys.")
        
        if self.method == 'standard_blocking':
            self.standard_blocking(columns)
        elif self.method == 'sorted_neighborhood':
            if window_size is None:
                raise ValueError("Window size must be provided for the sorted neighborhood method.")
            self.sorted_neighborhood(columns, window_size, n_letters)
        elif self.method == 'dynamic_sorted_neighborhood':
            if max_window_size is None or match_threshold is None:
                raise ValueError("Both max_window_size and match_threshold must be provided for the dynamic sorted neighborhood method.")
            self.dynamic_sorted_neighborhood(columns, max_window_size, match_threshold, n_letters)
        else:
            raise ValueError("Invalid method. Use 'standard_blocking', 'sorted_neighborhood', or 'dynamic_sorted_neighborhood'.")

        return self.display_block(block_index)

    def standard_blocking(self, columns):
        """
        Perform standard blocking using Soundex codes for the selected columns.
        :param columns: List of columns to use for generating BKVs.
        """
        self.blocks = self.data.copy()
        
        # Generate Soundex code for each selected column and concatenate them
        self.blocks['BKV'] = self.blocks[columns].apply(
            lambda col: col.map(lambda x: soundex(x) if isinstance(x, str) else '')
        ).agg(' '.join, axis=1)
        
        # Group by BKV and assign block IDs
        self.blocks['block_id'] = self.blocks.groupby('BKV').ngroup() + 1

        # Update the number of blocks
        self.num_blocks = self.blocks['block_id'].nunique()

    def sorted_neighborhood(self, columns, window_size, n_letters):
        """
        Perform sorted neighborhood blocking using concatenated first `n` letters of selected columns as SKVs.
        :param columns: List of columns to use for generating SKVs.
        :param window_size: Size of the sliding window.
        :param n_letters: Number of letters to concatenate for SKVs.
        """
        self.blocks = self.data.copy()

        # Generate the SKV by concatenating the first `n_letters` of each column
        self.blocks['SKV'] = self.blocks[columns].apply(
            lambda col: col.map(lambda x: x[:min(n_letters + 1, len(x))] if isinstance(x, str) else '')
        ).agg(''.join, axis=1)

        # Sort by SKV
        self.blocks = self.blocks.sort_values(by='SKV').reset_index(drop=True)

        # Assign block IDs based on window size
        self.blocks['block_id'] = (self.blocks.index // window_size) + 1

        # Update the number of blocks
        self.num_blocks = self.blocks['block_id'].nunique()

    def dynamic_sorted_neighborhood(self, columns, max_window_size, match_threshold, n_letters):
        """
        Perform dynamic sorted neighborhood blocking using SKVs.
        :param columns: List of columns to use for generating SKVs.
        :param max_window_size: Maximum size of the sliding window.
        :param match_threshold: Match threshold for window expansion.
        :param n_letters: Number of letters to concatenate for SKVs.
        """
        self.blocks = self.data.copy()
    
        # Generate the SKV by concatenating the first `n_letters + 1` of each column
        self.blocks['SKV'] = self.blocks[columns].apply(
            lambda col: col.map(lambda x: x[:min(n_letters + 1, len(x))] if isinstance(x, str) else '')
        ).agg(''.join, axis=1)
        # Sort by SKV
        self.blocks = self.blocks.sort_values(by='SKV').reset_index(drop=True)
    
        # Initialize variables
        block_ids = []
        current_block_id = 1
        window_start = 0
    
        # Iterate over sorted data to assign dynamic block IDs
        while window_start < len(self.blocks):
            # Start with a single row
            window_end = window_start + 1
    
            # Expand window dynamically
            while window_end < len(self.blocks) and (window_end - window_start) < max_window_size:
                # Check similarity between SKVs of current and next record
                similarity = fuzz.ratio(
                    self.blocks['SKV'].iloc[window_start],
                    self.blocks['SKV'].iloc[window_end]
                )
                if similarity >= match_threshold * 100:  # Convert threshold to percentage
                    window_end += 1
                else:
                    break
    
            # Assign the same block ID to all rows in the current window
            block_ids.extend([current_block_id] * (window_end - window_start))
    
            # Move to the next record
            window_start = window_end
            current_block_id += 1
    
        # Assign block IDs back to the dataframe
        self.blocks['block_id'] = block_ids
    
        # Update the number of blocks
        self.num_blocks = current_block_id - 1


    def display_block(self, block_index=1):
        """
        Display a specific block by block_id.
        :param block_index: The index of the block to display.
        :return: DataFrame containing the specified block.
        """
        if self.blocks is None:
            raise ValueError("No blocks have been generated. Run block building first.")

        return self.blocks[self.blocks['block_id'] == block_index]

    def get_blocks(self):
        """
        Return all generated blocks.
        :return: DataFrame containing all blocks.
        """
        if self.blocks is None:
            raise ValueError("No blocks have been generated. Run block building first.")

        return self.blocks

In [9]:
prev_id, _, preprocessed_data = logger.fetch_latest_step("preprocessing")
method='standard_blocking'
columns=['city', 'cuisine']
block_builder = BlockBuilding(preprocessed_data, method=method)
block_builder.build_blocks(columns=columns)
all_blocks = block_builder.get_blocks()
# Log blocking step
blocking_id = logger.log_step(
    step_name="Blocking",
    step_type="blocking",
    parameters={"method": method, "columns": columns},
    data=all_blocks,
    previous_id=prev_id
)
block_builder.display_block(100)

Unnamed: 0,name,address,city,cuisine,BKV,block_id
169,indigo coastal grill,1397 n highland ave,atlanta,eclectic,345 242,100
802,flying biscuit the,1655 mclendon ave,atlanta,eclectic,345 242,100


In [10]:
block_builder = BlockBuilding(preprocessed_data, method='sorted_neighborhood')
block_builder.build_blocks(window_size=20, columns=['city', 'cuisine'], n_letters=4)
block_builder.get_blocks()
block_builder.display_block(1)

Unnamed: 0,name,address,city,cuisine,SKV,block_id
0,anthonys,3109 piedmont rd just south of peachtree rd,atlanta,american,atla amer,1
1,ritzcarlton dining room buckhead,3434 peachtree rd ne,atlanta,american new,atla amer,1
2,ritzcarlton cafe buckhead,3434 peachtree rd ne,atlanta,american new,atla amer,1
3,pleasant peasant,555 peachtree st at linden ave,atlanta,american,atla amer,1
4,panos pauls,1232 w paces ferry rd,atlanta,american new,atla amer,1
5,johnny rockets at,2970 cobb pkwy,atlanta,american,atla amer,1
6,atlanta fish market,265 pharr rd,atlanta,american,atla amer,1
7,rjs uptown kitchen wine bar,870 n highland ave,atlanta,american,atla amer,1
8,ritzcarlton cafe atlanta,181 peachtree st,atlanta,american new,atla amer,1
9,original pancake house at,4330 peachtree rd,atlanta,american,atla amer,1


In [11]:
block_builder = BlockBuilding(preprocessed_data, method='dynamic_sorted_neighborhood')
block_builder.build_blocks(max_window_size=20, match_threshold=0.7, columns=['city', 'cuisine'], n_letters=4)
all_blocks = block_builder.get_blocks()
all_blocks
# block_builder.display_block(1)

Unnamed: 0,name,address,city,cuisine,SKV,block_id
0,anthonys,3109 piedmont rd just south of peachtree rd,atlanta,american,atla amer,1
1,ritzcarlton dining room buckhead,3434 peachtree rd ne,atlanta,american new,atla amer,1
2,ritzcarlton cafe buckhead,3434 peachtree rd ne,atlanta,american new,atla amer,1
3,pleasant peasant,555 peachtree st at linden ave,atlanta,american,atla amer,1
4,panos pauls,1232 w paces ferry rd,atlanta,american new,atla amer,1
...,...,...,...,...,...,...
853,local nochol,30869 thousand oaks blvd,westlake village,health food,west heal,121
854,don antonios,1136 westwood blvd,westwood,italian,west ital,121
855,baja fresh,3345 kimber dr,westlake village,mexican,west mexi,121
856,falafel king,1059 broxton ave,westwood,middle eastern,west midd,121


# Field and Record Comparison:

* Q-gram comparison
* Jaro-Winkler
* Soundex

In [12]:
class Comparison:
    def __init__(self, data):
        """
        Initialize the Comparison class with the data.
        :param data: pandas DataFrame containing the data to be compared.
        """
        self.data = data

    @staticmethod
    def levenshtein_similarity(str1, str2):
        """
        Calculate the normalized Levenshtein similarity between two strings.
        Ensures the result is between 0 and 1.
        """
        from rapidfuzz.distance import Levenshtein
        score = Levenshtein.normalized_similarity(str1, str2)
        return max(0, min(1, score))  # Ensure the value is between 0 and 1

    @staticmethod
    def jaro_winkler_similarity(str1, str2):
        """
        Calculate the normalized Jaro-Winkler similarity between two strings.
        Ensures the result is between 0 and 1.
        """
        from rapidfuzz.distance import JaroWinkler
        score = JaroWinkler.similarity(str1, str2)
        return max(0, min(1, score))  # Ensure the value is between 0 and 1

    @staticmethod
    def qgram_similarity(str1, str2, q=2):
        """
        Calculate the Q-gram similarity between two strings.
        Ensures the result is between 0 and 1.
        """
        def generate_qgrams(s, q):
            return [s[i:i + q] for i in range(len(s) - q + 1)]

        qgrams1 = generate_qgrams(str1, q)
        qgrams2 = generate_qgrams(str2, q)
        matches = sum(1 for q in qgrams1 if q in qgrams2)
        total_qgrams = len(set(qgrams1 + qgrams2))
        score = matches / total_qgrams if total_qgrams > 0 else 0
        return max(0, min(1, score))  # Ensure the value is between 0 and 1

    def compare_within_blocks(self, block_col, column_algorithms):
        """
        Compare all possible pairs within each block for specified columns.
        :param block_col: The column name containing block IDs.
        :param column_algorithms: Dictionary where keys are column names and values are comparison functions.
        :return: DataFrame with comparison results for all pairs in each block.
        """
        if block_col not in self.data.columns:
            raise ValueError(f"Block column '{block_col}' not found in data.")
    
        for col in column_algorithms:
            if col not in self.data.columns:
                raise ValueError(f"Comparison column '{col}' not found in data.")
    
        # Store results in a list
        results = []
    
        # Group by block_id
        grouped = self.data.groupby(block_col)
    
        for block_id, group in grouped:
            # Get all possible pairs within the block
            pairs = list(itertools.combinations(group.iterrows(), 2))
    
            for (idx1, row1), (idx2, row2) in pairs:
                result = {
                    "block_id": block_id,
                    "row1": idx1,
                    "row2": idx2,
                }
    
                # Apply the specified algorithm to each column
                for col, comparison_func in column_algorithms.items():
                    result[f"{col}_similarity"] = comparison_func(row1[col], row2[col])
    
                results.append(result)
    
        result_df = pd.DataFrame(results)
        return result_df.sort_values(by=block_col).reset_index(drop=True)


In [13]:
prev_id, _, all_blocks = logger.fetch_latest_step("blocking")

comparison = Comparison(all_blocks)

# name	address	city	cuisine
column_algorithms = {
    "name": comparison.qgram_similarity,
    "address": comparison.jaro_winkler_similarity,
    "city": comparison.jaro_winkler_similarity,
    "cuisine": comparison.qgram_similarity
}

# Compare within blocks
comparison_results = comparison.compare_within_blocks(
    block_col="block_id",
    column_algorithms=column_algorithms
)

# Log comparison step
comparison_id = logger.log_step(
    step_name="Comparison",
    step_type="comparison",
    parameters={"block_col": "block_id", "column_algorithms": column_algorithms},
    data=comparison_results,
    previous_id=prev_id
)

comparison_results



Unnamed: 0,block_id,row1,row2,name_similarity,address_similarity,city_similarity,cuisine_similarity
0,1,0,1,0.027778,0.564021,1.000000,0.666667
1,1,7,17,0.000000,0.477569,1.000000,0.666667
2,1,7,18,0.023810,0.508041,1.000000,0.181818
3,1,7,19,0.027778,0.516082,1.000000,0.181818
4,1,8,9,0.179487,0.811765,1.000000,0.666667
...,...,...,...,...,...,...,...
6207,121,851,853,0.000000,0.657814,0.869118,0.000000
6208,121,851,852,0.040000,0.593957,1.000000,0.055556
6209,121,855,857,0.000000,0.516374,0.869118,0.000000
6210,121,852,856,0.000000,0.639434,0.836111,0.040000


# Classification

* for now Threshold based only - but it is trivial to 

In [14]:
class Classifier:
    def __init__(self, blocked_data, comparison_table):
        """
        Initialize the MatchClassifier with blocked data and comparison table.
        :param blocked_data: DataFrame containing blocked source data with block_ids and SKVs/BKVs.
        :param comparison_table: DataFrame containing pairwise comparisons with similarities.
        """
        self.blocked_data = blocked_data
        self.comparison_table = comparison_table

    def classify_matches(self, method='threshold_based', thresholds=None):
        """
        Classify the results based on the selected method.
        :param method: The classification method to use ('threshold_based', etc.).
        :param thresholds: Dictionary with thresholds for classification. Example:
                           {'not_match': 0.5, 'possible_match': 0.75, 'match': 1.0}
                           - Below 0.5: Not Match
                           - Between 0.5 and 0.75: Possible Match
                           - Above 0.75: Match
        :return: DataFrame with classifications added.
        """
        if method == 'threshold_based':
            if thresholds is None:
                raise ValueError("Thresholds must be provided for threshold-based classification.")
            return self._threshold_based_classification(thresholds)
        
        elif method == 'future_method':  # Placeholder for future methods
            raise NotImplementedError("Future classification methods are not implemented yet.")
        
        else:
            raise ValueError(f"Unknown classification method: {method}")

    def _threshold_based_classification(self, thresholds):
        """
        Perform threshold-based classification.
        :param thresholds: Dictionary with thresholds for classification. Example:
                           {'not_match': 0.4, 'match': 0.75}
                           - Below 0.4: Not Match
                           - Between 0.4 and 0.75: Possible Match
                           - Above 0.75: Match
        :return: DataFrame with classifications added.
        """
        # Extract details for row1 and row2 using `loc` and add them back to the comparison table
        row1_details = self.blocked_data.loc[self.comparison_table['row1']].reset_index(drop=True)
        row2_details = self.blocked_data.loc[self.comparison_table['row2']].reset_index(drop=True)
    
        # Add row1 and row2 details directly to the comparison table
        merged_data = self.comparison_table.copy()
        for col in self.blocked_data.columns:
            if col not in ['block_id', 'SKV', 'BKV']:  # Exclude metadata columns
                merged_data[f'row1_{col}'] = row1_details[col].values
                merged_data[f'row2_{col}'] = row2_details[col].values
    
        # Dynamically find similarity columns in the comparison table
        similarity_columns = [col for col in self.comparison_table.columns if col.endswith('_similarity')]
    
        # Calculate the average similarity
        merged_data['average_similarity'] = merged_data[similarity_columns].mean(axis=1)
    
        # Classify based on thresholds
        def classify(similarity):
            if similarity < thresholds['not_match']:
                return 'Not Match'
            elif thresholds['not_match'] <= similarity < thresholds['match']:
                return 'Possible Match'
            else:
                return 'Match'
    
        merged_data['classification'] = merged_data['average_similarity'].apply(classify)
    
        return merged_data




In [15]:
classifier = Classifier(all_blocks, comparison_results)
method = 'threshold_based'
possible_match = False
# thresholds = {'not_match': 0.6, 'match': 0.75}
thresholds = {'match': 0.75}

# not_match = 0.4
classified_results = classifier.classify_matches(
    method='threshold_based',
    thresholds=thresholds,
    possible_match=possible_match
)

# Log the classification step
classification_id = logger.log_step(
    step_name="Classification",
    step_type="classification",
    parameters={
        "method": method,
        "thresholds": thresholds,
        "possible_match": possible_match
    },
    data=classified_results
)

classified_results[classified_results['classification'] == 'Match']

Unnamed: 0,block_id,row1,row2,name_similarity,address_similarity,city_similarity,cuisine_similarity,row1_name,row2_name,row1_address,row2_address,row1_city,row2_city,row1_cuisine,row2_cuisine,average_similarity,classification
73,1,10,11,0.062500,0.957115,1.0,1.000000,bones,buckhead diner,3130 piedmont road,3073 piedmont road,atlanta,atlanta,american,american,0.754904,Match
83,1,11,12,1.000000,0.978947,1.0,0.666667,buckhead diner,buckhead diner,3073 piedmont road,3073 piedmont rd,atlanta,atlanta,american,american new,0.911404,Match
97,1,1,8,0.268293,0.783193,1.0,1.000000,ritzcarlton dining room buckhead,ritzcarlton cafe atlanta,3434 peachtree rd ne,181 peachtree st,atlanta,atlanta,american new,american new,0.762871,Match
113,1,2,8,0.533333,0.783193,1.0,1.000000,ritzcarlton cafe buckhead,ritzcarlton cafe atlanta,3434 peachtree rd ne,181 peachtree st,atlanta,atlanta,american new,american new,0.829132,Match
138,1,1,2,0.558824,1.000000,1.0,1.000000,ritzcarlton dining room buckhead,ritzcarlton cafe buckhead,3434 peachtree rd ne,3434 peachtree rd ne,atlanta,atlanta,american new,american new,0.889706,Match
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5746,101,768,773,0.157895,0.929412,1.0,1.000000,oritalia,vivande porta via,1915 fillmore st,2125 fillmore st,san francisco,san francisco,italian,italian,0.771827,Match
5973,103,777,782,0.150000,1.000000,1.0,1.000000,lulu restaurantbiscafe,lulu,816 folsom st,816 folsom st,san francisco,san francisco,mediterranean,mediterranean,0.787500,Match
6114,109,823,827,1.000000,1.000000,1.0,0.000000,chinois on main,chinois on main,2709 main st,2709 main st,santa monica,santa monica,french,pacific new wave,0.750000,Match
6129,109,816,818,0.250000,0.811429,1.0,1.000000,ocean park cafe,ocean avenue,3117 ocean park blvd,1401 ocean ave,santa monica,santa monica,american,american,0.765357,Match


# Evaluation


In [16]:
class Evaluation:
    def __init__(self, source_data, classified_data):
        """
        Initialize the Evaluation class with the source data and classified data.
        :param source_data: DataFrame containing the original source data.
        :param classified_data: DataFrame containing the classification results.
        """
        self.source_data = source_data.copy()
        self.classified_data = classified_data

    def evaluate(self):
        """
        Evaluate the classified results by deduplicating matches and probable matches,
        and tagging the source data with appropriate labels and probable match indices.
        :return: DataFrame with the source data, tags, and probable match indices.
        """
        # Split the classified data into matches, probable matches, and not matches
        matches = self.classified_data[self.classified_data['classification'] == 'Match']
        probable_matches = self.classified_data[self.classified_data['classification'] == 'Possible Match']
        not_matched_indices = self.source_data.index.difference(matches['row1']).difference(probable_matches['row1'])

        # Deduplicate matches: Keep only the row with the highest similarity for each `row1`
        deduplicated_matches = self._deduplicate(matches)

        # Deduplicate probable matches: Keep only the row with the highest similarity for each `row1`
        deduplicated_probable_matches = self._deduplicate(probable_matches)

        # Create a tag column in the source data
        self.source_data['tag'] = 'not matched'  # Default tag for all rows
        self.source_data.loc[deduplicated_matches['row1'], 'tag'] = 'deduplicated'
        self.source_data.loc[deduplicated_probable_matches['row1'], 'tag'] = 'probable match'

        # Add a column for probable match indices
        self.source_data['probable_match_index'] = None
        for _, row in deduplicated_probable_matches.iterrows():
            if row['row1'] in self.source_data.index:
                self.source_data.loc[row['row1'], 'probable_match_index'] = row['row2']

        # Remove rows from the source data that are duplicate matches
        rows_to_drop = matches[~matches.index.isin(deduplicated_matches.index)]['row1'].unique()
        deduplicated_source_data = self.source_data.drop(rows_to_drop, errors='ignore')

        return deduplicated_source_data

    def _deduplicate(self, data):
        """
        Deduplicate the data by keeping only the row with the highest average similarity for each `row1`.
        :param data: DataFrame containing classified data.
        :return: Deduplicated DataFrame.
        """
        # Sort by average similarity in descending order
        data = data.sort_values(by='average_similarity', ascending=False)

        # Drop duplicate entries for the same `row1`, keeping the one with the highest similarity
        deduplicated_data = data.drop_duplicates(subset=['row1'], keep='first')

        return deduplicated_data


In [18]:
evaluation = Evaluation(source_data, classified_results)
results = evaluation.evaluate()
results

Unnamed: 0,name,address,city,cuisine,tag,probable_match_index
0,arnie morton's of chicago,435 s. la cienega blv.,los angeles,american,probable match,10
2,art's delicatessen,12224 ventura blvd.,studio city,american,probable match,12
3,art's deli,12224 ventura blvd.,studio city,delis,probable match,9
4,hotel bel-air,701 stone canyon rd.,bel air,californian,probable match,17
5,bel-air hotel,701 stone canyon rd.,bel air,californian,probable match,9
...,...,...,...,...,...,...
859,ti couz,3108 16th st.,san francisco,french,not matched,
860,trio cafe,1870 fillmore st.,san francisco,american,not matched,
861,tu lan,8 sixth st.,san francisco,vietnamese,not matched,
862,vicolo pizzeria,201 ivy st.,san francisco,pizza,not matched,
